├── pycave
├── py.typed
├── clustering
│ ├── __init__.py
│ └── kmeans
│ │ ├── __init__.py
│ │ ├── types.py
│ │ ├── model.py
│ │ ├── metrics.py
│ │ ├── estimator.py
│ │ └── lightning_module.py
├── utils
│ ├── __init__.py
│ └── lightning_module.py
├── bayes
│ ├── __init__.py
│ ├── markov_chain
│ │ ├── __init__.py
│ │ ├── types.py
│ │ ├── lightning_module.py
│ │ ├── metrics.py
│ │ ├── model.py
│ │ └── estimator.py
│ ├── gmm
│ │ ├── __init__.py
│ │ ├── types.py
│ │ ├── model.py
│ │ ├── metrics.py
│ │ ├── estimator.py
│ │ └── lightning_module.py
│ └── core
│ │ ├── __init__.py
│ │ ├── types.py
│ │ ├── utils.py
│ │ ├── _jit.py
│ │ └── normal.py
└── __init__.py
├── tests
├── __init__.py
├── _data
│ ├── __init__.py
│ ├── gmm.py
│ └── normal.py
├── bayes
│ ├── gmm
│ │ ├── test_gmm_model.py
│ │ ├── test_gmm_estimator.py
│ │ ├── test_gmm_metrics.py
│ │ └── benchmark_gmm_estimator.py
│ ├── core
│ │ ├── benchmark_precision_cholesky.py
│ │ ├── benchmark_log_normal.py
│ │ └── test_normal.py
│ └── markov_chain
│ │ ├── test_markov_chain_model.py
│ │ └── test_markov_chain_estimator.py
└── clustering
│ └── kmeans
│ ├── test_kmeans_model.py
│ ├── test_kmeans_estimator.py
│ └── benchmark_kmeans_estimator.py
├── .github
├── CODEOWNERS
├── dependabot.yml
└── workflows
│ ├── deploy.yml
│ ├── ci.yml
│ └── docs.yml
├── .prettierignore
├── docs
├── _static
│ ├── favicon.ico
│ └── logo.svg
├── _templates
│ ├── autosummary
│ │ ├── method.rst
│ │ └── class.rst
│ └── classes
│ │ ├── type_alias.rst
│ │ └── pytorch_module.rst
├── spelling_wordlist.txt
├── sites
│ ├── api.rst
│ └── benchmark.rst
├── conf.py
└── index.rst
├── Makefile
├── .prettierrc
├── .pre-commit-config.yaml
├── LICENSE
├── pyproject.toml
├── README.md
└── .gitignore
/pycave/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/_data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @borchero
2 |
--------------------------------------------------------------------------------
/.prettierignore:
--------------------------------------------------------------------------------
1 | .pytest_cache/
2 | build/
3 |
--------------------------------------------------------------------------------
/docs/_static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/borchero/pycave/HEAD/docs/_static/favicon.ico
--------------------------------------------------------------------------------
/pycave/clustering/__init__.py:
--------------------------------------------------------------------------------
1 | from .kmeans import KMeans
2 |
3 | __all__ = [
4 | "KMeans",
5 | ]
6 |
--------------------------------------------------------------------------------
/pycave/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .lightning_module import NonparametricLightningModule
2 |
3 | __all__ = ["NonparametricLightningModule"]
4 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: docs
2 |
3 | docs:
4 | rm -rf build
5 | rm -rf docs/generated
6 | rm -rf docs/sites/generated
7 | sphinx-build -W -b html docs build
8 |
--------------------------------------------------------------------------------
/.prettierrc:
--------------------------------------------------------------------------------
1 | {
2 | "bracketSpacing": true,
3 | "endOfLine": "auto",
4 | "printWidth": 99,
5 | "proseWrap": "always",
6 | "singleQuote": false,
7 | "tabWidth": 2
8 | }
9 |
--------------------------------------------------------------------------------
/pycave/bayes/__init__.py:
--------------------------------------------------------------------------------
1 | from .gmm import GaussianMixture
2 | from .markov_chain import MarkovChain
3 |
4 | __all__ = [
5 | "GaussianMixture",
6 | "MarkovChain",
7 | ]
8 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/method.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 |
3 | {{ (class + "." + name) | underline }}
4 |
5 | .. currentmodule:: {{ module }}
6 |
7 | .. automethod:: {{ class }}.{{ name }}
8 |
--------------------------------------------------------------------------------
/docs/_templates/classes/type_alias.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. role:: hidden
4 |
5 | {{ name | underline }}
6 |
7 | .. currentmodule:: {{ module }}
8 |
9 | .. autodata:: {{ name }}
10 |
--------------------------------------------------------------------------------
/pycave/clustering/kmeans/__init__.py:
--------------------------------------------------------------------------------
1 | from .estimator import KMeans
2 | from .model import KMeansModel, KMeansModelConfig
3 |
4 | __all__ = [
5 | "KMeans",
6 | "KMeansModel",
7 | "KMeansModelConfig",
8 | ]
9 |
--------------------------------------------------------------------------------
/pycave/bayes/markov_chain/__init__.py:
--------------------------------------------------------------------------------
1 | from .estimator import MarkovChain
2 | from .model import MarkovChainModel, MarkovChainModelConfig
3 |
4 | __all__ = [
5 | "MarkovChain",
6 | "MarkovChainModel",
7 | "MarkovChainModelConfig",
8 | ]
9 |
--------------------------------------------------------------------------------
/pycave/bayes/gmm/__init__.py:
--------------------------------------------------------------------------------
1 | from .estimator import GaussianMixture
2 | from .model import GaussianMixtureModel, GaussianMixtureModelConfig
3 |
4 | __all__ = [
5 | "GaussianMixture",
6 | "GaussianMixtureModel",
7 | "GaussianMixtureModelConfig",
8 | ]
9 |
--------------------------------------------------------------------------------
/tests/bayes/gmm/test_gmm_model.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | from torch import jit
3 | from pycave.bayes.gmm import GaussianMixtureModel, GaussianMixtureModelConfig
4 |
5 |
6 | def test_compile():
7 | config = GaussianMixtureModelConfig(num_components=2, num_features=3, covariance_type="full")
8 | model = GaussianMixtureModel(config)
9 | jit.script(model)
10 |
--------------------------------------------------------------------------------
/pycave/bayes/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .normal import cholesky_precision, covariance, log_normal, sample_normal
2 | from .types import CovarianceType
3 | from .utils import covariance_dim, covariance_shape
4 |
5 | __all__ = [
6 | "cholesky_precision",
7 | "log_normal",
8 | "sample_normal",
9 | "covariance",
10 | "CovarianceType",
11 | "covariance_dim",
12 | "covariance_shape",
13 | ]
14 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | # Update GitHub actions
4 | - directory: /
5 | open-pull-requests-limit: 5
6 | package-ecosystem: github-actions
7 | schedule:
8 | interval: weekly
9 | day: saturday
10 | # Update Python dependencies
11 | - directory: /
12 | open-pull-requests-limit: 5
13 | package-ecosystem: pip
14 | schedule:
15 | interval: weekly
16 | day: saturday
17 |
--------------------------------------------------------------------------------
/docs/_static/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/spelling_wordlist.txt:
--------------------------------------------------------------------------------
1 | abc
2 | bayes
3 | boolean
4 | centroid
5 | Centroids
6 | centroids
7 | Checkpointing
8 | Cholesky
9 | contravariant
10 | covariances
11 | dataclass
12 | dataloader
13 | dataloaders
14 | datapoint
15 | datapoints
16 | dataset
17 | datasets
18 | diag
19 | dimensionality
20 | dtype
21 | Elkan
22 | Frobenius
23 | Gaussians
24 | GiB
25 | gmm
26 | init
27 | initializer
28 | iteratively
29 | KMeans
30 | kmeans
31 | learnt
32 | logits
33 | markov
34 | Mixin
35 | nn
36 | overridable
37 | parallelize
38 | params
39 | precisions
40 | pycave
41 | runtime
42 | scikit
43 | SequenceData
44 | Subclasses
45 | subclasses
46 | stdout
47 | TabularData
48 | tbd
49 | Utils
50 | Xeon
51 |
--------------------------------------------------------------------------------
/.github/workflows/deploy.yml:
--------------------------------------------------------------------------------
1 | name: Deploy Package
2 | on:
3 | release:
4 | types: [published]
5 |
6 | jobs:
7 | build:
8 | name: Publish
9 | runs-on: ubuntu-latest
10 | container:
11 | image: python:3.8-buster
12 | steps:
13 | - name: Checkout
14 | uses: actions/checkout@v3
15 | - name: Install poetry
16 | uses: snok/install-poetry@v1
17 | - name: Tag
18 | run: poetry version ${{ github.event.release.tag_name }}
19 | - name: Build Wheel
20 | run: poetry build
21 | - name: Publish to PyPi
22 | run: poetry publish --username $PYPI_USERNAME --password $PYPI_PASSWORD
23 | env:
24 | PYPI_USERNAME: ${{ secrets.PYPI_USERNAME }}
25 | PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
26 |
--------------------------------------------------------------------------------
/tests/_data/gmm.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | from typing import Tuple
3 | import torch
4 | from pycave.bayes.core import CovarianceType
5 | from pycave.bayes.gmm import GaussianMixtureModel, GaussianMixtureModelConfig
6 |
7 |
8 | def sample_gmm(
9 | num_datapoints: int, num_features: int, num_components: int, covariance_type: CovarianceType
10 | ) -> Tuple[torch.Tensor, torch.Tensor]:
11 | config = GaussianMixtureModelConfig(num_components, num_features, covariance_type)
12 | model = GaussianMixtureModel(config)
13 |
14 | # Means and covariances can simply be scaled
15 | model.means.mul_(torch.rand(num_components).unsqueeze(-1) * 10).add_(
16 | torch.rand(num_components).unsqueeze(-1) * 10
17 | )
18 |
19 | return model.sample(num_datapoints), model.means
20 |
--------------------------------------------------------------------------------
/pycave/clustering/kmeans/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Literal
3 |
4 | KMeansInitStrategy = Literal["random", "kmeans++"]
5 | KMeansInitStrategy.__doc__ = """
6 | Strategy for initializing KMeans centroids.
7 |
8 | - **random**: Centroids are sampled randomly from the data. This has complexity ``O(n)`` for ``n``
9 | datapoints.
10 | - **kmeans++**: Centroids are computed iteratively. The first centroid is sampled randomly from
11 | the data. Subsequently, centroids are sampled from the remaining datapoints with probability
12 | proportional to ``D(x)^2`` where ``D(x)`` is the distance of datapoint ``x`` to the closest
13 | centroid chosen so far. This has complexity ``O(kn)`` for ``k`` clusters and ``n`` datapoints.
14 | If done on mini-batches, the complexity increases to ``O(k^2 n)``.
15 | """
16 |
--------------------------------------------------------------------------------
/pycave/bayes/gmm/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Literal
3 |
4 | GaussianMixtureInitStrategy = Literal["random", "kmeans", "kmeans++"]
5 | GaussianMixtureInitStrategy.__doc__ = """
6 | Strategy for initializing the parameters of a Gaussian mixture model.
7 |
8 | - **random**: Samples responsibilities of datapoints at random and subsequently initializes means
9 | and covariances from these.
10 | - **kmeans**: Runs K-Means via :class:`pycave.clustering.KMeans` and uses the centroids as the
11 | initial component means. For computing the covariances, responsibilities are given as the
12 | one-hot cluster assignments.
13 | - **kmeans++**: Runs only the K-Means++ initialization procedure to sample means in a smart
14 | fashion. Might be more efficient than ``kmeans`` as it does not actually run clustering. For
15 | many clusters, this is, however, still slow.
16 | """
17 |
--------------------------------------------------------------------------------
/pycave/bayes/core/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Literal
3 |
4 | CovarianceType = Literal["full", "tied", "diag", "spherical"]
5 | CovarianceType.__doc__ = """
6 | The type of covariance to use for a set of multivariate Normal distributions.
7 |
8 | - **full**: Each distribution has a full covariance matrix. Covariance matrix is a tensor of shape
9 | ``[num_components, num_features, num_features]``.
10 | - **tied**: All distributions share the same full covariance matrix. Covariance matrix is a tensor
11 | of shape ``[num_features, num_features]``.
12 | - **diag**: Each distribution has a diagonal covariance matrix. Covariance matrix is a tensor of
13 | shape ``[num_components, num_features]``.
14 | - **spherical**: Each distribution has a diagonal covariance matrix which is a multiple of the
15 | identity matrix. Covariance matrix is a tensor of shape ``[num_components]``.
16 | """
17 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ci:
2 | autofix_prs: false
3 | skip: [pylint]
4 |
5 | repos:
6 | - repo: https://github.com/psf/black
7 | rev: 22.12.0
8 | hooks:
9 | - id: black
10 | - repo: https://github.com/PyCQA/pylint
11 | rev: v2.15.9
12 | hooks:
13 | - id: pylint
14 | language: system
15 | types: [python]
16 | args: [-rn, -sn]
17 | - repo: https://github.com/pre-commit/mirrors-mypy
18 | rev: v0.991
19 | hooks:
20 | - id: mypy
21 | - repo: https://github.com/PyCQA/isort
22 | rev: v5.11.3
23 | hooks:
24 | - id: isort
25 | - repo: https://github.com/PyCQA/docformatter
26 | rev: v1.5.1
27 | hooks:
28 | - id: docformatter
29 | additional_dependencies: [tomli]
30 | - repo: https://github.com/asottile/pyupgrade
31 | rev: v3.3.1
32 | hooks:
33 | - id: pyupgrade
34 | args: [--py38-plus]
35 | - repo: https://github.com/pre-commit/mirrors-prettier
36 | rev: v3.0.0-alpha.4
37 | hooks:
38 | - id: prettier
39 |
--------------------------------------------------------------------------------
/tests/_data/normal.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | from typing import List
3 | import torch
4 |
5 |
6 | def sample_data(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
7 | return [torch.randn(count, dim) for count, dim in zip(counts, dims)]
8 |
9 |
10 | def sample_means(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
11 | return [torch.randn(count, dim) for count, dim in zip(counts, dims)]
12 |
13 |
14 | def sample_spherical_covars(counts: List[int]) -> List[torch.Tensor]:
15 | return [torch.rand(count) for count in counts]
16 |
17 |
18 | def sample_diag_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
19 | return [torch.rand(count, dim).squeeze() for count, dim in zip(counts, dims)]
20 |
21 |
22 | def sample_full_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
23 | result = []
24 | for count, dim in zip(counts, dims):
25 | A = torch.rand(count, dim * 10, dim)
26 | result.append(A.permute(0, 2, 1).bmm(A).squeeze())
27 | return result
28 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 |
3 | {{ name | underline }}
4 |
5 | .. currentmodule:: {{ module }}
6 |
7 | .. autoclass:: {{ name }}
8 | :show-inheritance:
9 |
10 | {% if methods %}
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 | :toctree:
15 | :nosignatures:
16 |
17 | {% for item in methods %}
18 | {%- if not item in inherited_members %}
19 | {%- if not item.startswith("_") %}
20 | ~{{ name }}.{{ item }}
21 | {%- endif %}
22 | {%- endif %}
23 | {%- endfor %}
24 |
25 | .. rubric:: Inherited Methods
26 |
27 | .. autosummary::
28 | :toctree:
29 | :nosignatures:
30 |
31 | {% for item in methods %}
32 | {%- if item in inherited_members %}
33 | {%- if not item.startswith("_") %}
34 | ~{{ name }}.{{ item }}
35 | {%- endif %}
36 | {%- endif %}
37 | {%- endfor %}
38 | {%- endif %}
39 |
40 | {% if attributes %}
41 | .. rubric:: Attributes
42 |
43 | .. autosummary::
44 | {% for item in attributes %}
45 | ~{{ name }}.{{ item }}
46 | {%- endfor %}
47 | {%- endif %}
48 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Oliver Borchert
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/tests/clustering/kmeans/test_kmeans_model.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import torch
3 | from torch import jit
4 | from pycave.clustering.kmeans import KMeansModel, KMeansModelConfig
5 |
6 |
7 | def test_compile():
8 | config = KMeansModelConfig(num_clusters=2, num_features=5)
9 | model = KMeansModel(config)
10 | jit.script(model)
11 |
12 |
13 | def test_forward():
14 | config = KMeansModelConfig(num_clusters=2, num_features=2)
15 | model = KMeansModel(config)
16 | model.centroids.copy_(torch.as_tensor([[0.0, 0.0], [2.0, 2.0]]))
17 |
18 | X = torch.as_tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [-1.0, 4.0]])
19 | distances, assignments, inertias = model.forward(X)
20 |
21 | expected_distances = torch.as_tensor([[0.0, 8.0], [2.0, 2.0], [8.0, 0.0], [17.0, 13.0]]).sqrt()
22 | expected_assignments = torch.as_tensor([0, 0, 1, 1])
23 | expected_inertias = torch.as_tensor([0.0, 2.0, 0.0, 13.0])
24 |
25 | assert torch.allclose(distances, expected_distances)
26 | assert torch.all(assignments == expected_assignments)
27 | assert torch.allclose(inertias, expected_inertias)
28 |
--------------------------------------------------------------------------------
/docs/_templates/classes/pytorch_module.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. role:: hidden
4 |
5 | {{ name | underline }}
6 |
7 | .. currentmodule:: {{ module }}
8 |
9 | .. autoclass:: {{ name }}
10 | :show-inheritance:
11 |
12 | {% if methods and not name.endswith("Config") %}
13 | .. rubric:: Methods
14 |
15 | .. autosummary::
16 | :toctree:
17 | :nosignatures:
18 |
19 | {% for item in methods %}
20 | {%- if not item in inherited_members %}
21 | {%- if not item.startswith("_") %}
22 | ~{{ name }}.{{ item }}
23 | {%- endif %}
24 | {%- endif %}
25 | {%- endfor %}
26 | {%- endif %}
27 |
28 | {% if methods and not name.endswith("Config") %}
29 | .. rubric:: Inherited Methods
30 |
31 | .. autosummary::
32 | :toctree:
33 | :nosignatures:
34 |
35 | {% for item in methods %}
36 | {%- if item in ["load", "save"] %}
37 | ~{{ name }}.{{ item }}
38 | {%- endif %}
39 | {%- endfor %}
40 | {%- endif %}
41 |
42 | {% if attributes %}
43 | .. rubric:: Attributes
44 |
45 | .. autosummary::
46 | {% for item in attributes %}
47 | {%- if not item in ["training", "T_destination", "dump_patches"] %}
48 | ~{{ name }}.{{ item }}
49 | {%- endif %}
50 | {%- endfor %}
51 | {%- endif %}
52 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 |
7 | jobs:
8 | pylint:
9 | name: Pylint Checks
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Checkout
13 | uses: actions/checkout@v3
14 | - name: Setup Python
15 | uses: actions/setup-python@v4
16 | with:
17 | python-version: "3.10"
18 | - name: Install poetry
19 | uses: snok/install-poetry@v1
20 | - name: Install project
21 | run: poetry install --only main,pre-commit
22 | - name: Run pylint
23 | run: poetry run pylint **/*.py
24 |
25 | unit-tests:
26 | name: Unit Tests - Python ${{ matrix.python-version }}
27 | runs-on: ubuntu-latest
28 | strategy:
29 | matrix:
30 | python-version: ["3.8", "3.9", "3.10"]
31 | steps:
32 | - name: Checkout
33 | uses: actions/checkout@v3
34 | - name: Setup Python
35 | uses: actions/setup-python@v4
36 | with:
37 | python-version: ${{ matrix.python-version }}
38 | - name: Install poetry
39 | uses: snok/install-poetry@v1
40 | - name: Install project
41 | run: poetry install --only main,testing
42 | - name: Run Pytest
43 | run: poetry run pytest tests
44 |
--------------------------------------------------------------------------------
/pycave/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 | import lightkit
4 |
5 | # This is taken from PyTorch Lightning and ensures that logging for this package is enabled
6 | _root_logger = logging.getLogger()
7 | _logger = logging.getLogger(__name__)
8 | _logger.setLevel(logging.INFO)
9 | if not _root_logger.hasHandlers():
10 | _logger.addHandler(logging.StreamHandler())
11 | _logger.propagate = False
12 |
13 | # This disables most logs generated by PyTorch Lightning
14 | logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
15 | warnings.filterwarnings(
16 | action="ignore", message=".*Consider increasing the value of the `num_workers` argument.*"
17 | )
18 | warnings.filterwarnings(
19 | action="ignore", message=".*`LightningModule.configure_optimizers` returned `None`.*"
20 | )
21 | warnings.filterwarnings(
22 | action="ignore", message=".*`LoggerConnector.gpus_metrics` was deprecated in v1.5.*"
23 | )
24 |
25 | # We also want to define a function which silences info logs
26 | def set_logging_level(level: int) -> None:
27 | """
28 | Enables or disables logging for the entire module. By default, logging is enabled.
29 |
30 | Args:
31 | enabled: Whether to enable logging.
32 | """
33 | _logger.setLevel(level)
34 | lightkit.set_logging_level(level)
35 |
36 |
37 | # Export
38 | __all__ = ["set_logging_level"]
39 |
--------------------------------------------------------------------------------
/docs/sites/api.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =============
3 |
4 | Bayesian Models
5 | ---------------
6 |
7 | .. currentmodule:: pycave.bayes
8 |
9 | Gaussian Mixture
10 | ^^^^^^^^^^^^^^^^
11 |
12 | .. autosummary::
13 | :toctree: generated/bayes/gmm
14 | :nosignatures:
15 | :caption: Bayesian Models
16 |
17 | GaussianMixture
18 |
19 | :template: classes/pytorch_module.rst
20 |
21 | ~gmm.GaussianMixtureModel
22 | ~gmm.GaussianMixtureModelConfig
23 |
24 |
25 | Markov Chain
26 | ^^^^^^^^^^^^
27 |
28 | .. autosummary::
29 | :toctree: generated/bayes/markov_chain
30 | :nosignatures:
31 |
32 | MarkovChain
33 |
34 | :template: classes/pytorch_module.rst
35 |
36 | ~markov_chain.MarkovChainModel
37 | ~markov_chain.MarkovChainModelConfig
38 |
39 |
40 | Clustering Models
41 | -----------------
42 |
43 | .. currentmodule:: pycave.clustering
44 |
45 | K-Means
46 | ^^^^^^^
47 |
48 | .. autosummary::
49 | :toctree: generated/clustering/kmeans
50 | :nosignatures:
51 | :caption: Clustering Models
52 |
53 | KMeans
54 |
55 | :template: classes/pytorch_module.rst
56 |
57 | ~kmeans.KMeansModel
58 | ~kmeans.KMeansModelConfig
59 |
60 |
61 | Utility Types
62 | -------------
63 |
64 | .. currentmodule:: pycave
65 | .. autosummary::
66 | :toctree: generated/types
67 | :nosignatures:
68 | :caption: Types
69 | :template: classes/type_alias.rst
70 |
71 | ~bayes.markov_chain.types.SequenceData
72 | ~bayes.core.CovarianceType
73 | ~bayes.gmm.types.GaussianMixtureInitStrategy
74 | ~clustering.kmeans.types.KMeansInitStrategy
75 |
--------------------------------------------------------------------------------
/pycave/bayes/core/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .types import CovarianceType
3 |
4 |
5 | def covariance_dim(covariance_type: CovarianceType) -> int:
6 | """
7 | Returns the number of dimension of the covariance matrix for a set of components.
8 |
9 | Args:
10 | covariance_type: The type of covariance to obtain the dimension for.
11 |
12 | Returns:
13 | The number of dimensions.
14 | """
15 | if covariance_type == "full":
16 | return 3
17 | if covariance_type in ("tied", "diag"):
18 | return 2
19 | return 1
20 |
21 |
22 | def covariance_shape(
23 | num_components: int, num_features: int, covariance_type: CovarianceType
24 | ) -> torch.Size:
25 | """
26 | Returns the expected shape of the covariance matrix for the given number of components with the
27 | provided number of features based on the covariance type.
28 |
29 | Args:
30 | num_components: The number of Normal distributions to describe with the covariance.
31 | num_features: The dimensionality of the Normal distributions.
32 | covariance_type: The type of covariance to use.
33 |
34 | Returns:
35 | The expected size of the tensor representing the covariances.
36 | """
37 | if covariance_type == "full":
38 | return torch.Size([num_components, num_features, num_features])
39 | if covariance_type == "tied":
40 | return torch.Size([num_features, num_features])
41 | if covariance_type == "diag":
42 | return torch.Size([num_components, num_features])
43 | # covariance_type == "spherical"
44 | return torch.Size([num_components])
45 |
--------------------------------------------------------------------------------
/tests/bayes/core/benchmark_precision_cholesky.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import numpy as np
3 | import torch
4 | from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
5 | from sklearn.mixture._gaussian_mixture import _compute_precision_cholesky # type: ignore
6 | from pycave.bayes.core import cholesky_precision
7 |
8 |
9 | def test_cholesky_precision_spherical(benchmark: BenchmarkFixture):
10 | covars = torch.rand(50)
11 | benchmark(cholesky_precision, covars, "spherical")
12 |
13 |
14 | def test_numpy_cholesky_precision_spherical(benchmark: BenchmarkFixture):
15 | covars = np.random.rand(50)
16 | benchmark(_compute_precision_cholesky, covars, "spherical") # type: ignore
17 |
18 |
19 | # -------------------------------------------------------------------------------------------------
20 |
21 |
22 | def test_cholesky_precision_tied(benchmark: BenchmarkFixture):
23 | A = torch.randn(10000, 100)
24 | covars = A.t().mm(A)
25 | benchmark(cholesky_precision, covars, "tied")
26 |
27 |
28 | def test_numpy_cholesky_precision_tied(benchmark: BenchmarkFixture):
29 | A = np.random.randn(10000, 100)
30 | covars = np.dot(A.T, A)
31 | benchmark(_compute_precision_cholesky, covars, "tied") # type: ignore
32 |
33 |
34 | # -------------------------------------------------------------------------------------------------
35 |
36 |
37 | def test_cholesky_precision_full(benchmark: BenchmarkFixture):
38 | A = torch.randn(50, 10000, 100)
39 | covars = A.permute(0, 2, 1).bmm(A)
40 | benchmark(cholesky_precision, covars, "full")
41 |
42 |
43 | def test_numpy_cholesky_precision_full(benchmark: BenchmarkFixture):
44 | A = np.random.randn(50, 10000, 100)
45 | covars = np.matmul(np.transpose(A, (0, 2, 1)), A)
46 | benchmark(_compute_precision_cholesky, covars, "full") # type: ignore
47 |
--------------------------------------------------------------------------------
/pycave/utils/lightning_module.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List
3 | import pytorch_lightning as pl
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class NonparametricLightningModule(pl.LightningModule, ABC):
9 | """
10 | A lightning module which sets some defaults for training models with no parameters (i.e. only
11 | buffers that are optimized differently than via gradient descent).
12 | """
13 |
14 | def __init__(self):
15 | super().__init__()
16 | self.automatic_optimization = False
17 |
18 | # Required parameter to make DDP training work
19 | self.register_parameter("__ddp_dummy__", nn.Parameter(torch.empty(1)))
20 |
21 | def configure_optimizers(self) -> None:
22 | return None
23 |
24 | def training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
25 | self.nonparametric_training_step(batch, batch_idx)
26 |
27 | def training_epoch_end(self, outputs: List[torch.Tensor]) -> None:
28 | self.nonparametric_training_epoch_end()
29 |
30 | @abstractmethod
31 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
32 | """
33 | Training step that is not allowed to return any value.
34 | """
35 |
36 | def nonparametric_training_epoch_end(self) -> None:
37 | """
38 | Training epoch end that is not passed any outputs.
39 |
40 | Does nothing by default.
41 | """
42 |
43 | def all_gather_first(self, x: torch.Tensor) -> torch.Tensor:
44 | """
45 | Gathers the provided tensor from all processes.
46 |
47 | If more than one process is available, chooses the value of the first process in every
48 | process.
49 | """
50 | gathered = self.all_gather(x)
51 | if gathered.dim() > x.dim():
52 | return gathered[0]
53 | return x
54 |
--------------------------------------------------------------------------------
/pycave/bayes/markov_chain/types.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Union
2 | import numpy as np
3 | import numpy.typing as npt
4 | import torch
5 | from torch.nn.utils.rnn import pack_sequence, PackedSequence
6 | from torch.utils.data import Dataset
7 |
8 | SequenceData = Union[
9 | npt.NDArray[np.float32],
10 | torch.Tensor,
11 | Dataset[torch.Tensor],
12 | ]
13 | SequenceData.__doc__ = """
14 | Data that may be passed to estimators expecting 1-D sequences. Data may be provided in multiple
15 | formats:
16 |
17 | - NumPy array of shape ``[num_sequences, sequence_length]``.
18 | - PyTorch tensor of shape ``[num_sequences, sequence_length]``.
19 | - PyTorch dataset yielding items of shape ``[sequence_length]`` where the sequence length may
20 | differ for different indices.
21 | """
22 |
23 |
24 | def collate_sequences_same_length(data: Tuple[torch.Tensor]) -> PackedSequence:
25 | """
26 | Collates the provided sequences into a packed sequence. Each sequence has to have the same
27 | length.
28 |
29 | Args:
30 | data: A single tensor of shape ``[num_sequences, sequence_length]`` where each item has the
31 | same length.
32 |
33 | Returns:
34 | A packed sequence containing all sequences.
35 | """
36 | (sequences,) = data
37 | num_sequences, sequence_length = sequences.size()
38 | batch_sizes = torch.ones(sequence_length, dtype=torch.long) * num_sequences
39 | return PackedSequence(sequences.t().flatten(), batch_sizes)
40 |
41 |
42 | def collate_sequences(sequences: List[torch.Tensor]) -> PackedSequence:
43 | """
44 | Collates the sequences provided as a list into a packed sequence. The sequences are not
45 | required to be sorted by their lengths.
46 |
47 | Args:
48 | sequences: A list of one-dimensional tensors to batch.
49 |
50 | Returns:
51 | A packed sequence with all the data provided.
52 | """
53 | return pack_sequence(sequences, enforce_sorted=False)
54 |
--------------------------------------------------------------------------------
/pycave/bayes/markov_chain/lightning_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.utils.rnn import PackedSequence
3 | from torchmetrics import MeanMetric
4 | from pycave.bayes.markov_chain.metrics import StateCountAggregator
5 | from pycave.utils import NonparametricLightningModule
6 | from .model import MarkovChainModel
7 |
8 |
9 | class MarkovChainLightningModule(NonparametricLightningModule):
10 | """
11 | Lightning module for training and evaluating a Markov chain.
12 | """
13 |
14 | def __init__(self, model: MarkovChainModel, symmetric: bool = False):
15 | """
16 | Args:
17 | model: The model to train or evaluate.
18 | symmetric: Whether transition probabilities should be symmetric.
19 | """
20 | super().__init__()
21 |
22 | self.model = model
23 | self.symmetric = symmetric
24 |
25 | self.aggregator = StateCountAggregator(
26 | num_states=self.model.config.num_states,
27 | symmetric=self.symmetric,
28 | dist_sync_fn=self.all_gather,
29 | )
30 | self.metric_nll = MeanMetric(dist_sync_fn=self.all_gather)
31 |
32 | def on_train_epoch_start(self) -> None:
33 | self.aggregator.reset()
34 |
35 | def nonparametric_training_step(self, batch: PackedSequence, _batch_idx: int) -> None:
36 | self.aggregator.update(batch)
37 |
38 | def nonparametric_training_epoch_end(self) -> None:
39 | initial_probs, transition_probs = self.aggregator.compute()
40 | self.model.initial_probs.copy_(initial_probs)
41 | self.model.transition_probs.copy_(transition_probs)
42 |
43 | def test_step(self, batch: PackedSequence, _batch_idx: int) -> None:
44 | log_probs = self.model.forward(batch)
45 | self.metric_nll.update(-log_probs)
46 | self.log("nll", self.metric_nll)
47 |
48 | def predict_step( # pylint: disable=signature-differs
49 | self, batch: PackedSequence, batch_idx: int
50 | ) -> torch.Tensor:
51 | return -self.model(batch)
52 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Generate Documentation
2 | on:
3 | release:
4 | types: [published]
5 | push:
6 | branches: [main]
7 | pull_request:
8 |
9 | jobs:
10 | build:
11 | name: Build
12 | runs-on: ubuntu-latest
13 | container:
14 | image: python:3.8-buster
15 | steps:
16 | - name: Checkout
17 | uses: actions/checkout@v3
18 | - name: Install Enchant
19 | run: apt-get update && apt-get install -y enchant
20 | - name: Install poetry
21 | uses: snok/install-poetry@v1
22 | - name: Install Dependencies
23 | run: poetry install --only main,docs
24 | - name: Fix Disutils
25 | run: poetry run pip install setuptools==59.5.0
26 | - name: Check Spelling
27 | run: poetry run sphinx-build -W -b spelling docs build
28 | - name: Generate HTML
29 | run: poetry run sphinx-build -W -b html docs build
30 | - name: Store Artifacts
31 | uses: actions/upload-artifact@v3
32 | with:
33 | name: html
34 | path: build
35 |
36 | deploy:
37 | name: Publish
38 | runs-on: ubuntu-latest
39 | if: ${{ github.event_name == 'release' }}
40 | needs:
41 | - build
42 | steps:
43 | - name: Retrieve Artifacts
44 | uses: actions/download-artifact@v3
45 | with:
46 | name: html
47 | path: build
48 | - name: Configure AWS Credentials
49 | uses: aws-actions/configure-aws-credentials@v1
50 | with:
51 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
52 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
53 | aws-region: ${{ secrets.AWS_REGION }}
54 | - name: Deploy to S3
55 | run: aws s3 sync build s3://pycave.borchero.com --delete --acl public-read
56 | - name: Invalidate Cloudfront
57 | run: |
58 | aws cloudfront create-invalidation \
59 | --distribution-id ${AWS_CLOUDFRONT_DISTRIBUTION} --paths "/*"
60 | env:
61 | AWS_CLOUDFRONT_DISTRIBUTION: ${{ secrets.AWS_CLOUDFRONT_DISTRIBUTION }}
62 |
--------------------------------------------------------------------------------
/tests/bayes/markov_chain/test_markov_chain_model.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import math
3 | import torch
4 | from torch import jit
5 | from torch.nn.utils.rnn import pack_padded_sequence
6 | from pycave.bayes.markov_chain import MarkovChainModel, MarkovChainModelConfig
7 |
8 |
9 | def test_compile():
10 | config = MarkovChainModelConfig(num_states=2)
11 | model = MarkovChainModel(config)
12 | jit.script(model)
13 |
14 |
15 | def test_forward_tensor():
16 | model = _get_default_model()
17 | sequences = torch.as_tensor([[1, 0, 0, 1], [0, 1, 1, 1]])
18 | expected = torch.as_tensor(
19 | [
20 | math.log(0.8) + math.log(0.1) + math.log(0.5) + math.log(0.5),
21 | math.log(0.2) + math.log(0.5) + math.log(0.9) + math.log(0.9),
22 | ]
23 | )
24 | assert torch.allclose(expected, model(sequences))
25 |
26 |
27 | def test_forward_packed_sequence():
28 | model = _get_default_model()
29 | sequences = torch.as_tensor([[1, 0, 0, 1], [0, 1, 1, -1]])
30 | packed_sequences = pack_padded_sequence(sequences.t(), torch.Tensor([4, 3]))
31 | expected = torch.as_tensor(
32 | [
33 | math.log(0.8) + math.log(0.1) + math.log(0.5) + math.log(0.5),
34 | math.log(0.2) + math.log(0.5) + math.log(0.9),
35 | ]
36 | )
37 | assert torch.allclose(expected, model(packed_sequences))
38 |
39 |
40 | def test_sample():
41 | torch.manual_seed(42)
42 | model = _get_default_model()
43 | n = 100000
44 | samples = model.sample(n, 3)
45 | assert math.isclose((samples[:, 0] == 0).sum() / n, 0.2, abs_tol=0.01)
46 |
47 |
48 | def test_stationary_distribution():
49 | model = _get_default_model()
50 | tol = 1e-7
51 | sd = model.stationary_distribution(tol=tol)
52 | assert math.isclose(sd[0].item(), 1 / 6, abs_tol=tol)
53 | assert math.isclose(sd[1].item(), 5 / 6, abs_tol=tol)
54 |
55 |
56 | # -------------------------------------------------------------------------------------------------
57 |
58 |
59 | def _get_default_model() -> MarkovChainModel:
60 | config = MarkovChainModelConfig(num_states=2)
61 | model = MarkovChainModel(config)
62 |
63 | model.initial_probs.copy_(torch.as_tensor([0.2, 0.8]))
64 | model.transition_probs.copy_(torch.as_tensor([[0.5, 0.5], [0.1, 0.9]]))
65 | return model
66 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=all
2 | from __future__ import annotations
3 | import datetime
4 | import os
5 | import sys
6 | from typing import Any
7 |
8 | filepath = os.path.abspath(os.path.dirname(__file__))
9 | sys.path.insert(0, os.path.join(filepath, ".."))
10 |
11 | # -------------------------------------------------------------------------------------------------
12 | # BASICS
13 |
14 | project = "PyCave"
15 | copyright = f"{datetime.datetime.now().year}, Oliver Borchert"
16 |
17 | # -------------------------------------------------------------------------------------------------
18 | # PLUGINS
19 |
20 | extensions = [
21 | "sphinx.ext.autodoc",
22 | "sphinx.ext.autosummary",
23 | "sphinx.ext.intersphinx",
24 | "sphinx.ext.napoleon",
25 | "sphinx.ext.viewcode",
26 | "sphinx_autodoc_typehints",
27 | "sphinx_automodapi.smart_resolver",
28 | "sphinx_copybutton",
29 | ]
30 | if os.uname().machine != "arm64":
31 | extensions.append("sphinxcontrib.spelling")
32 | templates_path = ["_templates"]
33 |
34 | # -------------------------------------------------------------------------------------------------
35 | # CONFIGURATION
36 |
37 | html_theme = "pydata_sphinx_theme"
38 | pygments_style = "lovelace"
39 | html_theme_options = {
40 | "show_prev_next": False,
41 | "github_url": "https://github.com/borchero/pycave",
42 | }
43 | html_logo = "_static/logo.svg"
44 | html_favicon = "_static/favicon.ico"
45 | html_permalinks = True
46 |
47 | autosummary_generate = True
48 | autosummary_imported_members = True
49 | autodoc_member_order = "groupwise"
50 | autodoc_type_aliases = {
51 | "CovarianceType": ":class:`~pycave.bayes.core.CovarianceType`",
52 | "SequenceData": ":class:`~pycave.data.SequenceData`",
53 | "TabularData": ":class:`~pycave.data.TabularData`",
54 | "GaussianMixtureInitStrategy": ":class:`~pycave.bayes.gmm.types.GaussianMixtureInitStrategy`",
55 | "KMeansInitStrategy": ":class:`~pycave.clustering.kmeans.types.KMeansInitStrategy`",
56 | }
57 | autoclass_content = "both"
58 |
59 | simplify_optional_unions = False
60 |
61 | spelling_lang = "en_US"
62 | spelling_word_list_filename = "spelling_wordlist.txt"
63 |
64 | intersphinx_mapping = {
65 | "python": ("https://docs.python.org/3", None),
66 | "torch": ("https://pytorch.org/docs/stable/", None),
67 | "numpy": ("https://numpy.org/doc/stable/", None),
68 | "pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None),
69 | }
70 |
--------------------------------------------------------------------------------
/tests/bayes/markov_chain/test_markov_chain_estimator.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import math
3 | from typing import Tuple
4 | import pytest
5 | import torch
6 | from pycave.bayes import MarkovChain
7 |
8 |
9 | def test_fit_automatic_config():
10 | chain = MarkovChain()
11 | data = torch.randint(50, size=(100, 20))
12 | chain.fit(data)
13 | assert chain.model_.config.num_states == 50
14 |
15 |
16 | @pytest.mark.flaky(max_runs=3, min_passes=1)
17 | def test_sample_and_fit():
18 | chain = MarkovChain(2)
19 | initial_probs, transition_probs = _set_probs(chain)
20 | sample = chain.sample(1000000, 10)
21 |
22 | new = MarkovChain(2)
23 | new.fit(sample)
24 |
25 | assert torch.allclose(initial_probs, new.model_.initial_probs, atol=1e-3)
26 | assert torch.allclose(transition_probs, new.model_.transition_probs, atol=1e-3)
27 |
28 |
29 | def test_score():
30 | chain = MarkovChain(2)
31 | test_data, expected = _set_sample_data(chain)
32 | actual = chain.score(test_data)
33 | assert math.isclose(actual, -expected.mean())
34 |
35 |
36 | def test_score_samples():
37 | chain = MarkovChain(2)
38 | test_data, expected = _set_sample_data(chain)
39 | actual = chain.score_samples(test_data)
40 | assert torch.allclose(actual, -expected)
41 |
42 |
43 | # -------------------------------------------------------------------------------------------------
44 |
45 |
46 | def _set_probs(chain: MarkovChain) -> Tuple[torch.Tensor, torch.Tensor]:
47 | data = torch.randint(2, size=(100, 20))
48 | chain.fit(data)
49 |
50 | initial_probs = torch.as_tensor([0.8, 0.2])
51 | chain.model_.initial_probs.copy_(initial_probs)
52 |
53 | transition_probs = torch.as_tensor([[0.5, 0.5], [0.1, 0.9]])
54 | chain.model_.transition_probs.copy_(transition_probs)
55 |
56 | return initial_probs, transition_probs
57 |
58 |
59 | def _set_sample_data(chain: MarkovChain) -> Tuple[torch.Tensor, torch.Tensor]:
60 | _set_probs(chain)
61 |
62 | test_data = torch.as_tensor(
63 | [
64 | [1, 1, 0, 1],
65 | [0, 1, 0, 1],
66 | [0, 0, 1, 1],
67 | ]
68 | )
69 | expected = torch.as_tensor(
70 | [
71 | math.log(0.2) + math.log(0.9) + math.log(0.1) + math.log(0.5),
72 | math.log(0.8) + math.log(0.5) + math.log(0.1) + math.log(0.5),
73 | math.log(0.8) + math.log(0.5) + math.log(0.5) + math.log(0.9),
74 | ]
75 | )
76 | return test_data, expected
77 |
--------------------------------------------------------------------------------
/pycave/bayes/markov_chain/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, cast, Optional, Tuple
2 | import torch
3 | from torch.nn.utils.rnn import PackedSequence
4 | from torchmetrics import Metric
5 |
6 |
7 | class StateCountAggregator(Metric):
8 | """
9 | The state count aggregator aggregates initial states and transitions between states.
10 | """
11 |
12 | full_state_update = False
13 |
14 | def __init__(
15 | self,
16 | num_states: int,
17 | symmetric: bool,
18 | *,
19 | dist_sync_fn: Optional[Callable[[Any], Any]] = None,
20 | ):
21 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
22 |
23 | self.num_states = num_states
24 | self.symmetric = symmetric
25 |
26 | self.initial_counts: torch.Tensor
27 | self.add_state("initial_counts", torch.zeros(num_states), dist_reduce_fx="sum")
28 |
29 | self.transition_counts: torch.Tensor
30 | self.add_state(
31 | "transition_counts", torch.zeros(num_states, num_states).view(-1), dist_reduce_fx="sum"
32 | )
33 |
34 | def update(self, sequences: PackedSequence) -> None:
35 | batch_sizes = cast(torch.Tensor, sequences.batch_sizes)
36 | num_sequences = batch_sizes[0]
37 | data = cast(torch.Tensor, sequences.data)
38 |
39 | # First, we count the initial states
40 | initial_counts = torch.bincount(data[:num_sequences], minlength=self.num_states).float()
41 | self.initial_counts.add_(initial_counts)
42 |
43 | # Then, we count the transitions
44 | offset = 0
45 | for prev_size, size in zip(batch_sizes, batch_sizes[1:]):
46 | sources = data[offset : offset + size]
47 | targets = data[offset + prev_size : offset + prev_size + size]
48 | transitions = sources * self.num_states + targets
49 | values = torch.ones_like(transitions, dtype=torch.float)
50 | self.transition_counts.scatter_add_(0, transitions, values)
51 | offset += prev_size
52 |
53 | def compute(self) -> Tuple[torch.Tensor, torch.Tensor]:
54 | initial_probs = self.initial_counts / self.initial_counts.sum()
55 | transition_counts = self.transition_counts.view(self.num_states, self.num_states)
56 |
57 | if self.symmetric:
58 | self.transition_counts.add_(transition_counts.t())
59 | transition_probs = transition_counts / transition_counts.sum(1, keepdim=True)
60 |
61 | return initial_probs, transition_probs
62 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | authors = ["Oliver Borchert "]
3 | classifiers = [
4 | "Development Status :: 5 - Production/Stable",
5 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
6 | ]
7 | description = "Traditional Machine Learning Models in PyTorch."
8 | documentation = "https://pycave.borchero.com"
9 | license = "MIT"
10 | name = "pycave"
11 | readme = "README.md"
12 | repository = "https://github.com/borchero/pycave"
13 | version = "0.0.0"
14 |
15 | [tool.poetry.dependencies]
16 | lightkit = "^0.5.0"
17 | numpy = "^1.20.3"
18 | python = ">=3.8,<3.11"
19 | pytorch-lightning = "^1.6.0"
20 | torch = "^1.8.0"
21 | torchmetrics = ">=0.6,<0.12"
22 |
23 | [tool.poetry.group.pre-commit.dependencies]
24 | black = "^22.12.0"
25 | docformatter = "^1.5.0"
26 | isort = "^5.10.1"
27 | mypy = "^0.991"
28 | pylint = "^2.12.2"
29 | pyupgrade = "^3.3.1"
30 |
31 | [tool.poetry.group.docs.dependencies]
32 | Sphinx = "^5.0.0"
33 | pydata-sphinx-theme = "^0.7.2"
34 | scanpydoc = "^0.7.1"
35 | sphinx-autodoc-typehints = "^1.12.0"
36 | sphinx-automodapi = "^0.13"
37 | sphinx-copybutton = "^0.3.3"
38 | sphinxcontrib-spelling = "^7.2.1"
39 |
40 | [tool.poetry.group.testing.dependencies]
41 | flaky = "^3.7.0"
42 | pytest = "^6.2.4"
43 | pytest-benchmark = "^3.4.1"
44 | scikit-learn = "^0.24.2"
45 |
46 | [tool.poetry.group.dev.dependencies]
47 | jupyter = "^1.0.0"
48 |
49 | [build-system]
50 | build-backend = "poetry.core.masonry.api"
51 | requires = ["poetry-core>=1.0.0"]
52 |
53 | [tool.pylint.messages_control]
54 | disable = [
55 | "arguments-differ",
56 | "duplicate-code",
57 | "missing-module-docstring",
58 | "invalid-name",
59 | "too-few-public-methods",
60 | "too-many-ancestors",
61 | "too-many-arguments",
62 | "too-many-branches",
63 | "too-many-locals",
64 | "too-many-instance-attributes",
65 | ]
66 |
67 | [tool.pylint.typecheck]
68 | generated-members = [
69 | "torch.*",
70 | ]
71 |
72 | [tool.black]
73 | line-length = 99
74 | target-version = ["py38", "py39", "py310"]
75 |
76 | [tool.isort]
77 | force_alphabetical_sort_within_sections = true
78 | include_trailing_comma = true
79 | known_first_party = "pycave,tests"
80 | line_length = 99
81 | lines_between_sections = 0
82 | profile = "black"
83 | skip_gitignore = true
84 |
85 | [tool.docformatter]
86 | make-summary-multi-line = true
87 | pre-summary-newline = true
88 | recursive = true
89 | wrap-descriptions = 99
90 | wrap-summaries = 99
91 |
92 | [tool.pytest.ini_options]
93 | filterwarnings = [
94 | "ignore:.*Create unlinked descriptors is going to go away.*:DeprecationWarning",
95 | "ignore:.*this fit will run with no optimizer.*",
96 | "ignore:.*Consider increasing the value of the `num_workers` argument.*",
97 | ]
98 |
--------------------------------------------------------------------------------
/pycave/clustering/kmeans/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Tuple
3 | import torch
4 | from lightkit.nn import Configurable
5 | from torch import jit, nn
6 |
7 |
8 | @dataclass
9 | class KMeansModelConfig:
10 | """
11 | Configuration class for a K-Means model.
12 |
13 | See also:
14 | :class:`KMeansModel`
15 | """
16 |
17 | #: The number of clusters.
18 | num_clusters: int
19 | #: The number of features of each cluster.
20 | num_features: int
21 |
22 |
23 | class KMeansModel(Configurable[KMeansModelConfig], nn.Module):
24 | """
25 | PyTorch module for the K-Means model.
26 |
27 | The centroids managed by this model are non-trainable parameters.
28 | """
29 |
30 | def __init__(self, config: KMeansModelConfig):
31 | """
32 | Args:
33 | config: The configuration to use for initializing the module's buffers.
34 | """
35 | super().__init__(config)
36 |
37 | #: The centers of all clusters, buffer of shape ``[num_clusters, num_features].``
38 | self.centroids: torch.Tensor
39 | self.register_buffer("centroids", torch.empty(config.num_clusters, config.num_features))
40 |
41 | self.reset_parameters()
42 |
43 | @jit.unused
44 | def reset_parameters(self) -> None:
45 | """
46 | Resets the parameters of the KMeans model.
47 |
48 | It samples all cluster centers from a standard Normal.
49 | """
50 | nn.init.normal_(self.centroids)
51 |
52 | def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
53 | """
54 | Computes the distance of each datapoint to each centroid as well as the "inertia", the
55 | squared distance of each datapoint to its closest centroid.
56 |
57 | Args:
58 | data: A tensor of shape ``[num_datapoints, num_features]`` for which to compute the
59 | distances and inertia.
60 |
61 | Returns:
62 | - A tensor of shape ``[num_datapoints, num_centroids]`` with the distance from each
63 | datapoint to each centroid.
64 | - A tensor of shape ``[num_datapoints]`` with the assignments, i.e. the indices of
65 | each datapoint's closest centroid.
66 | - A tensor of shape ``[num_datapoints]`` with the inertia (squared distance to the
67 | closest centroid) of each datapoint.
68 | """
69 | distances = torch.cdist(data, self.centroids)
70 | assignments = distances.min(1, keepdim=True).indices
71 | inertias = distances.gather(1, assignments).square()
72 | return distances, assignments.squeeze(1), inertias.squeeze(1)
73 |
--------------------------------------------------------------------------------
/tests/clustering/kmeans/test_kmeans_estimator.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import math
3 | from typing import Optional
4 | import pytest
5 | import torch
6 | from sklearn.cluster import KMeans as SklearnKMeans # type: ignore
7 | from pycave.clustering import KMeans
8 | from tests._data.gmm import sample_gmm
9 |
10 |
11 | def test_fit_automatic_config():
12 | estimator = KMeans(4)
13 | data = torch.cat([torch.randn(1000, 3) * 0.1 - 100, torch.randn(1000, 3) * 0.1 + 100])
14 | estimator.fit(data)
15 | assert estimator.model_.config.num_clusters == 4
16 | assert estimator.model_.config.num_features == 3
17 |
18 |
19 | def test_fit_num_iter():
20 | # The k-means++ iterations should find the centroids. Afterwards, it should only take a single
21 | # epoch until the centroids do not change anymore.
22 | data = torch.cat([torch.randn(1000, 4) * 0.1 - 100, torch.randn(1000, 4) * 0.1 + 100])
23 |
24 | estimator = KMeans(2)
25 | estimator.fit(data)
26 |
27 | assert estimator.num_iter_ == 1
28 |
29 |
30 | @pytest.mark.flaky(max_runs=2, min_passes=1)
31 | @pytest.mark.parametrize(
32 | ("num_epochs", "converged"),
33 | [(100, True), (1, False)],
34 | )
35 | def test_fit_converged(num_epochs: int, converged: bool):
36 | data, _ = sample_gmm(
37 | num_datapoints=10000,
38 | num_features=8,
39 | num_components=4,
40 | covariance_type="spherical",
41 | )
42 |
43 | estimator = KMeans(4, trainer_params=dict(max_epochs=num_epochs))
44 | estimator.fit(data)
45 |
46 | assert estimator.converged_ == converged
47 |
48 |
49 | @pytest.mark.flaky(max_runs=5, min_passes=1)
50 | @pytest.mark.parametrize(
51 | ("num_datapoints", "batch_size", "num_features", "num_centroids"),
52 | [
53 | (10000, None, 8, 4),
54 | (10000, 1000, 8, 4),
55 | ],
56 | )
57 | def test_fit_inertia(
58 | num_datapoints: int,
59 | batch_size: Optional[int],
60 | num_features: int,
61 | num_centroids: int,
62 | ):
63 | data, _ = sample_gmm(
64 | num_datapoints=num_datapoints,
65 | num_features=num_features,
66 | num_components=num_centroids,
67 | covariance_type="spherical",
68 | )
69 |
70 | # Ours
71 | estimator = KMeans(
72 | num_centroids,
73 | batch_size=batch_size,
74 | trainer_params=dict(precision=64),
75 | )
76 | ours_inertia = float("inf")
77 | for _ in range(10):
78 | ours_inertia = min(ours_inertia, estimator.fit(data).score(data))
79 |
80 | # Sklearn
81 | gmm = SklearnKMeans(num_centroids, n_init=10)
82 | sklearn_inertia = gmm.fit(data.numpy()).score(data.numpy())
83 |
84 | assert math.isclose(ours_inertia, -sklearn_inertia / data.size(0), rel_tol=0.01, abs_tol=0.01)
85 |
--------------------------------------------------------------------------------
/tests/bayes/gmm/test_gmm_estimator.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import math
3 | from typing import Optional
4 | import pytest
5 | import torch
6 | from sklearn.mixture import GaussianMixture as SklearnGaussianMixture # type: ignore
7 | from pycave.bayes import GaussianMixture
8 | from pycave.bayes.core import CovarianceType
9 | from tests._data.gmm import sample_gmm
10 |
11 |
12 | def test_fit_model_config():
13 | estimator = GaussianMixture()
14 | data = torch.randn(1000, 4)
15 | estimator.fit(data)
16 |
17 | assert estimator.model_.config.num_components == 1
18 | assert estimator.model_.config.num_features == 4
19 |
20 |
21 | @pytest.mark.parametrize("batch_size", [2, None])
22 | def test_fit_num_iter(batch_size: Optional[int]):
23 | # For the following data, K-means will find centroids [0.5, 3.5]. The estimator first computes
24 | # the NLL (first iteration), afterwards there is no improvmement in the NLL (second iteration).
25 | data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]])
26 | estimator = GaussianMixture(
27 | 2,
28 | batch_size=batch_size,
29 | trainer_params=dict(precision=64),
30 | )
31 | estimator.fit(data)
32 |
33 | assert estimator.num_iter_ == 2
34 |
35 |
36 | @pytest.mark.flaky(max_runs=3, min_passes=1)
37 | @pytest.mark.parametrize(
38 | ("batch_size", "max_epochs", "converged"),
39 | [(2, 1, False), (2, 3, True), (None, 1, False), (None, 3, True)],
40 | )
41 | def test_fit_converged(batch_size: Optional[int], max_epochs: int, converged: bool):
42 | data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]])
43 |
44 | estimator = GaussianMixture(
45 | 2,
46 | batch_size=batch_size,
47 | trainer_params=dict(precision=64, max_epochs=max_epochs),
48 | )
49 | estimator.fit(data)
50 | assert estimator.converged_ == converged
51 |
52 |
53 | @pytest.mark.flaky(max_runs=25, min_passes=1)
54 | @pytest.mark.parametrize(
55 | ("num_datapoints", "batch_size", "num_features", "num_components", "covariance_type"),
56 | [
57 | (10000, 10000, 4, 4, "spherical"),
58 | (10000, 10000, 4, 4, "diag"),
59 | (10000, 10000, 4, 4, "tied"),
60 | (10000, 10000, 4, 4, "full"),
61 | (10000, 1000, 4, 4, "spherical"),
62 | (10000, 1000, 4, 4, "diag"),
63 | (10000, 1000, 4, 4, "tied"),
64 | (10000, 1000, 4, 4, "full"),
65 | ],
66 | )
67 | def test_fit_nll(
68 | num_datapoints: int,
69 | batch_size: int,
70 | num_features: int,
71 | num_components: int,
72 | covariance_type: CovarianceType,
73 | ):
74 | data, _ = sample_gmm(
75 | num_datapoints=num_datapoints,
76 | num_features=num_features,
77 | num_components=num_components,
78 | covariance_type=covariance_type,
79 | )
80 |
81 | # Ours
82 | estimator = GaussianMixture(
83 | num_components,
84 | covariance_type=covariance_type,
85 | batch_size=batch_size,
86 | trainer_params=dict(precision=64),
87 | )
88 | ours_nll = estimator.fit(data).score(data)
89 |
90 | # Sklearn
91 | gmm = SklearnGaussianMixture(num_components, covariance_type=covariance_type)
92 | sklearn_nll = gmm.fit(data.numpy()).score(data.numpy())
93 |
94 | assert math.isclose(ours_nll, -sklearn_nll, rel_tol=0.01, abs_tol=0.01)
95 |
--------------------------------------------------------------------------------
/pycave/bayes/core/_jit.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import math
3 | import torch
4 |
5 |
6 | def jit_log_normal(
7 | x: torch.Tensor,
8 | means: torch.Tensor,
9 | precisions_cholesky: torch.Tensor,
10 | covariance_type: str,
11 | ) -> torch.Tensor:
12 | if covariance_type == "full":
13 | # Precision shape is `[num_components, dim, dim]`.
14 | log_prob = x.new_empty((x.size(0), means.size(0)))
15 | # We loop here to not blow up the size of intermediate matrices
16 | for k, (mu, prec_chol) in enumerate(zip(means, precisions_cholesky)):
17 | inner = x.matmul(prec_chol) - mu.matmul(prec_chol)
18 | log_prob[:, k] = inner.square().sum(1)
19 | elif covariance_type == "tied":
20 | # Precision shape is `[dim, dim]`.
21 | a = x.matmul(precisions_cholesky) # [N, D]
22 | b = means.matmul(precisions_cholesky) # [K, D]
23 | log_prob = (a.unsqueeze(1) - b).square().sum(-1)
24 | else:
25 | precisions = precisions_cholesky.square()
26 | if covariance_type == "diag":
27 | # Precision shape is `[num_components, dim]`.
28 | x_prob = torch.matmul(x * x, precisions.t())
29 | m_prob = torch.einsum("ij,ij,ij->i", means, means, precisions)
30 | xm_prob = torch.matmul(x, (means * precisions).t())
31 | else: # covariance_type == "spherical"
32 | # Precision shape is `[num_components]`
33 | x_prob = torch.ger(torch.einsum("ij,ij->i", x, x), precisions)
34 | m_prob = torch.einsum("ij,ij->i", means, means) * precisions
35 | xm_prob = torch.matmul(x, means.t() * precisions)
36 |
37 | log_prob = x_prob - 2 * xm_prob + m_prob
38 |
39 | num_features = x.size(1)
40 | logdet = _cholesky_logdet(num_features, precisions_cholesky, covariance_type)
41 | constant = math.log(2 * math.pi) * num_features
42 | return logdet - 0.5 * (constant + log_prob)
43 |
44 |
45 | def _cholesky_logdet(
46 | num_features: int,
47 | precisions_cholesky: torch.Tensor,
48 | covariance_type: str,
49 | ) -> torch.Tensor:
50 | if covariance_type == "full":
51 | return precisions_cholesky.diagonal(dim1=-2, dim2=-1).log().sum(-1)
52 | if covariance_type == "tied":
53 | return precisions_cholesky.diagonal().log().sum(-1)
54 | if covariance_type == "diag":
55 | return precisions_cholesky.log().sum(1)
56 | # covariance_type == "spherical"
57 | return precisions_cholesky.log() * num_features
58 |
59 |
60 | # -------------------------------------------------------------------------------------------------
61 |
62 |
63 | def jit_sample_normal(
64 | num: int,
65 | mean: torch.Tensor,
66 | cholesky_precisions: torch.Tensor,
67 | covariance_type: str,
68 | ) -> torch.Tensor:
69 | samples = torch.randn(num, mean.size(0), dtype=mean.dtype, device=mean.device)
70 | chol_covariance = _cholesky_covariance(cholesky_precisions, covariance_type)
71 |
72 | if covariance_type in ("tied", "full"):
73 | scale = chol_covariance.matmul(samples.unsqueeze(-1)).squeeze(-1)
74 | else:
75 | scale = chol_covariance * samples
76 |
77 | return mean + scale
78 |
79 |
80 | def _cholesky_covariance(chol_precision: torch.Tensor, covariance_type: str) -> torch.Tensor:
81 | # For complex covariance types, invert the
82 | if covariance_type in ("tied", "full"):
83 | num_features = chol_precision.size(-1)
84 | target = torch.eye(num_features, dtype=chol_precision.dtype, device=chol_precision.device)
85 | return torch.linalg.solve_triangular(chol_precision, target, upper=True).t()
86 |
87 | # Simple covariance type
88 | return chol_precision.reciprocal()
89 |
--------------------------------------------------------------------------------
/tests/bayes/gmm/test_gmm_metrics.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=protected-access,missing-function-docstring
2 | from typing import Any, Callable
3 | import numpy as np
4 | import sklearn.mixture._gaussian_mixture as skgmm # type: ignore
5 | import torch
6 | from pycave.bayes.core import CovarianceType
7 | from pycave.bayes.gmm.metrics import CovarianceAggregator, MeanAggregator, PriorAggregator
8 |
9 |
10 | def test_prior_aggregator():
11 | aggregator = PriorAggregator(3)
12 | aggregator.reset()
13 |
14 | # Step 1: single batch
15 | responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
16 | actual = aggregator.forward(responsibilities1)
17 | expected = torch.tensor([0.5, 0.3, 0.2])
18 | assert torch.allclose(actual, expected)
19 |
20 | # Step 2: batch aggregation
21 | responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]])
22 | aggregator.update(responsibilities2)
23 | actual = aggregator.compute()
24 | expected = torch.tensor([0.54, 0.3, 0.16])
25 | assert torch.allclose(actual, expected)
26 |
27 |
28 | def test_mean_aggregator():
29 | aggregator = MeanAggregator(3, 2)
30 | aggregator.reset()
31 |
32 | # Step 1: single batch
33 | data1 = torch.tensor([[5.0, 2.0], [3.0, 4.0], [1.0, 0.0]])
34 | responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
35 | actual = aggregator.forward(data1, responsibilities1)
36 | expected = torch.tensor([[2.8667, 2.5333], [2.5556, 1.1111], [4.0, 2.0]])
37 | assert torch.allclose(actual, expected, atol=1e-4)
38 |
39 | # Step 2: batch aggregation
40 | data2 = torch.tensor([[8.0, 2.5], [1.5, 4.0]])
41 | responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]])
42 | aggregator.update(data2, responsibilities2)
43 | actual = aggregator.compute()
44 | expected = torch.tensor([[3.9444, 2.7963], [3.0, 2.0667], [4.1875, 2.3125]])
45 | assert torch.allclose(actual, expected, atol=1e-4)
46 |
47 |
48 | def test_covariance_aggregator_spherical():
49 | _test_covariance("spherical", skgmm._estimate_gaussian_covariances_spherical) # type: ignore
50 |
51 |
52 | def test_covariance_aggregator_diag():
53 | _test_covariance("diag", skgmm._estimate_gaussian_covariances_diag) # type: ignore
54 |
55 |
56 | def test_covariance_aggregator_tied():
57 | _test_covariance("tied", skgmm._estimate_gaussian_covariances_tied) # type: ignore
58 |
59 |
60 | def test_covariance_aggregator_full():
61 | _test_covariance("full", skgmm._estimate_gaussian_covariances_full) # type: ignore
62 |
63 |
64 | def _test_covariance(
65 | covariance_type: CovarianceType,
66 | sk_aggregator: Callable[[Any, Any, Any, Any, Any], Any],
67 | ):
68 | reg = 1e-5
69 | aggregator = CovarianceAggregator(3, 2, covariance_type, reg=reg)
70 | aggregator.reset()
71 | means = torch.tensor([[3.0, 2.5], [2.5, 1.0], [4.0, 2.0]])
72 |
73 | # Step 1: single batch
74 | data1 = torch.tensor([[5.0, 2.0], [3.0, 4.0], [1.0, 0.0]])
75 | responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
76 | actual = aggregator.forward(data1, responsibilities1, means)
77 | expected = sk_aggregator( # type: ignore
78 | responsibilities1.numpy(),
79 | data1.numpy(),
80 | responsibilities1.sum(0).numpy(),
81 | means.numpy(),
82 | reg,
83 | ).astype(np.float32)
84 | assert torch.allclose(actual, torch.from_numpy(expected))
85 |
86 | # Step 2: batch aggregation
87 | data2 = torch.tensor([[8.0, 2.5], [1.5, 4.0]])
88 | responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]])
89 | aggregator.update(data2, responsibilities2, means)
90 | actual = aggregator.compute()
91 | expected = sk_aggregator( # type: ignore
92 | torch.cat([responsibilities1, responsibilities2]).numpy(),
93 | torch.cat([data1, data2]).numpy(),
94 | (responsibilities1.sum(0) + responsibilities2.sum(0)).numpy(),
95 | means.numpy(),
96 | reg,
97 | ).astype(np.float32)
98 | assert torch.allclose(actual, torch.from_numpy(expected))
99 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | PyCave Documentation
2 | ====================
3 |
4 | PyCave allows you to run traditional machine learning models on CPU, GPU, and even on multiple nodes. All models are implemented in `PyTorch `_ and provide an ``Estimator`` API that is fully compatible with `scikit-learn `_.
5 |
6 | .. image:: https://img.shields.io/pypi/v/pycave?label=version
7 | .. image:: https://img.shields.io/pypi/l/pycave
8 |
9 |
10 | Features
11 | --------
12 |
13 | - Support for GPU and multi-node training by implementing models in PyTorch and relying on `PyTorch Lightning `_
14 | - Mini-batch training for all models such that they can be used on huge datasets
15 | - Well-structured implementation of models
16 |
17 | - High-level ``Estimator`` API allows for easy usage such that models feel and behave like in
18 | scikit-learn
19 | - Medium-level ``LightingModule`` implements the training algorithm
20 | - Low-level PyTorch ``Module`` manages the model parameters
21 |
22 |
23 | Installation
24 | ------------
25 |
26 | PyCave is available via ``pip``:
27 |
28 | .. code-block:: python
29 |
30 | pip install pycave
31 |
32 | If you are using `Poetry `_:
33 |
34 | .. code-block:: python
35 |
36 | poetry add pycave
37 |
38 |
39 | Usage
40 | -----
41 |
42 | If you've ever used scikit-learn, you'll feel right at home when using PyCave. First, let's create
43 | some artificial data to work with:
44 |
45 | .. code-block:: python
46 |
47 | import torch
48 |
49 | X = torch.cat([
50 | torch.randn(10000, 8) - 5,
51 | torch.randn(10000, 8),
52 | torch.randn(10000, 8) + 5,
53 | ])
54 |
55 | This dataset consists of three clusters with 8-dimensional datapoints. If you want to fit a K-Means
56 | model, to find the clusters' centroids, it's as easy as:
57 |
58 | .. code-block:: python
59 |
60 | from pycave.clustering import KMeans
61 |
62 | estimator = KMeans(3)
63 | estimator.fit(X)
64 |
65 | # Once the estimator is fitted, it provides various properties. One of them is
66 | # the `model_` property which yields the PyTorch module with the fitted parameters.
67 | print("Centroids are:")
68 | print(estimator.model_.centroids)
69 |
70 | Due to the high-level estimator API, the usage for all machine learning models is similar. The API
71 | documentation provides more detailed information about parameters that can be passed to estimators
72 | and which methods are available.
73 |
74 | GPU and Multi-Node training
75 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^
76 |
77 | For GPU- and multi-node training, PyCave leverages PyTorch Lightning. The hardware that training
78 | runs on is determined by the :class:`pytorch_lightning.trainer.Trainer` class. It's
79 | :meth:`~pytorch_lightning.trainer.Trainer.__init__` method provides various configuration options.
80 |
81 | If you want to run K-Means with a GPU, you can pass the option ``accelerator='gpu'`` and ``devices=1`` to the estimator's
82 | initializer:
83 |
84 | .. code-block:: python
85 |
86 | estimator = KMeans(3, trainer_params=dict(accelerator='gpu', devices=1))
87 |
88 | Similarly, if you want to train on 4 nodes simultaneously where each node has one GPU available,
89 | you can specify this as follows:
90 |
91 | .. code-block:: python
92 |
93 | estimator = KMeans(3, trainer_params=dict(num_nodes=4, accelerator='gpu', 1))
94 |
95 | In fact, **you do not need to change anything else in your code**.
96 |
97 |
98 | Implemented Models
99 | ^^^^^^^^^^^^^^^^^^
100 |
101 | Currently, PyCave implements three different models. Some of these models are also available in
102 | scikit-learn. In this case, we benchmark our implementation against their (see
103 | :doc:`here `).
104 |
105 | .. currentmodule:: pycave
106 |
107 | .. autosummary::
108 | :nosignatures:
109 |
110 | ~bayes.GaussianMixture
111 | ~bayes.MarkovChain
112 | ~clustering.KMeans
113 |
114 | Reference
115 | ---------
116 |
117 | .. toctree::
118 | :maxdepth: 2
119 |
120 | sites/benchmark
121 | sites/api
122 |
123 | Index
124 | ^^^^^
125 |
126 | - :ref:`genindex`
127 |
--------------------------------------------------------------------------------
/tests/clustering/kmeans/benchmark_kmeans_estimator.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | from typing import Optional
3 | import pytest
4 | import pytorch_lightning as pl
5 | import torch
6 | from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
7 | from sklearn.cluster import KMeans as SklearnKMeans # type: ignore
8 | from pycave.clustering import KMeans
9 | from pycave.clustering.kmeans.types import KMeansInitStrategy
10 | from tests._data.gmm import sample_gmm
11 |
12 |
13 | @pytest.mark.parametrize(
14 | ("num_datapoints", "num_features", "num_centroids", "init_strategy"),
15 | [
16 | (10000, 8, 4, "k-means++"),
17 | (100000, 32, 16, "k-means++"),
18 | (1000000, 64, 64, "k-means++"),
19 | (10000000, 128, 128, "k-means++"),
20 | (10000, 8, 4, "random"),
21 | (100000, 32, 16, "random"),
22 | (1000000, 64, 64, "random"),
23 | (10000000, 128, 128, "random"),
24 | ],
25 | )
26 | def test_sklearn(
27 | benchmark: BenchmarkFixture,
28 | num_datapoints: int,
29 | num_features: int,
30 | num_centroids: int,
31 | init_strategy: str,
32 | ):
33 | pl.seed_everything(0)
34 | data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical")
35 |
36 | estimator = SklearnKMeans(
37 | num_centroids,
38 | algorithm="full",
39 | n_init=1,
40 | max_iter=100,
41 | tol=0,
42 | init=init_strategy,
43 | )
44 | benchmark(estimator.fit, data.numpy())
45 |
46 |
47 | @pytest.mark.parametrize(
48 | ("num_datapoints", "batch_size", "num_features", "num_centroids", "init_strategy"),
49 | [
50 | (10000, None, 8, 4, "kmeans++"),
51 | (10000, 1000, 8, 4, "kmeans++"),
52 | (100000, None, 32, 16, "kmeans++"),
53 | (100000, 10000, 32, 16, "kmeans++"),
54 | (1000000, None, 64, 64, "kmeans++"),
55 | (1000000, 100000, 64, 64, "kmeans++"),
56 | (10000, None, 8, 4, "random"),
57 | (10000, 1000, 8, 4, "random"),
58 | (100000, None, 32, 16, "random"),
59 | (100000, 10000, 32, 16, "random"),
60 | (1000000, None, 64, 64, "random"),
61 | (1000000, 100000, 64, 64, "random"),
62 | ],
63 | )
64 | def test_pycave(
65 | benchmark: BenchmarkFixture,
66 | num_datapoints: int,
67 | batch_size: Optional[int],
68 | num_features: int,
69 | num_centroids: int,
70 | init_strategy: KMeansInitStrategy,
71 | ):
72 | pl.seed_everything(0)
73 | data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical")
74 |
75 | estimator = KMeans(
76 | num_centroids,
77 | init_strategy=init_strategy,
78 | batch_size=batch_size,
79 | trainer_params=dict(max_epochs=100),
80 | )
81 | benchmark(estimator.fit, data)
82 |
83 |
84 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
85 | @pytest.mark.parametrize(
86 | ("num_datapoints", "batch_size", "num_features", "num_centroids", "init_strategy"),
87 | [
88 | (10000, None, 8, 4, "kmeans++"),
89 | (10000, 1000, 8, 4, "kmeans++"),
90 | (100000, None, 32, 16, "kmeans++"),
91 | (100000, 10000, 32, 16, "kmeans++"),
92 | (1000000, None, 64, 64, "kmeans++"),
93 | (1000000, 100000, 64, 64, "kmeans++"),
94 | (10000000, 1000000, 128, 128, "kmeans++"),
95 | (10000, None, 8, 4, "random"),
96 | (10000, 1000, 8, 4, "random"),
97 | (100000, None, 32, 16, "random"),
98 | (100000, 10000, 32, 16, "random"),
99 | (1000000, None, 64, 64, "random"),
100 | (1000000, 100000, 64, 64, "random"),
101 | (10000000, 1000000, 128, 128, "random"),
102 | ],
103 | )
104 | def test_pycave_gpu(
105 | benchmark: BenchmarkFixture,
106 | num_datapoints: int,
107 | batch_size: Optional[int],
108 | num_features: int,
109 | num_centroids: int,
110 | init_strategy: KMeansInitStrategy,
111 | ):
112 | # Initialize GPU
113 | torch.empty(1, device="cuda:0")
114 |
115 | pl.seed_everything(0)
116 | data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical")
117 |
118 | estimator = KMeans(
119 | num_centroids,
120 | init_strategy=init_strategy,
121 | batch_size=batch_size,
122 | convergence_tolerance=0,
123 | trainer_params=dict(max_epochs=100, accelerator="gpu", devices=1),
124 | )
125 | benchmark(estimator.fit, data)
126 |
--------------------------------------------------------------------------------
/tests/bayes/gmm/benchmark_gmm_estimator.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | from typing import Optional
3 | import pytest
4 | import pytorch_lightning as pl
5 | import torch
6 | from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
7 | from sklearn.mixture import GaussianMixture as SklearnGaussianMixture # type: ignore
8 | from pycave.bayes import GaussianMixture
9 | from pycave.bayes.core.types import CovarianceType
10 | from tests._data.gmm import sample_gmm
11 |
12 |
13 | @pytest.mark.parametrize(
14 | ("num_datapoints", "num_features", "num_components", "covariance_type"),
15 | [
16 | (10000, 8, 4, "diag"),
17 | (10000, 8, 4, "tied"),
18 | (10000, 8, 4, "full"),
19 | (100000, 32, 16, "diag"),
20 | (100000, 32, 16, "tied"),
21 | (100000, 32, 16, "full"),
22 | (1000000, 64, 64, "diag"),
23 | ],
24 | )
25 | def test_sklearn(
26 | benchmark: BenchmarkFixture,
27 | num_datapoints: int,
28 | num_features: int,
29 | num_components: int,
30 | covariance_type: CovarianceType,
31 | ):
32 | pl.seed_everything(0)
33 | data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type)
34 |
35 | estimator = SklearnGaussianMixture(
36 | num_components,
37 | covariance_type=covariance_type,
38 | tol=0,
39 | n_init=1,
40 | max_iter=100,
41 | reg_covar=1e-3,
42 | init_params="random",
43 | means_init=means.numpy(),
44 | )
45 | benchmark(estimator.fit, data.numpy())
46 |
47 |
48 | @pytest.mark.parametrize(
49 | ("num_datapoints", "num_features", "num_components", "covariance_type", "batch_size"),
50 | [
51 | (10000, 8, 4, "diag", None),
52 | (10000, 8, 4, "tied", None),
53 | (10000, 8, 4, "full", None),
54 | (100000, 32, 16, "diag", None),
55 | (100000, 32, 16, "tied", None),
56 | (100000, 32, 16, "full", None),
57 | (1000000, 64, 64, "diag", None),
58 | (10000, 8, 4, "diag", 1000),
59 | (10000, 8, 4, "tied", 1000),
60 | (10000, 8, 4, "full", 1000),
61 | (100000, 32, 16, "diag", 10000),
62 | (100000, 32, 16, "tied", 10000),
63 | (100000, 32, 16, "full", 10000),
64 | (1000000, 64, 64, "diag", 100000),
65 | ],
66 | )
67 | def test_pycave(
68 | benchmark: BenchmarkFixture,
69 | num_datapoints: int,
70 | num_features: int,
71 | num_components: int,
72 | covariance_type: CovarianceType,
73 | batch_size: Optional[int],
74 | ):
75 | pl.seed_everything(0)
76 | data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type)
77 |
78 | estimator = GaussianMixture(
79 | num_components,
80 | covariance_type=covariance_type,
81 | init_means=means,
82 | convergence_tolerance=0,
83 | covariance_regularization=1e-3,
84 | batch_size=batch_size,
85 | trainer_params=dict(max_epochs=100),
86 | )
87 | benchmark(estimator.fit, data)
88 |
89 |
90 | @pytest.mark.parametrize(
91 | ("num_datapoints", "num_features", "num_components", "covariance_type", "batch_size"),
92 | [
93 | (10000, 8, 4, "diag", None),
94 | (10000, 8, 4, "tied", None),
95 | (10000, 8, 4, "full", None),
96 | (100000, 32, 16, "diag", None),
97 | (100000, 32, 16, "tied", None),
98 | (100000, 32, 16, "full", None),
99 | (1000000, 64, 64, "diag", None),
100 | (10000, 8, 4, "diag", 1000),
101 | (10000, 8, 4, "tied", 1000),
102 | (10000, 8, 4, "full", 1000),
103 | (100000, 32, 16, "diag", 10000),
104 | (100000, 32, 16, "tied", 10000),
105 | (100000, 32, 16, "full", 10000),
106 | (1000000, 64, 64, "diag", 100000),
107 | (1000000, 64, 64, "tied", 100000),
108 | ],
109 | )
110 | def test_pycave_gpu(
111 | benchmark: BenchmarkFixture,
112 | num_datapoints: int,
113 | num_features: int,
114 | num_components: int,
115 | covariance_type: CovarianceType,
116 | batch_size: Optional[int],
117 | ):
118 | # Initialize GPU
119 | torch.empty(1, device="cuda:0")
120 |
121 | pl.seed_everything(0)
122 | data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type)
123 |
124 | estimator = GaussianMixture(
125 | num_components,
126 | covariance_type=covariance_type,
127 | init_means=means,
128 | convergence_tolerance=0,
129 | covariance_regularization=1e-3,
130 | batch_size=batch_size,
131 | trainer_params=dict(max_epochs=100, accelerator="gpu", devices=1),
132 | )
133 | benchmark(estimator.fit, data)
134 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyCave
2 |
3 | 
4 | 
5 |
6 | PyCave allows you to run traditional machine learning models on CPU, GPU, and even on multiple
7 | nodes. All models are implemented in [PyTorch](https://pytorch.org/) and provide an `Estimator` API
8 | that is fully compatible with [scikit-learn](https://scikit-learn.org/stable/).
9 |
10 | For Gaussian mixture model, PyCave allows for 100x speed ups when using a GPU and enables to train
11 | on markedly larger datasets via mini-batch training. The full suite of benchmarks run to compare
12 | PyCave models against scikit-learn models is available on the
13 | [documentation website](https://pycave.borchero.com/sites/benchmark.html).
14 |
15 | _PyCave version 3 is a complete rewrite of PyCave which is tested much more rigorously, depends on
16 | well-maintained libraries and is tuned for better performance. While you are, thus, highly
17 | encouraged to upgrade, refer to [pycave-v2.borchero.com](https://pycave-v2.borchero.com) for
18 | documentation on PyCave 2._
19 |
20 | ## Features
21 |
22 | - Support for GPU and multi-node training by implementing models in PyTorch and relying on
23 | [PyTorch Lightning](https://www.pytorchlightning.ai/)
24 | - Mini-batch training for all models such that they can be used on huge datasets
25 | - Well-structured implementation of models
26 |
27 | - High-level `Estimator` API allows for easy usage such that models feel and behave like in
28 | scikit-learn
29 | - Medium-level `LightingModule` implements the training algorithm
30 | - Low-level PyTorch `Module` manages the model parameters
31 |
32 | ## Installation
33 |
34 | PyCave is available via `pip`:
35 |
36 | ```bash
37 | pip install pycave
38 | ```
39 |
40 | If you are using [Poetry](https://python-poetry.org/):
41 |
42 | ```bash
43 | poetry add pycave
44 | ```
45 |
46 | ## Usage
47 |
48 | If you've ever used scikit-learn, you'll feel right at home when using PyCave. First, let's create
49 | some artificial data to work with:
50 |
51 | ```python
52 | import torch
53 |
54 | X = torch.cat([
55 | torch.randn(10000, 8) - 5,
56 | torch.randn(10000, 8),
57 | torch.randn(10000, 8) + 5,
58 | ])
59 | ```
60 |
61 | This dataset consists of three clusters with 8-dimensional datapoints. If you want to fit a K-Means
62 | model, to find the clusters' centroids, it's as easy as:
63 |
64 | ```python
65 | from pycave.clustering import KMeans
66 |
67 | estimator = KMeans(3)
68 | estimator.fit(X)
69 |
70 | # Once the estimator is fitted, it provides various properties. One of them is
71 | # the `model_` property which yields the PyTorch module with the fitted parameters.
72 | print("Centroids are:")
73 | print(estimator.model_.centroids)
74 | ```
75 |
76 | Due to the high-level estimator API, the usage for all machine learning models is similar. The API
77 | documentation provides more detailed information about parameters that can be passed to estimators
78 | and which methods are available.
79 |
80 | ### GPU and Multi-Node training
81 |
82 | For GPU- and multi-node training, PyCave leverages PyTorch Lightning. The hardware that training
83 | runs on is determined by the
84 | [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.trainer.html#pytorch_lightning.trainer.trainer.Trainer)
85 | class. It's
86 | [**init**](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.trainer.html#pytorch_lightning.trainer.trainer.Trainer.__init__)
87 | method provides various configuration options.
88 |
89 | If you want to run K-Means with a GPU, you can pass the options `accelerator='gpu'` and `devices=1`
90 | to the estimator's initializer:
91 |
92 | ```python
93 | estimator = KMeans(3, trainer_params=dict(accelerator='gpu', devices=1))
94 | ```
95 |
96 | Similarly, if you want to train on 4 nodes simultaneously where each node has one GPU available,
97 | you can specify this as follows:
98 |
99 | ```python
100 | estimator = KMeans(3, trainer_params=dict(num_nodes=4, accelerator='gpu', devices=1))
101 | ```
102 |
103 | In fact, **you do not need to change anything else in your code**.
104 |
105 | ### Implemented Models
106 |
107 | Currently, PyCave implements three different models:
108 |
109 | - [GaussianMixture](https://pycave.borchero.com/sites/generated/bayes/gmm/pycave.bayes.GaussianMixture.html)
110 | - [MarkovChain](https://pycave.borchero.com/sites/generated/bayes/markov_chain/pycave.bayes.MarkovChain.html)
111 | - [K-Means](https://pycave.borchero.com/sites/generated/clustering/kmeans/pycave.clustering.KMeans.html)
112 |
113 | ## License
114 |
115 | PyCave is licensed under the [MIT License](https://github.com/borchero/pycave/blob/main/LICENSE).
116 |
--------------------------------------------------------------------------------
/pycave/bayes/core/normal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ._jit import jit_log_normal, jit_sample_normal
3 | from .types import CovarianceType
4 |
5 |
6 | def cholesky_precision(covariances: torch.Tensor, covariance_type: CovarianceType) -> torch.Tensor:
7 | """
8 | Computes the Cholesky decompositions of the precision matrices induced by the provided
9 | covariance matrices.
10 |
11 | Args:
12 | covariances: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``,
13 | ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the
14 | ``covariance_type``. These are the covariance matrices of multivariate Normal
15 | distributions.
16 | covariance_type: The type of covariance for the covariance matrices given.
17 |
18 | Returns:
19 | A tensor of the same shape as ``covariances``, providing the lower-triangular Cholesky
20 | decompositions of the precision matrices.
21 | """
22 | if covariance_type in ("tied", "full"):
23 | # Compute Cholesky decomposition
24 | cholesky = torch.linalg.cholesky(covariances)
25 | # Invert
26 | num_features = covariances.size(-1)
27 | target = torch.eye(num_features, dtype=covariances.dtype, device=covariances.device)
28 | if covariance_type == "full":
29 | num_components = covariances.size(0)
30 | target = target.unsqueeze(0).expand(num_components, -1, -1)
31 | return torch.linalg.solve_triangular(cholesky, target, upper=False).transpose(-2, -1)
32 |
33 | # "Simple" kind of covariance
34 | return covariances.sqrt().reciprocal()
35 |
36 |
37 | def covariance(cholesky_precisions: torch.Tensor, covariance_type: CovarianceType) -> torch.Tensor:
38 | """
39 | Computes the covariances matrices of the provided Cholesky decompositions of the precision
40 | matrices. This function is the inverse of :meth:`cholesky_precision`.
41 |
42 | Args:
43 | cholesky_precisions: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``,
44 | ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the
45 | ``covariance_type``. These are the Cholesky decompositions of the precisions of
46 | multivariate Normal distributions.
47 | covariance_type: The type of covariance for the covariance matrices given.
48 |
49 | Returns:
50 | A tensor of the same shape as ``cholesky_precisions``, providing the covariance matrices
51 | corresponding to the given Cholesky-decomposed precision matrices.
52 | """
53 | if covariance_type in ("tied", "full"):
54 | choleksy_covars = torch.linalg.inv(cholesky_precisions)
55 | if covariance_type == "tied":
56 | return torch.matmul(choleksy_covars.T, choleksy_covars)
57 | return torch.bmm(choleksy_covars.transpose(1, 2), choleksy_covars)
58 |
59 | # "Simple" kind of covariance
60 | return (cholesky_precisions**2).reciprocal()
61 |
62 |
63 | def log_normal(
64 | x: torch.Tensor,
65 | means: torch.Tensor,
66 | precisions_cholesky: torch.Tensor,
67 | covariance_type: CovarianceType,
68 | ) -> torch.Tensor:
69 | """
70 | Computes the log-probability of the given data for multiple multivariate Normal distributions
71 | defined by their means and covariances.
72 |
73 | Args:
74 | x: A tensor of shape ``[num_datapoints, dim]``. This is the data to compute the
75 | log-probability for.
76 | means: A tensor of shape ``[num_components, dim]``. These are the means of the multivariate
77 | Normal distributions.
78 | precisions_cholesky: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``,
79 | ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the
80 | ``covariance_type``. These are the upper-triangular Cholesky matrices for the inverse
81 | covariance matrices (i.e. precision matrices) of the multivariate Normal distributions.
82 | covariance_type: The type of covariance for the covariance matrices given.
83 |
84 | Returns:
85 | A tensor of shape ``[num_datapoints, num_components]`` with the log-probabilities for each
86 | datapoint and each multivariate Normal distribution.
87 | """
88 | return jit_log_normal(x, means, precisions_cholesky, covariance_type)
89 |
90 |
91 | def sample_normal(
92 | num: int,
93 | mean: torch.Tensor,
94 | cholesky_precisions: torch.Tensor,
95 | covariance_type: CovarianceType,
96 | ) -> torch.Tensor:
97 | """
98 | Samples the given number of times from the multivariate Normal distribution described by the
99 | mean and Cholesky precision.
100 |
101 | Args:
102 | num: The number of times to sample.
103 | means: A tensor of shape ``[dim]`` with the mean of the distribution to sample from.
104 | choleksy_precisions: A tensor of shape ``[dim, dim]``, ``[dim]``, ``[dim]`` or ``[1]``
105 | depending on the ``covariance_type``. This is the corresponding Cholesky precision
106 | matrix for the mean.
107 | covariance_type: The type of covariance for the covariance matrix given.
108 |
109 | Returns:
110 | A tensor of shape ``[num_samples, dim]`` with the samples from the Normal distribution.
111 | """
112 | return jit_sample_normal(num, mean, cholesky_precisions, covariance_type)
113 |
--------------------------------------------------------------------------------
/tests/bayes/core/benchmark_log_normal.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import numpy as np
3 | import torch
4 | from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
5 | from sklearn.mixture._gaussian_mixture import _compute_precision_cholesky # type: ignore
6 | from sklearn.mixture._gaussian_mixture import _estimate_log_gaussian_prob # type: ignore
7 | from torch.distributions import MultivariateNormal
8 | from pycave.bayes.core import cholesky_precision, log_normal
9 |
10 |
11 | def test_log_normal_spherical(benchmark: BenchmarkFixture):
12 | data = torch.randn(10000, 100)
13 | means = torch.randn(50, 100)
14 | precisions = cholesky_precision(torch.rand(50), "spherical")
15 | benchmark(log_normal, data, means, precisions, covariance_type="spherical")
16 |
17 |
18 | def test_torch_log_normal_spherical(benchmark: BenchmarkFixture):
19 | data = torch.randn(10000, 100)
20 | means = torch.randn(50, 100)
21 | covars = torch.rand(50)
22 | covar_matrices = torch.stack([torch.eye(means.size(-1)) * c for c in covars])
23 |
24 | cholesky = torch.linalg.cholesky(covar_matrices)
25 | distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False)
26 | benchmark(distribution.log_prob, data.unsqueeze(1))
27 |
28 |
29 | def test_numpy_log_normal_spherical(benchmark: BenchmarkFixture):
30 | data = np.random.randn(10000, 100)
31 | means = np.random.randn(50, 100)
32 | covars = np.random.rand(50)
33 | benchmark(_estimate_log_gaussian_prob, data, means, covars, "spherical") # type: ignore
34 |
35 |
36 | # -------------------------------------------------------------------------------------------------
37 |
38 |
39 | def test_log_normal_diag(benchmark: BenchmarkFixture):
40 | data = torch.randn(10000, 100)
41 | means = torch.randn(50, 100)
42 | precisions = cholesky_precision(torch.rand(50, 100), "diag")
43 | benchmark(log_normal, data, means, precisions, covariance_type="diag")
44 |
45 |
46 | def test_torch_log_normal_diag(benchmark: BenchmarkFixture):
47 | data = torch.randn(10000, 100)
48 | means = torch.randn(50, 100)
49 | covars = torch.rand(50, 100)
50 | covar_matrices = torch.stack([torch.diag(c) for c in covars])
51 |
52 | cholesky = torch.linalg.cholesky(covar_matrices)
53 | distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False)
54 | benchmark(distribution.log_prob, data.unsqueeze(1))
55 |
56 |
57 | def test_numpy_log_normal_diag(benchmark: BenchmarkFixture):
58 | data = np.random.randn(10000, 100)
59 | means = np.random.randn(50, 100)
60 | covars = np.random.rand(50, 100)
61 | benchmark(_estimate_log_gaussian_prob, data, means, covars, "diag") # type: ignore
62 |
63 |
64 | # -------------------------------------------------------------------------------------------------
65 |
66 |
67 | def test_log_normal_full(benchmark: BenchmarkFixture):
68 | data = torch.randn(10000, 100)
69 | means = torch.randn(50, 100)
70 | A = torch.randn(50, 1000, 100)
71 | covars = A.permute(0, 2, 1).bmm(A)
72 | precisions = cholesky_precision(covars, "full")
73 | benchmark(log_normal, data, means, precisions, covariance_type="full")
74 |
75 |
76 | def test_torch_log_normal_full(benchmark: BenchmarkFixture):
77 | data = torch.randn(10000, 100)
78 | means = torch.randn(50, 100)
79 | A = torch.randn(50, 1000, 100)
80 | covars = A.permute(0, 2, 1).bmm(A)
81 |
82 | cholesky = torch.linalg.cholesky(covars)
83 | distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False)
84 | benchmark(distribution.log_prob, data.unsqueeze(1))
85 |
86 |
87 | def test_numpy_log_normal_full(benchmark: BenchmarkFixture):
88 | data = np.random.randn(10000, 100)
89 | means = np.random.randn(50, 100)
90 | A = np.random.randn(50, 1000, 100)
91 | covars = np.matmul(np.transpose(A, (0, 2, 1)), A)
92 |
93 | precisions = _compute_precision_cholesky(covars, "full") # type: ignore
94 | benchmark(
95 | _estimate_log_gaussian_prob, # type: ignore
96 | data,
97 | means,
98 | precisions,
99 | covariance_type="full",
100 | )
101 |
102 |
103 | # -------------------------------------------------------------------------------------------------
104 |
105 |
106 | def test_log_normal_tied(benchmark: BenchmarkFixture):
107 | data = torch.randn(10000, 100)
108 | means = torch.randn(50, 100)
109 | A = torch.randn(1000, 100)
110 | covars = A.t().mm(A)
111 | precisions = cholesky_precision(covars, "tied")
112 | benchmark(log_normal, data, means, precisions, covariance_type="tied")
113 |
114 |
115 | def test_torch_log_normal_tied(benchmark: BenchmarkFixture):
116 | data = torch.randn(10000, 100)
117 | means = torch.randn(50, 100)
118 | A = torch.randn(1000, 100)
119 | covars = A.t().mm(A)
120 |
121 | cholesky = torch.linalg.cholesky(covars)
122 | distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False)
123 | benchmark(distribution.log_prob, data.unsqueeze(1))
124 |
125 |
126 | def test_numpy_log_normal_tied(benchmark: BenchmarkFixture):
127 | data = np.random.randn(10000, 100)
128 | means = np.random.randn(50, 100)
129 | A = np.random.randn(1000, 100)
130 | covars = A.T.dot(A)
131 |
132 | precisions = _compute_precision_cholesky(covars, "tied") # type: ignore
133 | benchmark(
134 | _estimate_log_gaussian_prob, # type: ignore
135 | data,
136 | means,
137 | precisions,
138 | covariance_type="tied",
139 | )
140 |
--------------------------------------------------------------------------------
/pycave/bayes/gmm/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Tuple
3 | import numpy as np
4 | import torch
5 | from lightkit.nn import Configurable
6 | from torch import jit, nn
7 | from pycave.bayes.core import covariance, covariance_shape, CovarianceType
8 | from pycave.bayes.core._jit import jit_log_normal, jit_sample_normal
9 |
10 |
11 | @dataclass
12 | class GaussianMixtureModelConfig:
13 | """
14 | Configuration class for a Gaussian mixture model.
15 |
16 | See also:
17 | :class:`GaussianMixtureModel`
18 | """
19 |
20 | #: The number of components in the GMM.
21 | num_components: int
22 | #: The number of features for the GMM's components.
23 | num_features: int
24 | #: The type of covariance to use for the components.
25 | covariance_type: CovarianceType
26 |
27 |
28 | class GaussianMixtureModel(Configurable[GaussianMixtureModelConfig], nn.Module):
29 | """
30 | PyTorch module for a Gaussian mixture model.
31 |
32 | Covariances are represented via their Cholesky decomposition for computational efficiency. The
33 | model does not have trainable parameters.
34 | """
35 |
36 | #: The probabilities of each component, buffer of shape ``[num_components]``.
37 | component_probs: torch.Tensor
38 | #: The means of each component, buffer of shape ``[num_components, num_features]``.
39 | means: torch.Tensor
40 | #: The precision matrices for the components' covariances, buffer with a shape dependent
41 | #: on the covariance type, see :class:`CovarianceType`.
42 | precisions_cholesky: torch.Tensor
43 |
44 | def __init__(self, config: GaussianMixtureModelConfig):
45 | """
46 | Args:
47 | config: The configuration to use for initializing the module's buffers.
48 | """
49 | super().__init__(config)
50 |
51 | self.covariance_type = config.covariance_type
52 |
53 | self.register_buffer("component_probs", torch.empty(config.num_components))
54 | self.register_buffer("means", torch.empty(config.num_components, config.num_features))
55 |
56 | shape = covariance_shape(
57 | config.num_components, config.num_features, config.covariance_type
58 | )
59 | self.register_buffer("precisions_cholesky", torch.empty(shape))
60 |
61 | self.reset_parameters()
62 |
63 | @jit.unused # type: ignore
64 | @property
65 | def covariances(self) -> torch.Tensor:
66 | """
67 | The covariance matrices learnt for the GMM's components.
68 |
69 | The shape of the tensor depends on the covariance type, see :class:`CovarianceType`.
70 | """
71 | return covariance(self.precisions_cholesky, self.covariance_type) # type: ignore
72 |
73 | @jit.unused
74 | def reset_parameters(self) -> None:
75 | """
76 | Resets the parameters of the GMM.
77 |
78 | - Component probabilities are initialized via uniform sampling and normalization.
79 | - Means are initialized randomly from a Standard Normal.
80 | - Cholesky precisions are initialized randomly based on the covariance type. For all
81 | covariance types, it is based on uniform sampling.
82 | """
83 | nn.init.uniform_(self.component_probs)
84 | self.component_probs.div_(self.component_probs.sum())
85 |
86 | nn.init.normal_(self.means)
87 |
88 | nn.init.uniform_(self.precisions_cholesky)
89 | if self.covariance_type in ("full", "tied"):
90 | self.precisions_cholesky.tril_()
91 |
92 | def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
93 | """
94 | Computes the log-probability of observing each of the provided datapoints for each of the
95 | GMM's components.
96 |
97 | Args:
98 | data: A tensor of shape ``[num_datapoints, num_features]`` for which to compute the
99 | log-probabilities.
100 |
101 | Returns:
102 | - A tensor of shape ``[num_datapoints, num_components]`` with the log-responsibilities
103 | for each datapoint and components. These are the logits of the Categorical
104 | distribution over the parameters.
105 | - A tensor of shape ``[num_datapoints]`` with the log-likelihood of each datapoint.
106 | """
107 | log_probabilities = jit_log_normal(
108 | data, self.means, self.precisions_cholesky, self.covariance_type
109 | )
110 | log_responsibilities = log_probabilities + self.component_probs.log()
111 | log_prob = log_responsibilities.logsumexp(1, keepdim=True)
112 | return log_responsibilities - log_prob, log_prob.squeeze(1)
113 |
114 | def sample(self, num_datapoints: int) -> torch.Tensor:
115 | """
116 | Samples the provided number of datapoints from the GMM.
117 |
118 | Args:
119 | num_datapoints: The number of datapoints to sample.
120 |
121 | Returns:
122 | A tensor of shape ``[num_datapoints, num_features]`` with the random samples.
123 |
124 | Attention:
125 | This method does not automatically perform batching. If you need to sample many
126 | datapoints, call this method multiple times.
127 | """
128 | # First, we sample counts for each
129 | component_counts = np.random.multinomial(num_datapoints, self.component_probs.numpy())
130 |
131 | # Then, we generate datapoints for each components
132 | result = []
133 | for i, count in enumerate(component_counts):
134 | sample = jit_sample_normal(
135 | count.item(),
136 | self.means[i],
137 | self._get_component_precision(i),
138 | self.covariance_type,
139 | )
140 | result.append(sample)
141 |
142 | return torch.cat(result, dim=0)
143 |
144 | def _get_component_precision(self, component: int) -> torch.Tensor:
145 | if self.covariance_type == "tied":
146 | return self.precisions_cholesky
147 | return self.precisions_cholesky[component]
148 |
--------------------------------------------------------------------------------
/pycave/bayes/gmm/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional
2 | import torch
3 | from torchmetrics import Metric
4 | from pycave.bayes.core import covariance_shape, CovarianceType
5 |
6 |
7 | class PriorAggregator(Metric):
8 | """
9 | The prior aggregator aggregates component probabilities over batches and process.
10 | """
11 |
12 | full_state_update = False
13 |
14 | def __init__(
15 | self,
16 | num_components: int,
17 | *,
18 | dist_sync_fn: Optional[Callable[[Any], Any]] = None,
19 | ):
20 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
21 |
22 | self.responsibilities: torch.Tensor
23 | self.add_state("responsibilities", torch.zeros(num_components), dist_reduce_fx="sum")
24 |
25 | def update(self, responsibilities: torch.Tensor) -> None:
26 | # Responsibilities have shape [N, K]
27 | self.responsibilities.add_(responsibilities.sum(0))
28 |
29 | def compute(self) -> torch.Tensor:
30 | return self.responsibilities / self.responsibilities.sum()
31 |
32 |
33 | class MeanAggregator(Metric):
34 | """
35 | The mean aggregator aggregates component means over batches and processes.
36 | """
37 |
38 | full_state_update = False
39 |
40 | def __init__(
41 | self,
42 | num_components: int,
43 | num_features: int,
44 | *,
45 | dist_sync_fn: Optional[Callable[[Any], Any]] = None,
46 | ):
47 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
48 |
49 | self.mean_sum: torch.Tensor
50 | self.add_state("mean_sum", torch.zeros(num_components, num_features), dist_reduce_fx="sum")
51 |
52 | self.component_weights: torch.Tensor
53 | self.add_state("component_weights", torch.zeros(num_components), dist_reduce_fx="sum")
54 |
55 | def update(self, data: torch.Tensor, responsibilities: torch.Tensor) -> None:
56 | # Data has shape [N, D]
57 | # Responsibilities have shape [N, K]
58 | self.mean_sum.add_(responsibilities.t().matmul(data))
59 | self.component_weights.add_(responsibilities.sum(0))
60 |
61 | def compute(self) -> torch.Tensor:
62 | return self.mean_sum / self.component_weights.unsqueeze(1)
63 |
64 |
65 | class CovarianceAggregator(Metric):
66 | """
67 | The covariance aggregator aggregates component covariances over batches and processes.
68 | """
69 |
70 | full_state_update = False
71 |
72 | def __init__(
73 | self,
74 | num_components: int,
75 | num_features: int,
76 | covariance_type: CovarianceType,
77 | reg: float,
78 | *,
79 | dist_sync_fn: Optional[Callable[[Any], Any]] = None,
80 | ):
81 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
82 |
83 | self.num_components = num_components
84 | self.num_features = num_features
85 | self.covariance_type = covariance_type
86 | self.reg = reg
87 |
88 | self.covariance_sum: torch.Tensor
89 | self.add_state(
90 | "covariance_sum",
91 | torch.zeros(covariance_shape(num_components, num_features, covariance_type)),
92 | dist_reduce_fx="sum",
93 | )
94 |
95 | self.component_weights: torch.Tensor
96 | self.add_state("component_weights", torch.zeros(num_components), dist_reduce_fx="sum")
97 |
98 | def update(
99 | self, data: torch.Tensor, responsibilities: torch.Tensor, means: torch.Tensor
100 | ) -> None:
101 | data_component_weights = responsibilities.sum(0)
102 | self.component_weights.add_(data_component_weights)
103 |
104 | if self.covariance_type in ("spherical", "diag"):
105 | x_prob = torch.matmul(responsibilities.t(), data.square())
106 | m_prob = data_component_weights.unsqueeze(-1) * means.square()
107 | xm_prob = means * torch.matmul(responsibilities.t(), data)
108 | covars = x_prob - 2 * xm_prob + m_prob
109 | if self.covariance_type == "diag":
110 | self.covariance_sum.add_(covars)
111 | else: # covariance_type == "spherical"
112 | self.covariance_sum.add_(covars.mean(1))
113 | elif self.covariance_type == "tied":
114 | # This is taken from https://github.com/scikit-learn/scikit-learn/blob/
115 | # 844b4be24d20fc42cc13b957374c718956a0db39/sklearn/mixture/_gaussian_mixture.py#L183
116 | x_sq = data.T.matmul(data)
117 | mean_sq = (data_component_weights * means.T).matmul(means)
118 | self.covariance_sum.add_(x_sq - mean_sq)
119 | else: # covariance_type == "full":
120 | # We iterate over each component since this is typically faster...
121 | for i in range(self.num_components):
122 | component_diff = data - means[i]
123 | covars = (responsibilities[:, i].unsqueeze(1) * component_diff).T.matmul(
124 | component_diff
125 | )
126 | self.covariance_sum[i].add_(covars)
127 |
128 | def compute(self) -> torch.Tensor:
129 | if self.covariance_type == "diag":
130 | return self.covariance_sum / self.component_weights.unsqueeze(-1) + self.reg
131 | if self.covariance_type == "spherical":
132 | return self.covariance_sum / self.component_weights + self.reg * self.num_features
133 | if self.covariance_type == "tied":
134 | result = self.covariance_sum / self.component_weights.sum()
135 | shape = result.size()
136 | result = result.flatten()
137 | result[:: self.num_features + 1].add_(self.reg)
138 | return result.view(shape)
139 | # covariance_type == "full"
140 | result = self.covariance_sum / self.component_weights.unsqueeze(-1).unsqueeze(-1)
141 | diag_mask = (
142 | torch.eye(self.num_features, device=result.device, dtype=result.dtype)
143 | .bool()
144 | .unsqueeze(0)
145 | .expand(self.num_components, -1, -1)
146 | )
147 | result[diag_mask] += self.reg
148 | return result
149 |
--------------------------------------------------------------------------------
/pycave/bayes/markov_chain/model.py:
--------------------------------------------------------------------------------
1 | # pyright: reportPrivateUsage=false, reportUnknownParameterType=false
2 | from dataclasses import dataclass
3 | from typing import overload
4 | import torch
5 | import torch._jit_internal as _jit
6 | from lightkit.nn import Configurable
7 | from torch import jit, nn
8 | from torch.nn.utils.rnn import PackedSequence
9 |
10 |
11 | @dataclass
12 | class MarkovChainModelConfig:
13 | """
14 | Configuration class for a Markov chain model.
15 |
16 | See also:
17 | :class:`MarkovChainModel`
18 | """
19 |
20 | #: The number of states that are managed by the Markov chain.
21 | num_states: int
22 |
23 |
24 | class MarkovChainModel(Configurable[MarkovChainModelConfig], nn.Module):
25 | """
26 | PyTorch module for a Markov chain.
27 |
28 | The initial state probabilities as well as the transition probabilities are non-trainable
29 | parameters.
30 | """
31 |
32 | def __init__(self, config: MarkovChainModelConfig):
33 | """
34 | Args:
35 | config: The configuration to use for initializing the module's buffers.
36 | """
37 | super().__init__(config)
38 |
39 | #: The probabilities for the initial states, buffer of shape ``[num_states]``.
40 | self.initial_probs: torch.Tensor
41 | self.register_buffer("initial_probs", torch.empty(config.num_states))
42 |
43 | #: The transition probabilities between all states, buffer of shape
44 | #: ``[num_states, num_states]``.
45 | self.transition_probs: torch.Tensor
46 | self.register_buffer("transition_probs", torch.empty(config.num_states, config.num_states))
47 |
48 | self.reset_parameters()
49 |
50 | @jit.unused
51 | def reset_parameters(self) -> None:
52 | """
53 | Resets the parameters of the Markov model.
54 |
55 | Initial and transition probabilities are sampled uniformly.
56 | """
57 | nn.init.uniform_(self.initial_probs)
58 | self.initial_probs.div_(self.initial_probs.sum())
59 |
60 | nn.init.uniform_(self.transition_probs)
61 | self.transition_probs.div_(self.transition_probs.sum(1, keepdim=True))
62 |
63 | @overload
64 | @_jit._overload_method # pylint: disable=protected-access
65 | def forward(self, sequences: torch.Tensor) -> torch.Tensor:
66 | ...
67 |
68 | @overload
69 | @_jit._overload_method # pylint: disable=protected-access
70 | def forward(self, sequences: PackedSequence) -> torch.Tensor: # type: ignore
71 | ...
72 |
73 | def forward(self, sequences) -> torch.Tensor: # type: ignore
74 | """
75 | Computes the log-probability of observing each of the provided sequences.
76 |
77 | Args:
78 | sequences: Tensor of shape ``[num_sequences, sequence_length]`` or a packed sequence.
79 | Packed sequences should be used whenever the sequence lengths differ. All
80 | sequences must contain state indices of dtype ``long``.
81 |
82 | Returns:
83 | A tensor of shape ``[sequence_length]``, returning the log-probability of each
84 | sequence.
85 | """
86 | if isinstance(sequences, torch.Tensor):
87 | log_probs = self.initial_probs[sequences[:, 0]].log()
88 | sources = sequences[:, :-1]
89 | targets = sequences[:, 1:].unsqueeze(-1)
90 | transition_probs = self.transition_probs[sources].gather(-1, targets).squeeze(-1)
91 | return log_probs + transition_probs.log().sum(-1)
92 | if isinstance(sequences, PackedSequence):
93 | data = sequences.data
94 | batch_sizes = sequences.batch_sizes
95 |
96 | log_probs = self.initial_probs[data[: batch_sizes[0]]].log()
97 | offset = 0
98 | for prev_size, curr_size in zip(batch_sizes, batch_sizes[1:]):
99 | log_probs[:curr_size] += self.transition_probs[
100 | data[offset : offset + curr_size],
101 | data[offset + prev_size : offset + prev_size + curr_size],
102 | ].log()
103 | offset += prev_size
104 |
105 | if sequences.unsorted_indices is not None:
106 | return log_probs[sequences.unsorted_indices]
107 | return log_probs
108 | raise ValueError("unsupported input type")
109 |
110 | def sample(self, num_sequences: int, sequence_length: int) -> torch.Tensor:
111 | """
112 | Samples random sequences from the Markov chain.
113 |
114 | Args:
115 | num_sequences: The number of sequences to sample.
116 | sequence_length: The length of all sequences to sample.
117 |
118 | Returns:
119 | Tensor of shape ``[num_sequences, sequence_length]`` with dtype ``long``, providing the
120 | sampled states.
121 | """
122 | samples = torch.empty(
123 | num_sequences, sequence_length, device=self.transition_probs.device, dtype=torch.long
124 | )
125 | samples[:, 0] = self.initial_probs.multinomial(num_sequences, replacement=True)
126 | for i in range(1, sequence_length):
127 | samples[:, i] = self.transition_probs[samples[:, i - 1]].multinomial(1).squeeze(-1)
128 | return samples
129 |
130 | def stationary_distribution(
131 | self, tol: float = 1e-7, max_iterations: int = 1000
132 | ) -> torch.Tensor:
133 | """
134 | Computes the stationary distribution of the Markov chain using power iteration.
135 |
136 | Args:
137 | tol: The tolerance to use when checking if the power iteration has converged. As soon
138 | as the norm between the vectors of two successive iterations is below this value,
139 | the iteration is stopped.
140 | max_iterations: The maximum number of iterations to run if the tolerance does not
141 | indicate convergence.
142 |
143 | Returns:
144 | A tensor of shape ``[num_states]`` with the stationary distribution (i.e. the
145 | eigenvector corresponding to the largest eigenvector of the transition matrix,
146 | normalized to describe a probability distribution).
147 | """
148 | A = self.transition_probs.t()
149 | v = torch.rand(A.size(0), device=A.device, dtype=A.dtype)
150 |
151 | for _ in range(max_iterations):
152 | v_old = v
153 | v = A.mv(v)
154 | v = v / v.norm()
155 | if (v - v_old).norm() < tol:
156 | break
157 |
158 | return v / v.sum()
159 |
--------------------------------------------------------------------------------
/docs/sites/benchmark.rst:
--------------------------------------------------------------------------------
1 | Benchmarks
2 | ==========
3 |
4 | In order to evaluate the runtime performance of PyCave, we run an exhaustive set of experiments to
5 | compare against the implementation found in scikit-learn. Evaluations are run at varying dataset
6 | sizes.
7 |
8 | All benchmarks are run on an instance with a Intel Xeon E5-2630 v4 CPU (2.2 GHz). We use at most 4
9 | cores and 60 GiB of memory. Also, there is a single GeForce GTX 1080 Ti GPU (11 GiB memory)
10 | available. For the performance measures, each benchmark is run at least 5 times.
11 |
12 | Gaussian Mixture
13 | ----------------
14 |
15 | Setup
16 | ^^^^^
17 |
18 | For measuring the performance of fitting a Gaussian mixture model, we fix the number of iterations
19 | after initialization to 100 to not measure any variances in the convergence criterion. For
20 | initialization, we further set the known means that were used to generate data to not run into
21 | issues of degenerate covariance matrices. Thus, all benchmarks essentially measure the performance
22 | after K-means initialization has been run. Benchmarks for K-means itself are listed below.
23 |
24 | Results
25 | ^^^^^^^
26 |
27 | .. list-table:: Training Duration for Diagonal Covariance (``[num_datapoints, num_features] -> num_components``)
28 | :header-rows: 1
29 | :stub-columns: 1
30 | :widths: 3 2 2 2 2 2
31 |
32 | * -
33 | - Scikit-Learn
34 | - PyCave CPU (full)
35 | - PyCave CPU (batches)
36 | - PyCave GPU (full)
37 | - PyCave GPU (batches)
38 | * - ``[10k, 8] -> 4``
39 | - **352 ms**
40 | - 649 ms
41 | - 3.9 s
42 | - 358 ms
43 | - 3.6 s
44 | * - ``[100k, 32] -> 16``
45 | - 18.4 s
46 | - 4.3 s
47 | - 10.0 s
48 | - **527 ms**
49 | - 3.9 s
50 | * - ``[1M, 64] -> 64``
51 | - 730 s
52 | - 196 s
53 | - 284 s
54 | - **7.7 s**
55 | - 15.3 s
56 |
57 | .. list-table:: Training Duration for Tied Covariance (``[num_datapoints, num_features] -> num_components``)
58 | :header-rows: 1
59 | :stub-columns: 1
60 | :widths: 3 2 2 2 2 2
61 |
62 | * -
63 | - Scikit-Learn
64 | - PyCave CPU (full)
65 | - PyCave CPU (batches)
66 | - PyCave GPU (full)
67 | - PyCave GPU (batches)
68 | * - ``[10k, 8] -> 4``
69 | - 699 ms
70 | - 570 ms
71 | - 3.6 s
72 | - **356 ms**
73 | - 3.3 s
74 | * - ``[100k, 32] -> 16``
75 | - 72.2 s
76 | - 12.1 s
77 | - 16.1 s
78 | - **919 ms**
79 | - 3.8 s
80 | * - ``[1M, 64] -> 64``
81 | - --
82 | - --
83 | - --
84 | - --
85 | - **63.4 s**
86 |
87 | .. list-table:: Training Duration for Full Covariance (``[num_datapoints, num_features] -> num_components``)
88 | :header-rows: 1
89 | :stub-columns: 1
90 | :widths: 3 2 2 2 2 2
91 |
92 | * -
93 | - Scikit-Learn
94 | - PyCave CPU (full)
95 | - PyCave CPU (batches)
96 | - PyCave GPU (full)
97 | - PyCave GPU (batches)
98 | * - ``[10k, 8] -> 4``
99 | - 1.1 s
100 | - 679 ms
101 | - 4.1 s
102 | - **648 ms**
103 | - 4.4 s
104 | * - ``[100k, 32] -> 16``
105 | - 110 s
106 | - 13.5 s
107 | - 21.2 s
108 | - **2.4 s**
109 | - 7.8 s
110 |
111 | Summary
112 | ^^^^^^^
113 |
114 | PyCave's implementation of the Gaussian mixture model is markedly more efficient than the one found
115 | in scikit-learn. Even on the CPU, PyCave outperforms scikit-learn significantly at a 100k
116 | datapoints already. When moving to the GPU, however, PyCave unfolds its full potential and yields
117 | speed ups at around 100x. For larger datasets, mini-batch training is the only alternative. PyCave
118 | fully supports that while the training is approximately twice as large as when training using the
119 | full data. The reason for this is that the M-step of the EM algorithm needs to be split across
120 | epochs, which, in turn, requires to replay the E-step.
121 |
122 |
123 | K-Means
124 | -------
125 |
126 | Setup
127 | ^^^^^
128 |
129 | For the scikit-learn implementation, we use Lloyd's algorithm instead of Elkan's algorithm to have
130 | a useful comparison with PyCave (which implements Lloyd's algorithm).
131 |
132 | Further, we fix the number of iterations after initialization to 100 to not measure any variances
133 | in the convergence criterion.
134 |
135 | Results
136 | ^^^^^^^
137 |
138 | .. list-table:: Training Duration for Random Initialization (``[num_datapoints, num_features] -> num_clusters``)
139 | :header-rows: 1
140 | :stub-columns: 1
141 | :widths: 3 2 2 2 2 2
142 |
143 | * -
144 | - Scikit-Learn
145 | - PyCave CPU (full)
146 | - PyCave CPU (batches)
147 | - PyCave GPU (full)
148 | - PyCave GPU (batches)
149 | * - ``[10k, 8] -> 4``
150 | - **13 ms**
151 | - 412 ms
152 | - 797 ms
153 | - 387 ms
154 | - 2.1 s
155 | * - ``[100k, 32] -> 16``
156 | - **311 ms**
157 | - 2.1 s
158 | - 3.4 s
159 | - 707 ms
160 | - 2.5 s
161 | * - ``[1M, 64] -> 64``
162 | - 10.0 s
163 | - 73.6 s
164 | - 58.1 s
165 | - **8.2 s**
166 | - 10.0 s
167 | * - ``[10M, 128] -> 128``
168 | - 254 s
169 | - --
170 | - --
171 | - --
172 | - **133 s**
173 |
174 | .. list-table:: Training Duration for K-Means++ Initialization (``[num_datapoints, num_features] -> num_clusters``)
175 | :header-rows: 1
176 | :stub-columns: 1
177 | :widths: 3 2 2 2 2 2
178 |
179 | * -
180 | - Scikit-Learn
181 | - PyCave CPU (full)
182 | - PyCave CPU (batches)
183 | - PyCave GPU (full)
184 | - PyCave GPU (batches)
185 | * - ``[10k, 8] -> 4``
186 | - **15 ms**
187 | - 170 ms
188 | - 930 ms
189 | - 431 ms
190 | - 2.4 s
191 | * - ``[100k, 32] -> 16``
192 | - **542 ms**
193 | - 2.3 s
194 | - 4.3 s
195 | - 840 ms
196 | - 3.2 s
197 | * - ``[1M, 64] -> 64``
198 | - 25.3 s
199 | - 93.4 s
200 | - 83.7 s
201 | - **13.1 s**
202 | - 17.1 s
203 | * - ``[10M, 128] -> 128``
204 | - 827 s
205 | - --
206 | - --
207 | - --
208 | - **369 s**
209 |
210 | Summary
211 | ^^^^^^^
212 |
213 | As it turns out, it is really hard to outperform the implementation found in scikit-learn.
214 | Especially if little data is available, the overhead of PyTorch and PyTorch Lightning renders
215 | PyCave comparatively slow. However, as more data is available, PyCave starts to become relatively
216 | faster and, when leveraging the GPU, it finally outperforms scikit-learn for a dataset size of 1M
217 | datapoints. Nonetheless, the improvement is marginal.
218 |
--------------------------------------------------------------------------------
/pycave/bayes/markov_chain/estimator.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import logging
3 | from typing import Any, cast, List
4 | import numpy as np
5 | import torch
6 | from lightkit import ConfigurableBaseEstimator
7 | from lightkit.data import DataLoader, dataset_from_tensors
8 | from torch.nn.utils.rnn import PackedSequence
9 | from torch.utils.data import Dataset
10 | from .lightning_module import MarkovChainLightningModule
11 | from .model import MarkovChainModel, MarkovChainModelConfig
12 | from .types import collate_sequences, collate_sequences_same_length, SequenceData
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class MarkovChain(ConfigurableBaseEstimator[MarkovChainModel]): # type: ignore
18 | """
19 | Probabilistic model for observed state transitions. The Markov chain is similar to the hidden
20 | Markov model, only that the hidden states are known. More information on the Markov chain is
21 | available on `Wikipedia `_.
22 |
23 | See also:
24 | .. currentmodule:: pycave.bayes.markov_chain
25 | .. autosummary::
26 | :nosignatures:
27 | :template: classes/pytorch_module.rst
28 |
29 | MarkovChainModel
30 | MarkovChainModelConfig
31 | """
32 |
33 | #: The fitted PyTorch module with all estimated parameters.
34 | model_: MarkovChainModel
35 |
36 | def __init__(
37 | self,
38 | num_states: int | None = None,
39 | *,
40 | symmetric: bool = False,
41 | batch_size: int | None = None,
42 | trainer_params: dict[str, Any] | None = None,
43 | ):
44 | """
45 | Args:
46 | num_states: The number of states that the Markov chain has. If not provided, it will
47 | be derived automatically when calling :meth:`fit`. Note that this requires a pass
48 | through the data. Consider setting this option explicitly if you're fitting a lot
49 | of data.
50 | symmetric: Whether the transitions between states should be considered symmetric.
51 | batch_size: The batch size to use when fitting the model. If not provided, the full
52 | data will be used as a single batch. Set this if the full data does not fit into
53 | memory.
54 | num_workers: The number of workers to use for loading the data. Only used if a PyTorch
55 | dataset is passed to :meth:`fit` or related methods.
56 | trainer_params: Initialization parameters to use when initializing a PyTorch Lightning
57 | trainer. By default, it disables various stdout logs unless PyCave is configured to
58 | do verbose logging. Checkpointing and logging are disabled regardless of the log
59 | level. This estimator further enforces the following parameters:
60 |
61 | - ``max_epochs=1``
62 | """
63 | super().__init__(
64 | user_params=trainer_params,
65 | overwrite_params=dict(max_epochs=1),
66 | )
67 |
68 | self.num_states = num_states
69 | self.symmetric = symmetric
70 | self.batch_size = batch_size
71 |
72 | def fit(self, sequences: SequenceData) -> MarkovChain:
73 | """
74 | Fits the Markov chain on the provided data and returns the fitted estimator.
75 |
76 | Args:
77 | sequences: The sequences to fit the Markov chain on.
78 |
79 | Returns:
80 | The fitted Markov chain.
81 | """
82 | config = MarkovChainModelConfig(
83 | num_states=self.num_states or _get_num_states(sequences),
84 | )
85 | self.model_ = MarkovChainModel(config)
86 |
87 | logger.info("Fitting Markov chain...")
88 | self.trainer().fit(
89 | MarkovChainLightningModule(self.model_, self.symmetric),
90 | self._init_data_loader(sequences),
91 | )
92 | return self
93 |
94 | def sample(self, num_sequences: int, sequence_length: int) -> torch.Tensor:
95 | """
96 | Samples state sequences from the fitted Markov chain.
97 |
98 | Args:
99 | num_sequences: The number of sequences to sample.
100 | sequence_length: The length of the sequences to sample.
101 |
102 | Returns:
103 | The sampled sequences as a tensor of shape ``[num_sequences, sequence_length]``.
104 |
105 | Note:
106 | This method does not parallelize across multiple processes, i.e. performs no
107 | synchronization.
108 | """
109 | return self.model_.sample(num_sequences, sequence_length)
110 |
111 | def score(self, sequences: SequenceData) -> float:
112 | """
113 | Computes the average negative log-likelihood (NLL) of observing the provided sequences. If
114 | you want to have NLLs for each individual sequence, use :meth:`score_samples` instead.
115 |
116 | Args:
117 | sequences: The sequences for which to compute the average log-probability.
118 |
119 | Returns:
120 | The average NLL for all sequences.
121 |
122 | Note:
123 | See :meth:`score_samples` to obtain the NLL values for individual sequences.
124 | """
125 | result = self.trainer().test(
126 | MarkovChainLightningModule(self.model_),
127 | self._init_data_loader(sequences),
128 | verbose=False,
129 | )
130 | return result[0]["nll"]
131 |
132 | def score_samples(self, sequences: SequenceData) -> torch.Tensor:
133 | """
134 | Computes the average negative log-likelihood (NLL) of observing the provided sequences.
135 |
136 | Args:
137 | sequences: The sequences for which to compute the NLL.
138 |
139 | Returns:
140 | A tensor of shape ``[num_sequences]`` with the NLLs for each individual sequence.
141 |
142 | Attention:
143 | When calling this function in a multi-process environment, each process receives only
144 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
145 | the values returned from this method.
146 | """
147 | result = self.trainer().predict(
148 | MarkovChainLightningModule(self.model_),
149 | self._init_data_loader(sequences),
150 | return_predictions=True,
151 | )
152 | return torch.stack(cast(List[torch.Tensor], result))
153 |
154 | def _init_data_loader(self, sequences: SequenceData) -> DataLoader[PackedSequence]:
155 | if isinstance(sequences, Dataset):
156 | return DataLoader(
157 | sequences,
158 | batch_size=self.batch_size or len(sequences), # type: ignore
159 | collate_fn=collate_sequences, # type: ignore
160 | )
161 |
162 | return DataLoader( # type: ignore
163 | dataset_from_tensors(sequences),
164 | batch_size=self.batch_size or len(sequences),
165 | collate_fn=collate_sequences_same_length,
166 | )
167 |
168 |
169 | def _get_num_states(data: SequenceData) -> int:
170 | if isinstance(data, np.ndarray):
171 | assert data.dtype == np.int64, "array states must have type `np.int64`"
172 | return int(data.max() + 1)
173 | if isinstance(data, torch.Tensor):
174 | assert data.dtype == torch.long, "tensor states must have type `torch.long`"
175 | return int(data.max().item() + 1)
176 | return max(_get_num_states(entry) for entry in data)
177 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom ignore
2 | docs/sites/generated/
3 | lightning_logs/
4 |
5 | # Created by https://www.toptal.com/developers/gitignore/api/macos,visualstudiocode,pycharm+all,python
6 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,visualstudiocode,pycharm+all,python
7 |
8 | ### macOS ###
9 | # General
10 | .DS_Store
11 | .AppleDouble
12 | .LSOverride
13 |
14 | # Icon must end with two \r
15 | Icon
16 |
17 |
18 | # Thumbnails
19 | ._*
20 |
21 | # Files that might appear in the root of a volume
22 | .DocumentRevisions-V100
23 | .fseventsd
24 | .Spotlight-V100
25 | .TemporaryItems
26 | .Trashes
27 | .VolumeIcon.icns
28 | .com.apple.timemachine.donotpresent
29 |
30 | # Directories potentially created on remote AFP share
31 | .AppleDB
32 | .AppleDesktop
33 | Network Trash Folder
34 | Temporary Items
35 | .apdisk
36 |
37 | ### macOS Patch ###
38 | # iCloud generated files
39 | *.icloud
40 |
41 | ### PyCharm+all ###
42 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
43 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
44 |
45 | # User-specific stuff
46 | .idea/**/workspace.xml
47 | .idea/**/tasks.xml
48 | .idea/**/usage.statistics.xml
49 | .idea/**/dictionaries
50 | .idea/**/shelf
51 |
52 | # AWS User-specific
53 | .idea/**/aws.xml
54 |
55 | # Generated files
56 | .idea/**/contentModel.xml
57 |
58 | # Sensitive or high-churn files
59 | .idea/**/dataSources/
60 | .idea/**/dataSources.ids
61 | .idea/**/dataSources.local.xml
62 | .idea/**/sqlDataSources.xml
63 | .idea/**/dynamic.xml
64 | .idea/**/uiDesigner.xml
65 | .idea/**/dbnavigator.xml
66 |
67 | # Gradle
68 | .idea/**/gradle.xml
69 | .idea/**/libraries
70 |
71 | # Gradle and Maven with auto-import
72 | # When using Gradle or Maven with auto-import, you should exclude module files,
73 | # since they will be recreated, and may cause churn. Uncomment if using
74 | # auto-import.
75 | # .idea/artifacts
76 | # .idea/compiler.xml
77 | # .idea/jarRepositories.xml
78 | # .idea/modules.xml
79 | # .idea/*.iml
80 | # .idea/modules
81 | # *.iml
82 | # *.ipr
83 |
84 | # CMake
85 | cmake-build-*/
86 |
87 | # Mongo Explorer plugin
88 | .idea/**/mongoSettings.xml
89 |
90 | # File-based project format
91 | *.iws
92 |
93 | # IntelliJ
94 | out/
95 |
96 | # mpeltonen/sbt-idea plugin
97 | .idea_modules/
98 |
99 | # JIRA plugin
100 | atlassian-ide-plugin.xml
101 |
102 | # Cursive Clojure plugin
103 | .idea/replstate.xml
104 |
105 | # SonarLint plugin
106 | .idea/sonarlint/
107 |
108 | # Crashlytics plugin (for Android Studio and IntelliJ)
109 | com_crashlytics_export_strings.xml
110 | crashlytics.properties
111 | crashlytics-build.properties
112 | fabric.properties
113 |
114 | # Editor-based Rest Client
115 | .idea/httpRequests
116 |
117 | # Android studio 3.1+ serialized cache file
118 | .idea/caches/build_file_checksums.ser
119 |
120 | ### PyCharm+all Patch ###
121 | # Ignore everything but code style settings and run configurations
122 | # that are supposed to be shared within teams.
123 |
124 | .idea/*
125 |
126 | !.idea/codeStyles
127 | !.idea/runConfigurations
128 |
129 | ### Python ###
130 | # Byte-compiled / optimized / DLL files
131 | __pycache__/
132 | *.py[cod]
133 | *$py.class
134 |
135 | # C extensions
136 | *.so
137 |
138 | # Distribution / packaging
139 | .Python
140 | build/
141 | develop-eggs/
142 | dist/
143 | downloads/
144 | eggs/
145 | .eggs/
146 | lib/
147 | lib64/
148 | parts/
149 | sdist/
150 | var/
151 | wheels/
152 | share/python-wheels/
153 | *.egg-info/
154 | .installed.cfg
155 | *.egg
156 | MANIFEST
157 |
158 | # PyInstaller
159 | # Usually these files are written by a python script from a template
160 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
161 | *.manifest
162 | *.spec
163 |
164 | # Installer logs
165 | pip-log.txt
166 | pip-delete-this-directory.txt
167 |
168 | # Unit test / coverage reports
169 | htmlcov/
170 | .tox/
171 | .nox/
172 | .coverage
173 | .coverage.*
174 | .cache
175 | nosetests.xml
176 | coverage.xml
177 | *.cover
178 | *.py,cover
179 | .hypothesis/
180 | .pytest_cache/
181 | cover/
182 |
183 | # Translations
184 | *.mo
185 | *.pot
186 |
187 | # Django stuff:
188 | *.log
189 | local_settings.py
190 | db.sqlite3
191 | db.sqlite3-journal
192 |
193 | # Flask stuff:
194 | instance/
195 | .webassets-cache
196 |
197 | # Scrapy stuff:
198 | .scrapy
199 |
200 | # Sphinx documentation
201 | docs/_build/
202 |
203 | # PyBuilder
204 | .pybuilder/
205 | target/
206 |
207 | # Jupyter Notebook
208 | .ipynb_checkpoints
209 |
210 | # IPython
211 | profile_default/
212 | ipython_config.py
213 |
214 | # pyenv
215 | # For a library or package, you might want to ignore these files since the code is
216 | # intended to run in multiple environments; otherwise, check them in:
217 | .python-version
218 |
219 | # pipenv
220 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
221 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
222 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
223 | # install all needed dependencies.
224 | #Pipfile.lock
225 |
226 | # poetry
227 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
228 | # This is especially recommended for binary packages to ensure reproducibility, and is more
229 | # commonly ignored for libraries.
230 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
231 | #poetry.lock
232 |
233 | # pdm
234 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
235 | #pdm.lock
236 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
237 | # in version control.
238 | # https://pdm.fming.dev/#use-with-ide
239 | .pdm.toml
240 |
241 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
242 | __pypackages__/
243 |
244 | # Celery stuff
245 | celerybeat-schedule
246 | celerybeat.pid
247 |
248 | # SageMath parsed files
249 | *.sage.py
250 |
251 | # Environments
252 | .env
253 | .venv
254 | env/
255 | venv/
256 | ENV/
257 | env.bak/
258 | venv.bak/
259 |
260 | # Spyder project settings
261 | .spyderproject
262 | .spyproject
263 |
264 | # Rope project settings
265 | .ropeproject
266 |
267 | # mkdocs documentation
268 | /site
269 |
270 | # mypy
271 | .mypy_cache/
272 | .dmypy.json
273 | dmypy.json
274 |
275 | # Pyre type checker
276 | .pyre/
277 |
278 | # pytype static type analyzer
279 | .pytype/
280 |
281 | # Cython debug symbols
282 | cython_debug/
283 |
284 | # PyCharm
285 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
286 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
287 | # and can be added to the global gitignore or merged into this file. For a more nuclear
288 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
289 | #.idea/
290 |
291 | ### Python Patch ###
292 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
293 | poetry.toml
294 |
295 |
296 | ### VisualStudioCode ###
297 | .vscode/*
298 | !.vscode/settings.json
299 | !.vscode/tasks.json
300 | !.vscode/launch.json
301 | !.vscode/extensions.json
302 | !.vscode/*.code-snippets
303 |
304 | # Local History for Visual Studio Code
305 | .history/
306 |
307 | # Built Visual Studio Code Extensions
308 | *.vsix
309 |
310 | ### VisualStudioCode Patch ###
311 | # Ignore all local history of files
312 | .history
313 | .ionide
314 |
315 | # End of https://www.toptal.com/developers/gitignore/api/macos,visualstudiocode,pycharm+all,python
316 |
--------------------------------------------------------------------------------
/pycave/clustering/kmeans/metrics.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Any, Callable, Optional
3 | import torch
4 | from torchmetrics import Metric
5 |
6 |
7 | class CentroidAggregator(Metric):
8 | """
9 | The centroid aggregator aggregates kmeans centroids over batches and processes.
10 | """
11 |
12 | full_state_update = False
13 |
14 | def __init__(
15 | self,
16 | num_clusters: int,
17 | num_features: int,
18 | *,
19 | dist_sync_fn: Optional[Callable[[Any], Any]] = None,
20 | ):
21 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
22 |
23 | self.num_clusters = num_clusters
24 | self.num_features = num_features
25 |
26 | self.centroids: torch.Tensor
27 | self.add_state("centroids", torch.zeros(num_clusters, num_features), dist_reduce_fx="sum")
28 |
29 | self.cluster_counts: torch.Tensor
30 | self.add_state("cluster_counts", torch.zeros(num_clusters), dist_reduce_fx="sum")
31 |
32 | def update(self, data: torch.Tensor, assignments: torch.Tensor) -> None:
33 | indices = assignments.unsqueeze(1).expand(-1, self.num_features)
34 | self.centroids.scatter_add_(0, indices, data)
35 |
36 | counts = assignments.bincount(minlength=self.num_clusters).float()
37 | self.cluster_counts.add_(counts)
38 |
39 | def compute(self) -> torch.Tensor:
40 | return self.centroids / self.cluster_counts.unsqueeze(-1)
41 |
42 |
43 | class UniformSampler(Metric):
44 | """
45 | The uniform sampler randomly samples a specified number of datapoints uniformly from all
46 | datapoints.
47 |
48 | The idea is the following: sample the number of choices from each batch and track the number of
49 | datapoints that was already sampled from. When sampling from the union of existing choices and
50 | a new batch, more weight is put on the existing choices (according to the number of datapoints
51 | they were already sampled from).
52 | """
53 |
54 | full_state_update = False
55 |
56 | def __init__(
57 | self,
58 | num_choices: int,
59 | num_features: int,
60 | *,
61 | dist_sync_fn: Optional[Callable[[Any], Any]] = None,
62 | ):
63 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
64 |
65 | self.num_choices = num_choices
66 |
67 | self.choices: torch.Tensor
68 | self.add_state("choices", torch.empty(num_choices, num_features), dist_reduce_fx="cat")
69 |
70 | self.choice_weights: torch.Tensor
71 | self.add_state("choice_weights", torch.zeros(num_choices), dist_reduce_fx="cat")
72 |
73 | def update(self, data: torch.Tensor) -> None:
74 | if self.num_choices == 1:
75 | # If there is only one choice, the fastest thing is to use the `random` package. The
76 | # cumulative weight of the data is its size, the cumulative weight of the current
77 | # choice is some value.
78 | cum_weight = data.size(0) + self.choice_weights.item()
79 | if random.random() * cum_weight < data.size(0):
80 | # Use some item from the data, else keep the current choice
81 | self.choices.copy_(data[random.randrange(data.size(0))])
82 | else:
83 | # The choices are computed from scratch every time, weighting the current choices by
84 | # the cumulative weight put on them
85 | weights = torch.cat(
86 | [
87 | torch.ones(data.size(0), device=data.device, dtype=data.dtype),
88 | self.choice_weights,
89 | ]
90 | )
91 | pool = torch.cat([data, self.choices])
92 | samples = weights.multinomial(self.num_choices)
93 | self.choices.copy_(pool[samples])
94 |
95 | # The weights are the cumulative counts, divided by the number of choices
96 | self.choice_weights.add_(data.size(0) / self.num_choices)
97 |
98 | def compute(self) -> torch.Tensor:
99 | # In the ddp setting, there are "too many" choices, so we sample
100 | if self.choices.size(0) > self.num_choices:
101 | samples = self.choice_weights.multinomial(self.num_choices)
102 | return self.choices[samples]
103 | return self.choices
104 |
105 |
106 | class DistanceSampler(Metric):
107 | """
108 | The distance sampler may be used for kmeans++ initialization, to iteratively select centroids
109 | according to their squared distances to existing choices.
110 |
111 | Computing the distance to existing choices is not part of this sampler. Within each "cycle", it
112 | computes a given number of candidates. Candidates are sampled independently and may be
113 | duplicates.
114 | """
115 |
116 | full_state_update = False
117 |
118 | def __init__(
119 | self,
120 | num_choices: int,
121 | num_features: int,
122 | *,
123 | dist_sync_fn: Optional[Callable[[Any], Any]] = None,
124 | ):
125 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
126 |
127 | self.num_choices = num_choices
128 | self.num_features = num_features
129 |
130 | self.choices: torch.Tensor
131 | self.add_state("choices", torch.empty(num_choices, num_features), dist_reduce_fx="cat")
132 |
133 | # Cumulative distance is the same for all choices
134 | self.cumulative_squared_distance: torch.Tensor
135 | self.add_state("cumulative_squared_distance", torch.zeros(1), dist_reduce_fx="cat")
136 |
137 | def update(self, data: torch.Tensor, shortest_distances: torch.Tensor) -> None:
138 | eps = torch.finfo(data.dtype).eps
139 | squared_distances = shortest_distances.square()
140 |
141 | # For all choices, check if we should use a sample from the data or the existing choice
142 | data_dist = squared_distances.sum()
143 | cum_dist = data_dist + eps + self.cumulative_squared_distance
144 | use_choice_from_data = (
145 | torch.rand(self.num_choices, device=data.device, dtype=data.dtype) * cum_dist
146 | < data_dist + eps
147 | )
148 |
149 | # Then, we sample from the data `num_choices` times and replace if needed
150 | choices = (squared_distances + eps).multinomial(self.num_choices, replacement=True)
151 | self.choices.masked_scatter_(
152 | use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]]
153 | )
154 |
155 | # In any case, the cumulative distances are updated
156 | self.cumulative_squared_distance.add_(data_dist)
157 |
158 | def compute(self) -> torch.Tensor:
159 | # Upon computation, we sample if there is more than one choice (ddp setting)
160 | if self.choices.size(0) > self.num_choices:
161 | # choices now have shape [num_choices, num_processes, num_features]
162 | choices = self.choices.reshape(-1, self.num_choices, self.num_features).transpose(0, 1)
163 | # For each choice, we sample across processes
164 | choice_indices = torch.arange(self.num_choices, device=self.choices.device)
165 | process_indices = self.cumulative_squared_distance.multinomial(
166 | self.num_choices, replacement=True
167 | )
168 | return choices[choice_indices, process_indices]
169 | # Otherwise, we can return the choices
170 | return self.choices
171 |
172 |
173 | class BatchSummer(Metric):
174 | """
175 | Sums the values for a batch of items independently.
176 | """
177 |
178 | full_state_update = True
179 |
180 | def __init__(self, num_values: int, *, dist_sync_fn: Optional[Callable[[Any], Any]] = None):
181 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
182 |
183 | self.sums: torch.Tensor
184 | self.add_state("sums", torch.zeros(num_values), dist_reduce_fx="sum")
185 |
186 | def update(self, values: torch.Tensor) -> None:
187 | self.sums.add_(values.sum(0))
188 |
189 | def compute(self) -> torch.Tensor:
190 | return self.sums
191 |
192 |
193 | class BatchAverager(Metric):
194 | """
195 | Averages the values for a batch of items independently.
196 | """
197 |
198 | full_state_update = False
199 |
200 | def __init__(
201 | self,
202 | num_values: int,
203 | for_variance: bool,
204 | *,
205 | dist_sync_fn: Optional[Callable[[Any], Any]] = None,
206 | ):
207 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
208 |
209 | self.for_variance = for_variance
210 |
211 | self.sums: torch.Tensor
212 | self.add_state("sums", torch.zeros(num_values), dist_reduce_fx="sum")
213 |
214 | self.counts: torch.Tensor
215 | self.add_state("counts", torch.zeros(num_values), dist_reduce_fx="sum")
216 |
217 | def update(self, values: torch.Tensor) -> None:
218 | self.sums.add_(values.sum(0))
219 | self.counts.add_(values.size(0))
220 |
221 | def compute(self) -> torch.Tensor:
222 | return self.sums / (self.counts - 1 if self.for_variance else self.counts)
223 |
--------------------------------------------------------------------------------
/pycave/clustering/kmeans/estimator.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import logging
3 | from typing import Any, cast, List
4 | import torch
5 | from lightkit import ConfigurableBaseEstimator
6 | from lightkit.data import collate_tensor, DataLoader, dataset_from_tensors, TensorLike
7 | from lightkit.estimator import PredictorMixin, TransformerMixin
8 | from .lightning_module import (
9 | FeatureVarianceLightningModule,
10 | KMeansLightningModule,
11 | KmeansPlusPlusInitLightningModule,
12 | KmeansRandomInitLightningModule,
13 | )
14 | from .model import KMeansModel, KMeansModelConfig
15 | from .types import KMeansInitStrategy
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class KMeans(
21 | ConfigurableBaseEstimator[KMeansModel], # type: ignore
22 | TransformerMixin[TensorLike, torch.Tensor],
23 | PredictorMixin[TensorLike, torch.Tensor],
24 | ):
25 | """
26 | Model for clustering data into a predefined number of clusters. More information on K-means
27 | clustering is available on `Wikipedia `_.
28 |
29 | See also:
30 | .. currentmodule:: pycave.clustering.kmeans
31 | .. autosummary::
32 | :nosignatures:
33 | :template: classes/pytorch_module.rst
34 |
35 | KMeansModel
36 | KMeansModelConfig
37 | """
38 |
39 | #: The fitted PyTorch module with all estimated parameters.
40 | model_: KMeansModel
41 | #: A boolean indicating whether the model converged during training.
42 | converged_: bool
43 | #: The number of iterations the model was fitted for, excluding initialization.
44 | num_iter_: int
45 | #: The mean squared distance of all datapoints to their closest cluster centers.
46 | inertia_: float
47 |
48 | def __init__(
49 | self,
50 | num_clusters: int = 1,
51 | *,
52 | init_strategy: KMeansInitStrategy = "kmeans++",
53 | convergence_tolerance: float = 1e-4,
54 | batch_size: int | None = None,
55 | trainer_params: dict[str, Any] | None = None,
56 | ):
57 | """
58 | Args:
59 | num_clusters: The number of clusters.
60 | init_strategy: The strategy for initializing centroids.
61 | convergence_tolerance: Training is conducted until the Frobenius norm of the change
62 | between cluster centroids falls below this threshold. The tolerance is multiplied
63 | by the average variance of the features.
64 | batch_size: The batch size to use when fitting the model. If not provided, the full
65 | data will be used as a single batch. Set this if the full data does not fit into
66 | memory.
67 | trainer_params: Initialization parameters to use when initializing a PyTorch Lightning
68 | trainer. By default, it disables various stdout logs unless PyCave is configured to
69 | do verbose logging. Checkpointing and logging are disabled regardless of the log
70 | level. This estimator further sets the following overridable defaults:
71 |
72 | - ``max_epochs=300``
73 |
74 | Note:
75 | The number of epochs passed to the initializer only define the number of optimization
76 | epochs. Prior to that, initialization is run which may perform additional iterations
77 | through the data.
78 | """
79 | super().__init__(
80 | default_params=dict(max_epochs=300),
81 | user_params=trainer_params,
82 | )
83 |
84 | # Assign other properties
85 | self.batch_size = batch_size
86 | self.num_clusters = num_clusters
87 | self.init_strategy = init_strategy
88 | self.convergence_tolerance = convergence_tolerance
89 |
90 | def fit(self, data: TensorLike) -> KMeans:
91 | """
92 | Fits the KMeans model on the provided data by running Lloyd's algorithm.
93 |
94 | Args:
95 | data: The tabular data to fit on. The dimensionality of the KMeans model is
96 | automatically inferred from this data.
97 |
98 | Returns:
99 | The fitted KMeans model.
100 | """
101 | # Initialize model
102 | num_features = len(data[0])
103 | config = KMeansModelConfig(
104 | num_clusters=self.num_clusters,
105 | num_features=num_features,
106 | )
107 | self.model_ = KMeansModel(config)
108 |
109 | # Setup the data loading
110 | loader = DataLoader(
111 | dataset_from_tensors(data),
112 | batch_size=self.batch_size or len(data),
113 | collate_fn=collate_tensor,
114 | )
115 | is_batch_training = self._num_batches_per_epoch(loader) > 1
116 |
117 | # First, initialize the centroids
118 | if self.init_strategy == "random":
119 | module = KmeansRandomInitLightningModule(self.model_)
120 | num_epochs = 1
121 | else:
122 | module = KmeansPlusPlusInitLightningModule(
123 | self.model_,
124 | is_batch_training=is_batch_training,
125 | )
126 | num_epochs = 2 * config.num_clusters - 1
127 |
128 | logger.info("Running initialization...")
129 | self.trainer(max_epochs=num_epochs).fit(module, loader)
130 |
131 | # Then, in order to find the right convergence tolerance, we need to compute the variance
132 | # of the data.
133 | if self.convergence_tolerance != 0:
134 | variances = torch.empty(config.num_features)
135 | module = FeatureVarianceLightningModule(variances)
136 | self.trainer().fit(module, loader)
137 |
138 | tolerance_multiplier = cast(float, variances.mean().item())
139 | convergence_tolerance = self.convergence_tolerance * tolerance_multiplier
140 | else:
141 | convergence_tolerance = 0
142 |
143 | # Then, we can fit the actual model. We need a new trainer for that
144 | logger.info("Fitting K-Means...")
145 | trainer = self.trainer()
146 | module = KMeansLightningModule(
147 | self.model_,
148 | convergence_tolerance=convergence_tolerance,
149 | )
150 | trainer.fit(module, loader)
151 |
152 | # Assign convergence properties
153 | self.num_iter_ = module.current_epoch
154 | self.converged_ = module.current_epoch < trainer.max_epochs
155 | if "inertia" in trainer.callback_metrics:
156 | self.inertia_ = cast(float, trainer.callback_metrics["inertia"].item())
157 | return self
158 |
159 | def predict(self, data: TensorLike) -> torch.Tensor:
160 | """
161 | Predicts the closest cluster for each item provided.
162 |
163 | Args:
164 | data: The datapoints for which to predict the clusters.
165 |
166 | Returns:
167 | Tensor of shape ``[num_datapoints]`` with the index of the closest cluster for each
168 | datapoint.
169 |
170 | Attention:
171 | When calling this function in a multi-process environment, each process receives only
172 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
173 | the values returned from this method.
174 | """
175 | loader = DataLoader(
176 | dataset_from_tensors(data),
177 | batch_size=self.batch_size or len(data),
178 | collate_fn=collate_tensor,
179 | )
180 | result = self.trainer().predict(
181 | KMeansLightningModule(self.model_, predict_target="assignments"), loader
182 | )
183 | return torch.cat(cast(List[torch.Tensor], result))
184 |
185 | def score(self, data: TensorLike) -> float:
186 | """
187 | Computes the average inertia of all the provided datapoints. That is, it computes the mean
188 | squared distance to each datapoint's closest centroid.
189 |
190 | Args:
191 | data: The data for which to compute the average inertia.
192 |
193 | Returns:
194 | The average inertia.
195 |
196 | Note:
197 | See :meth:`score_samples` to obtain the inertia for individual sequences.
198 | """
199 | loader = DataLoader(
200 | dataset_from_tensors(data),
201 | batch_size=self.batch_size or len(data),
202 | collate_fn=collate_tensor,
203 | )
204 | result = self.trainer().test(KMeansLightningModule(self.model_), loader, verbose=False)
205 | return result[0]["inertia"]
206 |
207 | def score_samples(self, data: TensorLike) -> torch.Tensor:
208 | """
209 | Computes the inertia for each of the the provided datapoints. That is, it computes the mean
210 | squared distance of each datapoint to its closest centroid.
211 |
212 | Args:
213 | data: The data for which to compute the inertia values.
214 |
215 | Returns:
216 | A tensor of shape ``[num_datapoints]`` with the inertia of each datapoint.
217 |
218 | Attention:
219 | When calling this function in a multi-process environment, each process receives only
220 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
221 | the values returned from this method.
222 | """
223 | loader = DataLoader(
224 | dataset_from_tensors(data),
225 | batch_size=self.batch_size or len(data),
226 | collate_fn=collate_tensor,
227 | )
228 | result = self.trainer().predict(
229 | KMeansLightningModule(self.model_, predict_target="inertias"), loader
230 | )
231 | return torch.cat(cast(List[torch.Tensor], result))
232 |
233 | def transform(self, data: TensorLike) -> torch.Tensor:
234 | """
235 | Transforms the provided data into the cluster-distance space. That is, it returns the
236 | distance of each datapoint to each cluster centroid.
237 |
238 | Args:
239 | data: The data to transform.
240 |
241 | Returns:
242 | A tensor of shape ``[num_datapoints, num_clusters]`` with the distances to the cluster
243 | centroids.
244 |
245 | Attention:
246 | When calling this function in a multi-process environment, each process receives only
247 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
248 | the values returned from this method.
249 | """
250 | loader = DataLoader(
251 | dataset_from_tensors(data),
252 | batch_size=self.batch_size or len(data),
253 | collate_fn=collate_tensor,
254 | )
255 | result = self.trainer().predict(
256 | KMeansLightningModule(self.model_, predict_target="distances"), loader
257 | )
258 | return torch.cat(cast(List[torch.Tensor], result))
259 |
--------------------------------------------------------------------------------
/tests/bayes/core/test_normal.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import pytest
3 | import torch
4 | from sklearn.mixture._gaussian_mixture import _compute_log_det_cholesky # type: ignore
5 | from sklearn.mixture._gaussian_mixture import _compute_precision_cholesky # type: ignore
6 | from torch.distributions import MultivariateNormal
7 | from pycave.bayes.core import cholesky_precision, covariance, log_normal, sample_normal
8 | from pycave.bayes.core._jit import _cholesky_logdet # type: ignore
9 | from tests._data.normal import (
10 | sample_data,
11 | sample_diag_covars,
12 | sample_full_covars,
13 | sample_means,
14 | sample_spherical_covars,
15 | )
16 |
17 | # -------------------------------------------------------------------------------------------------
18 | # CHOLESKY PRECISIONS
19 | # -------------------------------------------------------------------------------------------------
20 |
21 |
22 | @pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200]))
23 | def test_cholesky_precision_spherical(covars: torch.Tensor):
24 | expected = _compute_precision_cholesky(covars.numpy(), "spherical") # type: ignore
25 | actual = cholesky_precision(covars, "spherical")
26 | assert torch.allclose(
27 | torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4
28 | )
29 |
30 |
31 | @pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100]))
32 | def test_cholesky_precision_diag(covars: torch.Tensor):
33 | expected = _compute_precision_cholesky(covars.numpy(), "diag") # type: ignore
34 | actual = cholesky_precision(covars, "diag")
35 | assert torch.allclose(
36 | torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4
37 | )
38 |
39 |
40 | @pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100]))
41 | def test_cholesky_precision_full(covars: torch.Tensor):
42 | expected = _compute_precision_cholesky(covars.numpy(), "full") # type: ignore
43 | actual = cholesky_precision(covars, "full")
44 | assert torch.allclose(
45 | torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4
46 | )
47 |
48 |
49 | @pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100]))
50 | def test_cholesky_precision_tied(covars: torch.Tensor):
51 | expected = _compute_precision_cholesky(covars.numpy(), "tied") # type: ignore
52 | actual = cholesky_precision(covars, "tied")
53 | assert torch.allclose(
54 | torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4
55 | )
56 |
57 |
58 | # -------------------------------------------------------------------------------------------------
59 | # COVARIANCES
60 | # -------------------------------------------------------------------------------------------------
61 |
62 |
63 | @pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200]))
64 | def test_covariances_spherical(covars: torch.Tensor):
65 | precision_cholesky = _compute_precision_cholesky(covars.numpy(), "spherical") # type: ignore
66 | actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.float), "spherical")
67 | assert torch.allclose(covars, actual)
68 |
69 |
70 | @pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100]))
71 | def test_covariances_diag(covars: torch.Tensor):
72 | precision_cholesky = _compute_precision_cholesky(covars.numpy(), "diag") # type: ignore
73 | actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.float), "diag")
74 | assert torch.allclose(covars, actual)
75 |
76 |
77 | @pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100]))
78 | def test_covariances_full(covars: torch.Tensor):
79 | precision_cholesky = _compute_precision_cholesky(covars.numpy(), "full") # type: ignore
80 | actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.double), "full")
81 | assert torch.allclose(covars, covars.transpose(1, 2))
82 | assert torch.allclose(covars.to(torch.double), actual)
83 |
84 |
85 | @pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100]))
86 | def test_covariances_tied(covars: torch.Tensor):
87 | precision_cholesky = _compute_precision_cholesky(covars.numpy(), "tied") # type: ignore
88 | actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.double), "tied")
89 | assert torch.allclose(covars, covars.T)
90 | assert torch.allclose(covars.to(torch.double), actual)
91 |
92 |
93 | # -------------------------------------------------------------------------------------------------
94 | # CHOLESKY LOG DETERMINANTS
95 | # -------------------------------------------------------------------------------------------------
96 |
97 |
98 | @pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200]))
99 | def test_cholesky_logdet_spherical(covars: torch.Tensor):
100 | expected = _compute_log_det_cholesky( # type: ignore
101 | _compute_precision_cholesky(covars.numpy(), "spherical"), "spherical", 100 # type: ignore
102 | )
103 | actual = _cholesky_logdet( # type: ignore
104 | 100,
105 | cholesky_precision(covars, "spherical"),
106 | "spherical",
107 | )
108 | assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual)
109 |
110 |
111 | @pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100]))
112 | def test_cholesky_logdet_diag(covars: torch.Tensor):
113 | expected = _compute_log_det_cholesky( # type: ignore
114 | _compute_precision_cholesky(covars.numpy(), "diag"), # type: ignore
115 | "diag",
116 | covars.size(1),
117 | )
118 | actual = _cholesky_logdet( # type: ignore
119 | covars.size(1),
120 | cholesky_precision(covars, "diag"),
121 | "diag",
122 | )
123 | assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual)
124 |
125 |
126 | @pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100]))
127 | def test_cholesky_logdet_full(covars: torch.Tensor):
128 | expected = _compute_log_det_cholesky( # type: ignore
129 | _compute_precision_cholesky(covars.numpy(), "full"), # type: ignore
130 | "full",
131 | covars.size(1),
132 | )
133 | actual = _cholesky_logdet( # type: ignore
134 | covars.size(1),
135 | cholesky_precision(covars, "full"),
136 | "full",
137 | )
138 | assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual)
139 |
140 |
141 | @pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100]))
142 | def test_cholesky_logdet_tied(covars: torch.Tensor):
143 | expected = _compute_log_det_cholesky( # type: ignore
144 | _compute_precision_cholesky(covars.numpy(), "tied"), # type: ignore
145 | "tied",
146 | covars.size(0),
147 | )
148 | actual = _cholesky_logdet( # type: ignore
149 | covars.size(0),
150 | cholesky_precision(covars, "tied"),
151 | "tied",
152 | )
153 | assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual)
154 |
155 |
156 | # -------------------------------------------------------------------------------------------------
157 | # LOG NORMAL
158 | # -------------------------------------------------------------------------------------------------
159 |
160 |
161 | @pytest.mark.parametrize(
162 | "x, means, covars",
163 | zip(
164 | sample_data([10, 50, 100], [3, 50, 100]),
165 | sample_means([70, 5, 200], [3, 50, 100]),
166 | sample_spherical_covars([70, 5, 200]),
167 | ),
168 | )
169 | def test_log_normal_spherical(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor):
170 | covar_matrices = torch.stack([torch.eye(means.size(-1)) * c for c in covars])
171 | precisions_cholesky = cholesky_precision(covars, "spherical")
172 | actual = log_normal(x, means, precisions_cholesky, covariance_type="spherical")
173 | _assert_log_prob(actual, x, means, covar_matrices)
174 |
175 |
176 | @pytest.mark.parametrize(
177 | "x, means, covars",
178 | zip(
179 | sample_data([10, 50, 100], [3, 50, 100]),
180 | sample_means([70, 5, 200], [3, 50, 100]),
181 | sample_diag_covars([70, 5, 200], [3, 50, 100]),
182 | ),
183 | )
184 | def test_log_normal_diag(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor):
185 | covar_matrices = torch.stack([torch.diag(c) for c in covars])
186 | precisions_cholesky = cholesky_precision(covars, "diag")
187 | actual = log_normal(x, means, precisions_cholesky, covariance_type="diag")
188 | _assert_log_prob(actual, x, means, covar_matrices)
189 |
190 |
191 | @pytest.mark.parametrize(
192 | "x, means, covars",
193 | zip(
194 | sample_data([10, 50, 100], [3, 50, 100]),
195 | sample_means([70, 5, 200], [3, 50, 100]),
196 | sample_full_covars([70, 5, 200], [3, 50, 100]),
197 | ),
198 | )
199 | def test_log_normal_full(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor):
200 | precisions_cholesky = cholesky_precision(covars, "full")
201 | actual = log_normal(x, means, precisions_cholesky, covariance_type="full")
202 | _assert_log_prob(actual.float(), x, means, covars)
203 |
204 |
205 | @pytest.mark.parametrize(
206 | "x, means, covars",
207 | zip(
208 | sample_data([10, 50, 100], [3, 50, 100]),
209 | sample_means([70, 5, 200], [3, 50, 100]),
210 | sample_full_covars([1, 1, 1], [3, 50, 100]),
211 | ),
212 | )
213 | def test_log_normal_tied(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor):
214 | precisions_cholesky = cholesky_precision(covars, "tied")
215 | actual = log_normal(x, means, precisions_cholesky, covariance_type="tied")
216 | _assert_log_prob(actual, x, means, covars)
217 |
218 |
219 | # -------------------------------------------------------------------------------------------------
220 | # SAMPLING
221 | # -------------------------------------------------------------------------------------------------
222 |
223 |
224 | @pytest.mark.flaky(max_runs=3, min_passes=1)
225 | def test_sample_normal_spherical():
226 | mean = torch.tensor([1.5, 3.5])
227 | covar = torch.tensor(4.0)
228 | target_covar = torch.tensor([[4.0, 0.0], [0.0, 4.0]])
229 |
230 | n = 1_000_000
231 | precisions = cholesky_precision(covar, "spherical")
232 | samples = sample_normal(n, mean, precisions, "spherical")
233 |
234 | sample_mean = samples.mean(0)
235 | sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n
236 |
237 | assert torch.allclose(mean, sample_mean, atol=1e-2)
238 | assert torch.allclose(target_covar, sample_covar, atol=1e-2)
239 |
240 |
241 | @pytest.mark.flaky(max_runs=3, min_passes=1)
242 | def test_sample_normal_diag():
243 | mean = torch.tensor([1.5, 3.5])
244 | covar = torch.tensor([0.5, 4.5])
245 | target_covar = torch.tensor([[0.5, 0.0], [0.0, 4.5]])
246 |
247 | n = 1_000_000
248 | precisions = cholesky_precision(covar, "diag")
249 | samples = sample_normal(n, mean, precisions, "diag")
250 |
251 | sample_mean = samples.mean(0)
252 | sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n
253 |
254 | assert torch.allclose(mean, sample_mean, atol=1e-2)
255 | assert torch.allclose(target_covar, sample_covar, atol=1e-2)
256 |
257 |
258 | @pytest.mark.flaky(max_runs=3, min_passes=1)
259 | def test_sample_normal_full():
260 | mean = torch.tensor([1.5, 3.5])
261 | covar = torch.tensor([[4.0, 2.5], [2.5, 2.0]])
262 |
263 | n = 1_000_000
264 | precisions = cholesky_precision(covar, "tied")
265 | samples = sample_normal(n, mean, precisions, "full")
266 |
267 | sample_mean = samples.mean(0)
268 | sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n
269 |
270 | assert torch.allclose(mean, sample_mean, atol=1e-2)
271 | assert torch.allclose(covar, sample_covar, atol=1e-2)
272 |
273 |
274 | # -------------------------------------------------------------------------------------------------
275 |
276 |
277 | def _assert_log_prob(
278 | actual: torch.Tensor, x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor
279 | ) -> None:
280 | distribution = MultivariateNormal(means, covars)
281 | expected = distribution.log_prob(x.unsqueeze(1))
282 | assert torch.allclose(actual, expected, rtol=1e-3)
283 |
--------------------------------------------------------------------------------
/pycave/clustering/kmeans/lightning_module.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=abstract-method
2 | import math
3 | from typing import List, Literal
4 | import pytorch_lightning as pl
5 | import torch
6 | from pytorch_lightning.callbacks import EarlyStopping
7 | from torchmetrics import MeanMetric
8 | from pycave.utils import NonparametricLightningModule
9 | from .metrics import (
10 | BatchAverager,
11 | BatchSummer,
12 | CentroidAggregator,
13 | DistanceSampler,
14 | UniformSampler,
15 | )
16 | from .model import KMeansModel
17 |
18 | # -------------------------------------------------------------------------------------------------
19 | # TRAINING
20 |
21 |
22 | class KMeansLightningModule(NonparametricLightningModule):
23 | """
24 | Lightning module for training and evaluating a K-Means model.
25 | """
26 |
27 | def __init__(
28 | self,
29 | model: KMeansModel,
30 | convergence_tolerance: float = 1e-4,
31 | predict_target: Literal["assignments", "distances", "inertias"] = "assignments",
32 | ):
33 | """
34 | Args:
35 | model: The model to train.
36 | convergence_tolerance: Training is conducted until the Frobenius norm of the change
37 | between cluster centroids falls below this threshold.
38 | predict_target: Whether to predict cluster assigments or distances to clusters.
39 | """
40 | super().__init__()
41 |
42 | self.model = model
43 | self.convergence_tolerance = convergence_tolerance
44 | self.predict_target = predict_target
45 |
46 | # Initialize aggregators
47 | self.centroid_aggregator = CentroidAggregator(
48 | num_clusters=self.model.config.num_clusters,
49 | num_features=self.model.config.num_features,
50 | dist_sync_fn=self.all_gather,
51 | )
52 |
53 | # Initialize metrics
54 | self.metric_inertia = MeanMetric()
55 |
56 | def configure_callbacks(self) -> List[pl.Callback]:
57 | if self.convergence_tolerance == 0:
58 | return []
59 | early_stopping = EarlyStopping(
60 | "frobenius_norm_change",
61 | patience=100000,
62 | stopping_threshold=self.convergence_tolerance,
63 | check_on_train_epoch_end=True,
64 | )
65 | return [early_stopping]
66 |
67 | def on_train_epoch_start(self) -> None:
68 | self.centroid_aggregator.reset()
69 |
70 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
71 | # First, we compute the cluster assignments
72 | _, assignments, inertias = self.model.forward(batch)
73 |
74 | # Then, we update the centroids
75 | self.centroid_aggregator.update(batch, assignments)
76 |
77 | # And log the inertia
78 | self.metric_inertia.update(inertias)
79 | self.log("inertia", self.metric_inertia, on_step=False, on_epoch=True, prog_bar=True)
80 |
81 | def nonparametric_training_epoch_end(self) -> None:
82 | centroids = self.centroid_aggregator.compute()
83 | self.log("frobenius_norm_change", torch.linalg.norm(self.model.centroids - centroids))
84 | self.model.centroids.copy_(centroids)
85 |
86 | def test_step(self, batch: torch.Tensor, _batch_idx: int) -> None:
87 | _, _, inertias = self.model.forward(batch)
88 | self.metric_inertia.update(inertias)
89 | self.log("inertia", self.metric_inertia)
90 |
91 | def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
92 | distances, assignments, inertias = self.model.forward(batch)
93 | if self.predict_target == "assignments":
94 | return assignments
95 | if self.predict_target == "inertias":
96 | return inertias
97 | return distances
98 |
99 |
100 | # -------------------------------------------------------------------------------------------------
101 | # INIT STRATEGIES
102 |
103 |
104 | class KmeansRandomInitLightningModule(NonparametricLightningModule):
105 | """
106 | Lightning module for initializing K-Means centroids randomly.
107 |
108 | Within the first epoch, all items are sampled. Thus, this module should only be trained for a
109 | single epoch.
110 | """
111 |
112 | def __init__(self, model: KMeansModel):
113 | """
114 | Args:
115 | model: The model to initialize.
116 | """
117 | super().__init__()
118 |
119 | self.model = model
120 |
121 | self.sampler = UniformSampler(
122 | num_choices=self.model.config.num_clusters,
123 | num_features=self.model.config.num_features,
124 | dist_sync_fn=self.all_gather_first,
125 | )
126 |
127 | def on_train_epoch_start(self) -> None:
128 | self.sampler.reset()
129 |
130 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
131 | self.sampler.update(batch)
132 |
133 | def nonparametric_training_epoch_end(self) -> None:
134 | choices = self.sampler.compute()
135 | self.model.centroids.copy_(choices)
136 |
137 |
138 | class KmeansPlusPlusInitLightningModule(NonparametricLightningModule):
139 | """
140 | Lightning module for K-Means++ initialization. It performs the following operations:
141 |
142 | - In the first epoch, a centroid is chosen at random.
143 | - In even epochs, candidates for the next centroid are sampled, based on the squared distance
144 | to their nearest cluster center.
145 | - In odd epochs, a candidate is selected deterministically as the next centroid.
146 |
147 | In total, initialization thus requires ``2 * k - 1`` epochs where ``k`` is the number of
148 | clusters.
149 | """
150 |
151 | def __init__(self, model: KMeansModel, is_batch_training: bool):
152 | """
153 | Args:
154 | model: The model to initialize.
155 | is_batch_training: Whether training is performed on mini-batches instead of the entire
156 | data at once.
157 | """
158 | super().__init__()
159 |
160 | self.model = model
161 | self.is_batch_training = is_batch_training
162 |
163 | self.uniform_sampler = UniformSampler(
164 | num_choices=1,
165 | num_features=self.model.config.num_features,
166 | dist_sync_fn=self.all_gather_first,
167 | )
168 | num_candidates = 2 + int(math.log(self.model.config.num_clusters))
169 | self.distance_sampler = DistanceSampler(
170 | num_choices=num_candidates,
171 | num_features=self.model.config.num_features,
172 | dist_sync_fn=self.all_gather_first,
173 | )
174 | self.candidate_inertia_summer = BatchSummer(
175 | num_candidates,
176 | dist_sync_fn=self.all_gather,
177 | )
178 |
179 | # Some buffers required for running initialization
180 | self.centroid_candidates: torch.Tensor
181 | self.register_buffer(
182 | "centroid_candidates",
183 | torch.empty(num_candidates, self.model.config.num_features),
184 | persistent=False,
185 | )
186 |
187 | if not self.is_batch_training:
188 | self.shortest_distance_cache: torch.Tensor
189 | self.register_buffer("shortest_distance_cache", torch.empty(1), persistent=False)
190 |
191 | def on_train_epoch_start(self) -> None:
192 | if self.current_epoch == 0:
193 | self.uniform_sampler.reset()
194 | elif self._is_current_epoch_sampling:
195 | self.distance_sampler.reset()
196 | else:
197 | self.candidate_inertia_summer.reset()
198 |
199 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
200 | if self.current_epoch == 0:
201 | self.uniform_sampler.update(batch)
202 | return
203 | # In all other epochs, we either sample a number of candidates from the remaining
204 | # datapoints or select a candidate deterministically. In any case, the shortest
205 | # distance is required.
206 | if self.current_epoch == 1:
207 | # In the first epoch, we can skip any argmin as the shortest distances are computed
208 | # with respect to the first centroid.
209 | shortest_distances = torch.cdist(batch, self.model.centroids[:1]).squeeze(1)
210 | if not self.is_batch_training:
211 | self.shortest_distance_cache = shortest_distances
212 | elif self.is_batch_training:
213 | # For batch training, we always need to recompute all distances since we can't
214 | # cache them (this is the whole reason for batch training).
215 | distances = torch.cdist(batch, self.model.centroids[: self._init_epoch + 1])
216 | shortest_distances = distances.gather(
217 | 1, distances.min(1, keepdim=True).indices # min is faster than argmin on CPU
218 | ).squeeze(1)
219 | else:
220 | # If we're not doing batch training, we only need to compute the distance to the
221 | # newest centroid (and only if we're currently sampling)
222 | if self._is_current_epoch_sampling:
223 | latest_distance = torch.cdist(
224 | batch, self.model.centroids[self._init_epoch - 1].unsqueeze(0)
225 | ).squeeze(1)
226 | shortest_distances = torch.minimum(self.shortest_distance_cache, latest_distance)
227 | self.shortest_distance_cache = shortest_distances
228 | else:
229 | shortest_distances = self.shortest_distance_cache
230 |
231 | if self._is_current_epoch_sampling:
232 | # After computing the shortest distances, we can finally do the sampling
233 | self.distance_sampler.update(batch, shortest_distances)
234 | else:
235 | # Or, we select a candidate by the lowest resulting inertia
236 | distances = torch.cdist(batch, self.centroid_candidates)
237 | updated_distances = torch.minimum(distances, shortest_distances.unsqueeze(1))
238 | self.candidate_inertia_summer.update(updated_distances)
239 |
240 | def nonparametric_training_epoch_end(self) -> None:
241 | if self.current_epoch == 0:
242 | choice = self.uniform_sampler.compute()
243 | self.model.centroids[0].copy_(choice[0] if choice.dim() > 0 else choice)
244 | elif self._is_current_epoch_sampling:
245 | candidates = self.distance_sampler.compute()
246 | self.centroid_candidates.copy_(candidates)
247 | else:
248 | new_inertias = self.candidate_inertia_summer.compute()
249 | choice = new_inertias.argmin()
250 | self.model.centroids[self._init_epoch].copy_(self.centroid_candidates[choice])
251 |
252 | @property
253 | def _init_epoch(self) -> int:
254 | return (self.current_epoch + 1) // 2
255 |
256 | @property
257 | def _is_current_epoch_sampling(self) -> bool:
258 | return self.current_epoch % 2 == 1
259 |
260 |
261 | # -------------------------------------------------------------------------------------------------
262 | # MISC
263 |
264 |
265 | class FeatureVarianceLightningModule(NonparametricLightningModule):
266 | """
267 | Lightning module for computing the average variance of a dataset's features.
268 |
269 | In the first epoch, it computes the features' means, then it can compute their variances.
270 | """
271 |
272 | def __init__(self, variances: torch.Tensor):
273 | """
274 | Args:
275 | variances: The output tensor where the variances are stored.
276 | """
277 | super().__init__()
278 |
279 | self.mean_aggregator = BatchAverager(
280 | num_values=variances.size(0),
281 | for_variance=False,
282 | dist_sync_fn=self.all_gather,
283 | )
284 | self.variance_aggregator = BatchAverager(
285 | num_values=variances.size(0),
286 | for_variance=True,
287 | dist_sync_fn=self.all_gather,
288 | )
289 |
290 | self.means: torch.Tensor
291 | self.register_buffer("means", torch.empty(variances.size(0)), persistent=False)
292 |
293 | self.variances: torch.Tensor
294 | self.register_buffer("variances", variances, persistent=False)
295 |
296 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
297 | if self.current_epoch == 0:
298 | self.mean_aggregator.update(batch)
299 | else:
300 | self.variance_aggregator.update((batch - self.means.unsqueeze(0)).square())
301 |
302 | def nonparametric_training_epoch_end(self) -> None:
303 | if self.current_epoch == 0:
304 | self.means.copy_(self.mean_aggregator.compute())
305 | else:
306 | self.variances.copy_(self.variance_aggregator.compute())
307 |
--------------------------------------------------------------------------------
/pycave/bayes/gmm/estimator.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import logging
3 | from typing import Any, cast, List, Tuple
4 | import torch
5 | from lightkit import ConfigurableBaseEstimator
6 | from lightkit.data import collate_tensor, DataLoader, dataset_from_tensors, TensorLike
7 | from lightkit.estimator import PredictorMixin
8 | from pycave.bayes.core import CovarianceType
9 | from pycave.clustering import KMeans
10 | from .lightning_module import (
11 | GaussianMixtureKmeansInitLightningModule,
12 | GaussianMixtureLightningModule,
13 | GaussianMixtureRandomInitLightningModule,
14 | )
15 | from .model import GaussianMixtureModel, GaussianMixtureModelConfig
16 | from .types import GaussianMixtureInitStrategy
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class GaussianMixture(
22 | ConfigurableBaseEstimator[GaussianMixtureModel], # type: ignore
23 | PredictorMixin[TensorLike, torch.Tensor],
24 | ):
25 | """
26 | Probabilistic model assuming that data is generated from a mixture of Gaussians.
27 |
28 | The mixture is assumed to be composed of a fixed number of components with individual means
29 | and covariances. More information on Gaussian mixture models (GMMs) is available on
30 | `Wikipedia `_.
31 |
32 | See also:
33 | .. currentmodule:: pycave.bayes.gmm
34 | .. autosummary::
35 | :nosignatures:
36 | :template: classes/pytorch_module.rst
37 |
38 | GaussianMixtureModel
39 | GaussianMixtureModelConfig
40 | """
41 |
42 | #: The fitted PyTorch module with all estimated parameters.
43 | model_: GaussianMixtureModel
44 | #: A boolean indicating whether the model converged during training.
45 | converged_: bool
46 | #: The number of iterations the model was fitted for, excluding initialization.
47 | num_iter_: int
48 | #: The average per-datapoint negative log-likelihood at the last training step.
49 | nll_: float
50 |
51 | def __init__(
52 | self,
53 | num_components: int = 1,
54 | *,
55 | covariance_type: CovarianceType = "diag",
56 | init_strategy: GaussianMixtureInitStrategy = "kmeans",
57 | init_means: torch.Tensor | None = None,
58 | convergence_tolerance: float = 1e-3,
59 | covariance_regularization: float = 1e-6,
60 | batch_size: int | None = None,
61 | trainer_params: dict[str, Any] | None = None,
62 | ):
63 | """
64 | Args:
65 | num_components: The number of components in the GMM. The dimensionality of each
66 | component is automatically inferred from the data.
67 | covariance_type: The type of covariance to assume for all Gaussian components.
68 | init_strategy: The strategy for initializing component means and covariances.
69 | init_means: An optional initial guess for the means of the components. If provided,
70 | must be a tensor of shape ``[num_components, num_features]``. If this is given,
71 | the ``init_strategy`` is ignored and the means are handled as if K-means
72 | initialization has been run.
73 | convergence_tolerance: The change in the per-datapoint negative log-likelihood which
74 | implies that training has converged.
75 | covariance_regularization: A small value which is added to the diagonal of the
76 | covariance matrix to ensure that it is positive semi-definite.
77 | batch_size: The batch size to use when fitting the model. If not provided, the full
78 | data will be used as a single batch. Set this if the full data does not fit into
79 | memory.
80 | num_workers: The number of workers to use for loading the data. Only used if a PyTorch
81 | dataset is passed to :meth:`fit` or related methods.
82 | trainer_params: Initialization parameters to use when initializing a PyTorch Lightning
83 | trainer. By default, it disables various stdout logs unless PyCave is configured to
84 | do verbose logging. Checkpointing and logging are disabled regardless of the log
85 | level. This estimator further sets the following overridable defaults:
86 |
87 | - ``max_epochs=100``
88 |
89 | Note:
90 | The number of epochs passed to the initializer only define the number of optimization
91 | epochs. Prior to that, initialization is run which may perform additional iterations
92 | through the data.
93 |
94 | Note:
95 | For batch training, the number of epochs run (i.e. the number of passes through the
96 | data), does not align with the number of epochs passed to the initializer. This is
97 | because the EM algorithm needs to be split up across two epochs. The actual number of
98 | minimum/maximum epochs is, thus, doubled. Nonetheless, :attr:`num_iter_` indicates how
99 | many EM iterations have been run.
100 | """
101 | super().__init__(
102 | default_params=dict(max_epochs=100),
103 | user_params=trainer_params,
104 | )
105 |
106 | self.num_components = num_components
107 | self.covariance_type = covariance_type
108 | self.init_strategy = init_strategy
109 | self.init_means = init_means
110 | self.convergence_tolerance = convergence_tolerance
111 | self.covariance_regularization = covariance_regularization
112 |
113 | self.batch_size = batch_size
114 |
115 | def fit(self, data: TensorLike) -> GaussianMixture:
116 | """
117 | Fits the Gaussian mixture on the provided data, estimating component priors, means and
118 | covariances. Parameters are estimated using the EM algorithm.
119 |
120 | Args:
121 | data: The tabular data to fit on. The dimensionality of the Gaussian mixture is
122 | automatically inferred from this data.
123 |
124 | Returns:
125 | The fitted Gaussian mixture.
126 | """
127 | # Initialize the model
128 | num_features = len(data[0])
129 | config = GaussianMixtureModelConfig(
130 | num_components=self.num_components,
131 | num_features=num_features,
132 | covariance_type=self.covariance_type, # type: ignore
133 | )
134 | self.model_ = GaussianMixtureModel(config)
135 |
136 | # Setup the data loading
137 | loader = DataLoader(
138 | dataset_from_tensors(data),
139 | batch_size=self.batch_size or len(data),
140 | collate_fn=collate_tensor,
141 | )
142 | is_batch_training = self._num_batches_per_epoch(loader) == 1
143 |
144 | # Run k-means if required or copy means
145 | if self.init_means is not None:
146 | self.model_.means.copy_(self.init_means)
147 | elif self.init_strategy in ("kmeans", "kmeans++"):
148 | logger.info("Fitting K-means estimator...")
149 | params = self.trainer_params_user
150 | if self.init_strategy == "kmeans++":
151 | params = {**(params or {}), **dict(max_epochs=0)}
152 |
153 | estimator = KMeans(
154 | self.num_components,
155 | batch_size=self.batch_size,
156 | trainer_params=params,
157 | ).fit(data)
158 | self.model_.means.copy_(estimator.model_.centroids)
159 |
160 | # Run initialization
161 | logger.info("Running initialization...")
162 | if self.init_strategy in ("kmeans", "kmeans++") and self.init_means is None:
163 | module = GaussianMixtureKmeansInitLightningModule(
164 | self.model_,
165 | covariance_regularization=self.covariance_regularization,
166 | )
167 | self.trainer(max_epochs=1).fit(module, loader)
168 | else:
169 | module = GaussianMixtureRandomInitLightningModule(
170 | self.model_,
171 | covariance_regularization=self.covariance_regularization,
172 | is_batch_training=is_batch_training,
173 | use_model_means=self.init_means is not None,
174 | )
175 | self.trainer(max_epochs=1 + int(is_batch_training)).fit(module, loader)
176 |
177 | # Fit model
178 | logger.info("Fitting Gaussian mixture...")
179 | module = GaussianMixtureLightningModule(
180 | self.model_,
181 | convergence_tolerance=self.convergence_tolerance,
182 | covariance_regularization=self.covariance_regularization,
183 | is_batch_training=is_batch_training,
184 | )
185 | trainer = self.trainer(
186 | max_epochs=cast(int, self.trainer_params["max_epochs"]) * (1 + int(is_batch_training))
187 | )
188 | trainer.fit(module, loader)
189 |
190 | # Assign convergence properties
191 | self.num_iter_ = module.current_epoch
192 | if is_batch_training:
193 | self.num_iter_ //= 2
194 | self.converged_ = trainer.should_stop
195 | self.nll_ = cast(float, trainer.callback_metrics["nll"].item())
196 | return self
197 |
198 | def sample(self, num_datapoints: int) -> torch.Tensor:
199 | """
200 | Samples datapoints from the fitted Gaussian mixture.
201 |
202 | Args:
203 | num_datapoints: The number of datapoints to sample.
204 |
205 | Returns:
206 | A tensor of shape ``[num_datapoints, dim]`` providing the samples.
207 |
208 | Note:
209 | This method does not parallelize across multiple processes, i.e. performs no
210 | synchronization.
211 | """
212 | return self.model_.sample(num_datapoints)
213 |
214 | def score(self, data: TensorLike) -> float:
215 | """
216 | Computes the average negative log-likelihood (NLL) of the provided datapoints.
217 |
218 | Args:
219 | data: The datapoints for which to evaluate the NLL.
220 |
221 | Returns:
222 | The average NLL of all datapoints.
223 |
224 | Note:
225 | See :meth:`score_samples` to obtain NLL values for individual datapoints.
226 | """
227 | loader = DataLoader(
228 | dataset_from_tensors(data),
229 | batch_size=self.batch_size or len(data),
230 | collate_fn=collate_tensor,
231 | )
232 | result = self.trainer().test(
233 | GaussianMixtureLightningModule(self.model_), loader, verbose=False
234 | )
235 | return result[0]["nll"]
236 |
237 | def score_samples(self, data: TensorLike) -> torch.Tensor:
238 | """
239 | Computes the negative log-likelihood (NLL) of each of the provided datapoints.
240 |
241 | Args:
242 | data: The datapoints for which to compute the NLL.
243 |
244 | Returns:
245 | A tensor of shape ``[num_datapoints]`` with the NLL for each datapoint.
246 |
247 | Attention:
248 | When calling this function in a multi-process environment, each process receives only
249 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
250 | the values returned from this method.
251 | """
252 | loader = DataLoader(
253 | dataset_from_tensors(data),
254 | batch_size=self.batch_size or len(data),
255 | collate_fn=collate_tensor,
256 | )
257 | result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader)
258 | return torch.stack([x[1] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
259 |
260 | def predict(self, data: TensorLike) -> torch.Tensor:
261 | """
262 | Computes the most likely components for each of the provided datapoints.
263 |
264 | Args:
265 | data: The datapoints for which to obtain the most likely components.
266 |
267 | Returns:
268 | A tensor of shape ``[num_datapoints]`` with the indices of the most likely components.
269 |
270 | Note:
271 | Use :meth:`predict_proba` to obtain probabilities for each component instead of the
272 | most likely component only.
273 |
274 | Attention:
275 | When calling this function in a multi-process environment, each process receives only
276 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
277 | the values returned from this method.
278 | """
279 | return self.predict_proba(data).argmax(-1)
280 |
281 | def predict_proba(self, data: TensorLike) -> torch.Tensor:
282 | """
283 | Computes a distribution over the components for each of the provided datapoints.
284 |
285 | Args:
286 | data: The datapoints for which to compute the component assignment probabilities.
287 |
288 | Returns:
289 | A tensor of shape ``[num_datapoints, num_components]`` with the assignment
290 | probabilities for each component and datapoint. Note that each row of the vector sums
291 | to 1, i.e. the returned tensor provides a proper distribution over the components for
292 | each datapoint.
293 |
294 | Attention:
295 | When calling this function in a multi-process environment, each process receives only
296 | a subset of the predictions. If you want to aggregate predictions, make sure to gather
297 | the values returned from this method.
298 | """
299 | loader = DataLoader(
300 | dataset_from_tensors(data),
301 | batch_size=self.batch_size or len(data),
302 | collate_fn=collate_tensor,
303 | )
304 | result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader)
305 | return torch.cat([x[0] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
306 |
--------------------------------------------------------------------------------
/pycave/bayes/gmm/lightning_module.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import pytorch_lightning as pl
3 | import torch
4 | from pytorch_lightning.callbacks import EarlyStopping
5 | from torchmetrics import MeanMetric
6 | from pycave.bayes.core import cholesky_precision
7 | from pycave.utils import NonparametricLightningModule
8 | from .metrics import CovarianceAggregator, MeanAggregator, PriorAggregator
9 | from .model import GaussianMixtureModel
10 |
11 | # -------------------------------------------------------------------------------------------------
12 | # TRAINING
13 |
14 |
15 | class GaussianMixtureLightningModule(NonparametricLightningModule):
16 | """
17 | Lightning module for training and evaluating a Gaussian mixture model.
18 | """
19 |
20 | def __init__(
21 | self,
22 | model: GaussianMixtureModel,
23 | convergence_tolerance: float = 1e-3,
24 | covariance_regularization: float = 1e-6,
25 | is_batch_training: bool = False,
26 | ):
27 | """
28 | Args:
29 | model: The Gaussian mixture model to use for training/evaluation.
30 | convergence_tolerance: The change in the per-datapoint negative log-likelihood which
31 | implies that training has converged.
32 | covariance_regularization: A small value which is added to the diagonal of the
33 | covariance matrix to ensure that it is positive semi-definite.
34 | is_batch_training: Whether training is performed on mini-batches instead of the entire
35 | data at once. In the case of batching, the EM-algorithm is "split" across two
36 | epochs.
37 | """
38 | super().__init__()
39 |
40 | self.model = model
41 | self.convergence_tolerance = convergence_tolerance
42 | self.is_batch_training = is_batch_training
43 |
44 | # For batch training, we store a model copy such that we can "replay" responsibilities
45 | if self.is_batch_training:
46 | self.model_copy = GaussianMixtureModel(self.model.config)
47 | self.model_copy.load_state_dict(self.model.state_dict())
48 |
49 | # Initialize aggregators
50 | self.prior_aggregator = PriorAggregator(
51 | num_components=self.model.config.num_components,
52 | dist_sync_fn=self.all_gather,
53 | )
54 | self.mean_aggregator = MeanAggregator(
55 | num_components=self.model.config.num_components,
56 | num_features=self.model.config.num_features,
57 | dist_sync_fn=self.all_gather,
58 | )
59 | self.covar_aggregator = CovarianceAggregator(
60 | num_components=self.model.config.num_components,
61 | num_features=self.model.config.num_features,
62 | covariance_type=self.model.config.covariance_type,
63 | reg=covariance_regularization,
64 | dist_sync_fn=self.all_gather,
65 | )
66 |
67 | # Initialize metrics
68 | self.metric_nll = MeanMetric(dist_sync_fn=self.all_gather)
69 |
70 | def configure_callbacks(self) -> list[pl.Callback]:
71 | if self.convergence_tolerance == 0:
72 | return []
73 | early_stopping = EarlyStopping(
74 | "nll",
75 | min_delta=self.convergence_tolerance,
76 | patience=2 if self.is_batch_training else 1,
77 | check_on_train_epoch_end=True,
78 | strict=False, # Allows to not log every epoch
79 | )
80 | return [early_stopping]
81 |
82 | def on_train_epoch_start(self) -> None:
83 | self.prior_aggregator.reset()
84 | self.mean_aggregator.reset()
85 | self.covar_aggregator.reset()
86 |
87 | def nonparametric_training_step(self, batch: torch.Tensor, _batch_idx: int) -> None:
88 | ### E-Step
89 | if self._computes_responsibilities_on_live_model:
90 | log_responsibilities, log_probs = self.model.forward(batch)
91 | else:
92 | log_responsibilities, log_probs = self.model_copy.forward(batch)
93 | responsibilities = log_responsibilities.exp()
94 |
95 | # Compute the NLL for early stopping
96 | if self._should_log_nll:
97 | self.metric_nll.update(-log_probs)
98 | self.log("nll", self.metric_nll, on_step=False, on_epoch=True, prog_bar=True)
99 |
100 | ### (Partial) M-Step
101 | if self._should_update_means:
102 | self.prior_aggregator.update(responsibilities)
103 | self.mean_aggregator.update(batch, responsibilities)
104 | if self._should_update_covars:
105 | means = self.mean_aggregator.compute()
106 | self.covar_aggregator.update(batch, responsibilities, means)
107 | else:
108 | self.covar_aggregator.update(batch, responsibilities, self.model.means)
109 |
110 | def nonparametric_training_epoch_end(self) -> None:
111 | # Prior to updating the model, we might need to copy it in the case of batch training
112 | if self._requires_to_copy_live_model:
113 | self.model_copy.load_state_dict(self.model.state_dict())
114 |
115 | # Finalize the M-Step
116 | if self._should_update_means:
117 | priors = self.prior_aggregator.compute()
118 | self.model.component_probs.copy_(priors)
119 |
120 | means = self.mean_aggregator.compute()
121 | self.model.means.copy_(means)
122 |
123 | if self._should_update_covars:
124 | covars = self.covar_aggregator.compute()
125 | self.model.precisions_cholesky.copy_(
126 | cholesky_precision(covars, self.model.config.covariance_type)
127 | )
128 |
129 | def test_step(self, batch: torch.Tensor, _batch_idx: int) -> None:
130 | _, log_probs = self.model.forward(batch)
131 | self.metric_nll.update(-log_probs)
132 | self.log("nll", self.metric_nll)
133 |
134 | def predict_step(
135 | self, batch: torch.Tensor, batch_idx: int
136 | ) -> tuple[torch.Tensor, torch.Tensor]:
137 | log_responsibilities, log_probs = self.model.forward(batch)
138 | return log_responsibilities.exp(), -log_probs
139 |
140 | @property
141 | def _computes_responsibilities_on_live_model(self) -> bool:
142 | if not self.is_batch_training:
143 | return True
144 | return self.current_epoch % 2 == 0
145 |
146 | @property
147 | def _requires_to_copy_live_model(self) -> bool:
148 | if not self.is_batch_training:
149 | return False
150 | return self.current_epoch % 2 == 0
151 |
152 | @property
153 | def _should_log_nll(self) -> bool:
154 | if not self.is_batch_training:
155 | return True
156 | return self.current_epoch % 2 == 1
157 |
158 | @property
159 | def _should_update_means(self) -> bool:
160 | if not self.is_batch_training:
161 | return True
162 | return self.current_epoch % 2 == 0
163 |
164 | @property
165 | def _should_update_covars(self) -> bool:
166 | if not self.is_batch_training:
167 | return True
168 | return self.current_epoch % 2 == 1
169 |
170 |
171 | # -------------------------------------------------------------------------------------------------
172 | # INIT STRATEGIES
173 |
174 |
175 | class GaussianMixtureKmeansInitLightningModule(NonparametricLightningModule):
176 | """
177 | Lightning module for initializing a Gaussian mixture from centroids found via K-Means.
178 | """
179 |
180 | def __init__(self, model: GaussianMixtureModel, covariance_regularization: float):
181 | """
182 | Args:
183 | model: The model whose parameters to initialize.
184 | covariance_regularization: A small value which is added to the diagonal of the
185 | covariance matrix to ensure that it is positive semi-definite.
186 | """
187 | super().__init__()
188 |
189 | self.model = model
190 |
191 | self.prior_aggregator = PriorAggregator(
192 | num_components=self.model.config.num_components,
193 | dist_sync_fn=self.all_gather,
194 | )
195 | self.covar_aggregator = CovarianceAggregator(
196 | num_components=self.model.config.num_components,
197 | num_features=self.model.config.num_features,
198 | covariance_type=self.model.config.covariance_type,
199 | reg=covariance_regularization,
200 | dist_sync_fn=self.all_gather,
201 | )
202 |
203 | def on_train_epoch_start(self) -> None:
204 | self.prior_aggregator.reset()
205 | self.covar_aggregator.reset()
206 |
207 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
208 | # Just like for k-means, responsibilities are one-hot assignments to the clusters
209 | responsibilities = _one_hot_responsibilities(batch, self.model.means)
210 |
211 | # Then, we can update the aggregators
212 | self.prior_aggregator.update(responsibilities)
213 | self.covar_aggregator.update(batch, responsibilities, self.model.means)
214 |
215 | def nonparametric_training_epoch_end(self) -> None:
216 | priors = self.prior_aggregator.compute()
217 | self.model.component_probs.copy_(priors)
218 |
219 | covars = self.covar_aggregator.compute()
220 | self.model.precisions_cholesky.copy_(
221 | cholesky_precision(covars, self.model.config.covariance_type)
222 | )
223 |
224 |
225 | class GaussianMixtureRandomInitLightningModule(NonparametricLightningModule):
226 | """
227 | Lightning module for initializing a Gaussian mixture randomly or using the assignments for
228 | arbitrary means that were not found via K-means.
229 |
230 | For batch training, this requires two epochs, otherwise, it requires a single epoch.
231 | """
232 |
233 | def __init__(
234 | self,
235 | model: GaussianMixtureModel,
236 | covariance_regularization: float,
237 | is_batch_training: bool,
238 | use_model_means: bool,
239 | ):
240 | """
241 | Args:
242 | model: The model whose parameters to initialize.
243 | covariance_regularization: A small value which is added to the diagonal of the
244 | covariance matrix to ensure that it is positive semi-definite.
245 | is_batch_training: Whether training is performed on mini-batches instead of the entire
246 | data at once.
247 | use_model_means: Whether the model's means ought to be used for one-hot component
248 | assignments.
249 | """
250 | super().__init__()
251 |
252 | self.model = model
253 | self.is_batch_training = is_batch_training
254 | self.use_model_means = use_model_means
255 |
256 | self.prior_aggregator = PriorAggregator(
257 | num_components=self.model.config.num_components,
258 | dist_sync_fn=self.all_gather,
259 | )
260 | self.mean_aggregator = MeanAggregator(
261 | num_components=self.model.config.num_components,
262 | num_features=self.model.config.num_features,
263 | dist_sync_fn=self.all_gather,
264 | )
265 | self.covar_aggregator = CovarianceAggregator(
266 | num_components=self.model.config.num_components,
267 | num_features=self.model.config.num_features,
268 | covariance_type=self.model.config.covariance_type,
269 | reg=covariance_regularization,
270 | dist_sync_fn=self.all_gather,
271 | )
272 |
273 | # For batch training, we store a model copy such that we can "replay" responsibilities
274 | if self.is_batch_training and self.use_model_means:
275 | self.model_copy = GaussianMixtureModel(self.model.config)
276 | self.model_copy.load_state_dict(self.model.state_dict())
277 |
278 | def on_train_epoch_start(self) -> None:
279 | self.prior_aggregator.reset()
280 | self.mean_aggregator.reset()
281 | self.covar_aggregator.reset()
282 |
283 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
284 | if self.use_model_means:
285 | if self.current_epoch == 0:
286 | responsibilities = _one_hot_responsibilities(batch, self.model.means)
287 | else:
288 | responsibilities = _one_hot_responsibilities(batch, self.model_copy.means)
289 | else:
290 | responsibilities = torch.rand(
291 | batch.size(0),
292 | self.model.config.num_components,
293 | device=batch.device,
294 | dtype=batch.dtype,
295 | )
296 | responsibilities = responsibilities / responsibilities.sum(1, keepdim=True)
297 |
298 | if self.current_epoch == 0:
299 | self.prior_aggregator.update(responsibilities)
300 | self.mean_aggregator.update(batch, responsibilities)
301 | if not self.is_batch_training:
302 | means = self.mean_aggregator.compute()
303 | self.covar_aggregator.update(batch, responsibilities, means)
304 | else:
305 | # Only reached if batch training
306 | self.covar_aggregator.update(batch, responsibilities, self.model.means)
307 |
308 | def nonparametric_training_epoch_end(self) -> None:
309 | if self.current_epoch == 0 and self.is_batch_training:
310 | self.model_copy.load_state_dict(self.model.state_dict())
311 |
312 | if self.current_epoch == 0:
313 | priors = self.prior_aggregator.compute()
314 | self.model.component_probs.copy_(priors)
315 |
316 | means = self.mean_aggregator.compute()
317 | self.model.means.copy_(means)
318 |
319 | if (self.current_epoch == 0 and not self.is_batch_training) or self.current_epoch == 1:
320 | covars = self.covar_aggregator.compute()
321 | self.model.precisions_cholesky.copy_(
322 | cholesky_precision(covars, self.model.config.covariance_type)
323 | )
324 |
325 |
326 | def _one_hot_responsibilities(data: torch.Tensor, centroids: torch.Tensor) -> torch.Tensor:
327 | distances = torch.cdist(data, centroids)
328 | assignments = distances.min(1).indices
329 | onehot = torch.eye(
330 | centroids.size(0),
331 | device=data.device,
332 | dtype=data.dtype,
333 | )
334 | return onehot[assignments]
335 |
--------------------------------------------------------------------------------