├── pycave ├── py.typed ├── clustering │ ├── __init__.py │ └── kmeans │ │ ├── __init__.py │ │ ├── types.py │ │ ├── model.py │ │ ├── metrics.py │ │ ├── estimator.py │ │ └── lightning_module.py ├── utils │ ├── __init__.py │ └── lightning_module.py ├── bayes │ ├── __init__.py │ ├── markov_chain │ │ ├── __init__.py │ │ ├── types.py │ │ ├── lightning_module.py │ │ ├── metrics.py │ │ ├── model.py │ │ └── estimator.py │ ├── gmm │ │ ├── __init__.py │ │ ├── types.py │ │ ├── model.py │ │ ├── metrics.py │ │ ├── estimator.py │ │ └── lightning_module.py │ └── core │ │ ├── __init__.py │ │ ├── types.py │ │ ├── utils.py │ │ ├── _jit.py │ │ └── normal.py └── __init__.py ├── tests ├── __init__.py ├── _data │ ├── __init__.py │ ├── gmm.py │ └── normal.py ├── bayes │ ├── gmm │ │ ├── test_gmm_model.py │ │ ├── test_gmm_estimator.py │ │ ├── test_gmm_metrics.py │ │ └── benchmark_gmm_estimator.py │ ├── core │ │ ├── benchmark_precision_cholesky.py │ │ ├── benchmark_log_normal.py │ │ └── test_normal.py │ └── markov_chain │ │ ├── test_markov_chain_model.py │ │ └── test_markov_chain_estimator.py └── clustering │ └── kmeans │ ├── test_kmeans_model.py │ ├── test_kmeans_estimator.py │ └── benchmark_kmeans_estimator.py ├── .github ├── CODEOWNERS ├── dependabot.yml └── workflows │ ├── deploy.yml │ ├── ci.yml │ └── docs.yml ├── .prettierignore ├── docs ├── _static │ ├── favicon.ico │ └── logo.svg ├── _templates │ ├── autosummary │ │ ├── method.rst │ │ └── class.rst │ └── classes │ │ ├── type_alias.rst │ │ └── pytorch_module.rst ├── spelling_wordlist.txt ├── sites │ ├── api.rst │ └── benchmark.rst ├── conf.py └── index.rst ├── Makefile ├── .prettierrc ├── .pre-commit-config.yaml ├── LICENSE ├── pyproject.toml ├── README.md └── .gitignore /pycave/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @borchero 2 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | .pytest_cache/ 2 | build/ 3 | -------------------------------------------------------------------------------- /docs/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/borchero/pycave/HEAD/docs/_static/favicon.ico -------------------------------------------------------------------------------- /pycave/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | from .kmeans import KMeans 2 | 3 | __all__ = [ 4 | "KMeans", 5 | ] 6 | -------------------------------------------------------------------------------- /pycave/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .lightning_module import NonparametricLightningModule 2 | 3 | __all__ = ["NonparametricLightningModule"] 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs 2 | 3 | docs: 4 | rm -rf build 5 | rm -rf docs/generated 6 | rm -rf docs/sites/generated 7 | sphinx-build -W -b html docs build 8 | -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "bracketSpacing": true, 3 | "endOfLine": "auto", 4 | "printWidth": 99, 5 | "proseWrap": "always", 6 | "singleQuote": false, 7 | "tabWidth": 2 8 | } 9 | -------------------------------------------------------------------------------- /pycave/bayes/__init__.py: -------------------------------------------------------------------------------- 1 | from .gmm import GaussianMixture 2 | from .markov_chain import MarkovChain 3 | 4 | __all__ = [ 5 | "GaussianMixture", 6 | "MarkovChain", 7 | ] 8 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/method.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | 3 | {{ (class + "." + name) | underline }} 4 | 5 | .. currentmodule:: {{ module }} 6 | 7 | .. automethod:: {{ class }}.{{ name }} 8 | -------------------------------------------------------------------------------- /docs/_templates/classes/type_alias.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. role:: hidden 4 | 5 | {{ name | underline }} 6 | 7 | .. currentmodule:: {{ module }} 8 | 9 | .. autodata:: {{ name }} 10 | -------------------------------------------------------------------------------- /pycave/clustering/kmeans/__init__.py: -------------------------------------------------------------------------------- 1 | from .estimator import KMeans 2 | from .model import KMeansModel, KMeansModelConfig 3 | 4 | __all__ = [ 5 | "KMeans", 6 | "KMeansModel", 7 | "KMeansModelConfig", 8 | ] 9 | -------------------------------------------------------------------------------- /pycave/bayes/markov_chain/__init__.py: -------------------------------------------------------------------------------- 1 | from .estimator import MarkovChain 2 | from .model import MarkovChainModel, MarkovChainModelConfig 3 | 4 | __all__ = [ 5 | "MarkovChain", 6 | "MarkovChainModel", 7 | "MarkovChainModelConfig", 8 | ] 9 | -------------------------------------------------------------------------------- /pycave/bayes/gmm/__init__.py: -------------------------------------------------------------------------------- 1 | from .estimator import GaussianMixture 2 | from .model import GaussianMixtureModel, GaussianMixtureModelConfig 3 | 4 | __all__ = [ 5 | "GaussianMixture", 6 | "GaussianMixtureModel", 7 | "GaussianMixtureModelConfig", 8 | ] 9 | -------------------------------------------------------------------------------- /tests/bayes/gmm/test_gmm_model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | from torch import jit 3 | from pycave.bayes.gmm import GaussianMixtureModel, GaussianMixtureModelConfig 4 | 5 | 6 | def test_compile(): 7 | config = GaussianMixtureModelConfig(num_components=2, num_features=3, covariance_type="full") 8 | model = GaussianMixtureModel(config) 9 | jit.script(model) 10 | -------------------------------------------------------------------------------- /pycave/bayes/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .normal import cholesky_precision, covariance, log_normal, sample_normal 2 | from .types import CovarianceType 3 | from .utils import covariance_dim, covariance_shape 4 | 5 | __all__ = [ 6 | "cholesky_precision", 7 | "log_normal", 8 | "sample_normal", 9 | "covariance", 10 | "CovarianceType", 11 | "covariance_dim", 12 | "covariance_shape", 13 | ] 14 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Update GitHub actions 4 | - directory: / 5 | open-pull-requests-limit: 5 6 | package-ecosystem: github-actions 7 | schedule: 8 | interval: weekly 9 | day: saturday 10 | # Update Python dependencies 11 | - directory: / 12 | open-pull-requests-limit: 5 13 | package-ecosystem: pip 14 | schedule: 15 | interval: weekly 16 | day: saturday 17 | -------------------------------------------------------------------------------- /docs/_static/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Group 4 | 5 | 6 | 7 | 8 | PyCave 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /docs/spelling_wordlist.txt: -------------------------------------------------------------------------------- 1 | abc 2 | bayes 3 | boolean 4 | centroid 5 | Centroids 6 | centroids 7 | Checkpointing 8 | Cholesky 9 | contravariant 10 | covariances 11 | dataclass 12 | dataloader 13 | dataloaders 14 | datapoint 15 | datapoints 16 | dataset 17 | datasets 18 | diag 19 | dimensionality 20 | dtype 21 | Elkan 22 | Frobenius 23 | Gaussians 24 | GiB 25 | gmm 26 | init 27 | initializer 28 | iteratively 29 | KMeans 30 | kmeans 31 | learnt 32 | logits 33 | markov 34 | Mixin 35 | nn 36 | overridable 37 | parallelize 38 | params 39 | precisions 40 | pycave 41 | runtime 42 | scikit 43 | SequenceData 44 | Subclasses 45 | subclasses 46 | stdout 47 | TabularData 48 | tbd 49 | Utils 50 | Xeon 51 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Package 2 | on: 3 | release: 4 | types: [published] 5 | 6 | jobs: 7 | build: 8 | name: Publish 9 | runs-on: ubuntu-latest 10 | container: 11 | image: python:3.8-buster 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | - name: Install poetry 16 | uses: snok/install-poetry@v1 17 | - name: Tag 18 | run: poetry version ${{ github.event.release.tag_name }} 19 | - name: Build Wheel 20 | run: poetry build 21 | - name: Publish to PyPi 22 | run: poetry publish --username $PYPI_USERNAME --password $PYPI_PASSWORD 23 | env: 24 | PYPI_USERNAME: ${{ secrets.PYPI_USERNAME }} 25 | PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 26 | -------------------------------------------------------------------------------- /tests/_data/gmm.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | from typing import Tuple 3 | import torch 4 | from pycave.bayes.core import CovarianceType 5 | from pycave.bayes.gmm import GaussianMixtureModel, GaussianMixtureModelConfig 6 | 7 | 8 | def sample_gmm( 9 | num_datapoints: int, num_features: int, num_components: int, covariance_type: CovarianceType 10 | ) -> Tuple[torch.Tensor, torch.Tensor]: 11 | config = GaussianMixtureModelConfig(num_components, num_features, covariance_type) 12 | model = GaussianMixtureModel(config) 13 | 14 | # Means and covariances can simply be scaled 15 | model.means.mul_(torch.rand(num_components).unsqueeze(-1) * 10).add_( 16 | torch.rand(num_components).unsqueeze(-1) * 10 17 | ) 18 | 19 | return model.sample(num_datapoints), model.means 20 | -------------------------------------------------------------------------------- /pycave/clustering/kmeans/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Literal 3 | 4 | KMeansInitStrategy = Literal["random", "kmeans++"] 5 | KMeansInitStrategy.__doc__ = """ 6 | Strategy for initializing KMeans centroids. 7 | 8 | - **random**: Centroids are sampled randomly from the data. This has complexity ``O(n)`` for ``n`` 9 | datapoints. 10 | - **kmeans++**: Centroids are computed iteratively. The first centroid is sampled randomly from 11 | the data. Subsequently, centroids are sampled from the remaining datapoints with probability 12 | proportional to ``D(x)^2`` where ``D(x)`` is the distance of datapoint ``x`` to the closest 13 | centroid chosen so far. This has complexity ``O(kn)`` for ``k`` clusters and ``n`` datapoints. 14 | If done on mini-batches, the complexity increases to ``O(k^2 n)``. 15 | """ 16 | -------------------------------------------------------------------------------- /pycave/bayes/gmm/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Literal 3 | 4 | GaussianMixtureInitStrategy = Literal["random", "kmeans", "kmeans++"] 5 | GaussianMixtureInitStrategy.__doc__ = """ 6 | Strategy for initializing the parameters of a Gaussian mixture model. 7 | 8 | - **random**: Samples responsibilities of datapoints at random and subsequently initializes means 9 | and covariances from these. 10 | - **kmeans**: Runs K-Means via :class:`pycave.clustering.KMeans` and uses the centroids as the 11 | initial component means. For computing the covariances, responsibilities are given as the 12 | one-hot cluster assignments. 13 | - **kmeans++**: Runs only the K-Means++ initialization procedure to sample means in a smart 14 | fashion. Might be more efficient than ``kmeans`` as it does not actually run clustering. For 15 | many clusters, this is, however, still slow. 16 | """ 17 | -------------------------------------------------------------------------------- /pycave/bayes/core/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Literal 3 | 4 | CovarianceType = Literal["full", "tied", "diag", "spherical"] 5 | CovarianceType.__doc__ = """ 6 | The type of covariance to use for a set of multivariate Normal distributions. 7 | 8 | - **full**: Each distribution has a full covariance matrix. Covariance matrix is a tensor of shape 9 | ``[num_components, num_features, num_features]``. 10 | - **tied**: All distributions share the same full covariance matrix. Covariance matrix is a tensor 11 | of shape ``[num_features, num_features]``. 12 | - **diag**: Each distribution has a diagonal covariance matrix. Covariance matrix is a tensor of 13 | shape ``[num_components, num_features]``. 14 | - **spherical**: Each distribution has a diagonal covariance matrix which is a multiple of the 15 | identity matrix. Covariance matrix is a tensor of shape ``[num_components]``. 16 | """ 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | skip: [pylint] 4 | 5 | repos: 6 | - repo: https://github.com/psf/black 7 | rev: 22.12.0 8 | hooks: 9 | - id: black 10 | - repo: https://github.com/PyCQA/pylint 11 | rev: v2.15.9 12 | hooks: 13 | - id: pylint 14 | language: system 15 | types: [python] 16 | args: [-rn, -sn] 17 | - repo: https://github.com/pre-commit/mirrors-mypy 18 | rev: v0.991 19 | hooks: 20 | - id: mypy 21 | - repo: https://github.com/PyCQA/isort 22 | rev: v5.11.3 23 | hooks: 24 | - id: isort 25 | - repo: https://github.com/PyCQA/docformatter 26 | rev: v1.5.1 27 | hooks: 28 | - id: docformatter 29 | additional_dependencies: [tomli] 30 | - repo: https://github.com/asottile/pyupgrade 31 | rev: v3.3.1 32 | hooks: 33 | - id: pyupgrade 34 | args: [--py38-plus] 35 | - repo: https://github.com/pre-commit/mirrors-prettier 36 | rev: v3.0.0-alpha.4 37 | hooks: 38 | - id: prettier 39 | -------------------------------------------------------------------------------- /tests/_data/normal.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | from typing import List 3 | import torch 4 | 5 | 6 | def sample_data(counts: List[int], dims: List[int]) -> List[torch.Tensor]: 7 | return [torch.randn(count, dim) for count, dim in zip(counts, dims)] 8 | 9 | 10 | def sample_means(counts: List[int], dims: List[int]) -> List[torch.Tensor]: 11 | return [torch.randn(count, dim) for count, dim in zip(counts, dims)] 12 | 13 | 14 | def sample_spherical_covars(counts: List[int]) -> List[torch.Tensor]: 15 | return [torch.rand(count) for count in counts] 16 | 17 | 18 | def sample_diag_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]: 19 | return [torch.rand(count, dim).squeeze() for count, dim in zip(counts, dims)] 20 | 21 | 22 | def sample_full_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]: 23 | result = [] 24 | for count, dim in zip(counts, dims): 25 | A = torch.rand(count, dim * 10, dim) 26 | result.append(A.permute(0, 2, 1).bmm(A).squeeze()) 27 | return result 28 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | 3 | {{ name | underline }} 4 | 5 | .. currentmodule:: {{ module }} 6 | 7 | .. autoclass:: {{ name }} 8 | :show-inheritance: 9 | 10 | {% if methods %} 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | :toctree: 15 | :nosignatures: 16 | 17 | {% for item in methods %} 18 | {%- if not item in inherited_members %} 19 | {%- if not item.startswith("_") %} 20 | ~{{ name }}.{{ item }} 21 | {%- endif %} 22 | {%- endif %} 23 | {%- endfor %} 24 | 25 | .. rubric:: Inherited Methods 26 | 27 | .. autosummary:: 28 | :toctree: 29 | :nosignatures: 30 | 31 | {% for item in methods %} 32 | {%- if item in inherited_members %} 33 | {%- if not item.startswith("_") %} 34 | ~{{ name }}.{{ item }} 35 | {%- endif %} 36 | {%- endif %} 37 | {%- endfor %} 38 | {%- endif %} 39 | 40 | {% if attributes %} 41 | .. rubric:: Attributes 42 | 43 | .. autosummary:: 44 | {% for item in attributes %} 45 | ~{{ name }}.{{ item }} 46 | {%- endfor %} 47 | {%- endif %} 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Oliver Borchert 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /tests/clustering/kmeans/test_kmeans_model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import torch 3 | from torch import jit 4 | from pycave.clustering.kmeans import KMeansModel, KMeansModelConfig 5 | 6 | 7 | def test_compile(): 8 | config = KMeansModelConfig(num_clusters=2, num_features=5) 9 | model = KMeansModel(config) 10 | jit.script(model) 11 | 12 | 13 | def test_forward(): 14 | config = KMeansModelConfig(num_clusters=2, num_features=2) 15 | model = KMeansModel(config) 16 | model.centroids.copy_(torch.as_tensor([[0.0, 0.0], [2.0, 2.0]])) 17 | 18 | X = torch.as_tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [-1.0, 4.0]]) 19 | distances, assignments, inertias = model.forward(X) 20 | 21 | expected_distances = torch.as_tensor([[0.0, 8.0], [2.0, 2.0], [8.0, 0.0], [17.0, 13.0]]).sqrt() 22 | expected_assignments = torch.as_tensor([0, 0, 1, 1]) 23 | expected_inertias = torch.as_tensor([0.0, 2.0, 0.0, 13.0]) 24 | 25 | assert torch.allclose(distances, expected_distances) 26 | assert torch.all(assignments == expected_assignments) 27 | assert torch.allclose(inertias, expected_inertias) 28 | -------------------------------------------------------------------------------- /docs/_templates/classes/pytorch_module.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. role:: hidden 4 | 5 | {{ name | underline }} 6 | 7 | .. currentmodule:: {{ module }} 8 | 9 | .. autoclass:: {{ name }} 10 | :show-inheritance: 11 | 12 | {% if methods and not name.endswith("Config") %} 13 | .. rubric:: Methods 14 | 15 | .. autosummary:: 16 | :toctree: 17 | :nosignatures: 18 | 19 | {% for item in methods %} 20 | {%- if not item in inherited_members %} 21 | {%- if not item.startswith("_") %} 22 | ~{{ name }}.{{ item }} 23 | {%- endif %} 24 | {%- endif %} 25 | {%- endfor %} 26 | {%- endif %} 27 | 28 | {% if methods and not name.endswith("Config") %} 29 | .. rubric:: Inherited Methods 30 | 31 | .. autosummary:: 32 | :toctree: 33 | :nosignatures: 34 | 35 | {% for item in methods %} 36 | {%- if item in ["load", "save"] %} 37 | ~{{ name }}.{{ item }} 38 | {%- endif %} 39 | {%- endfor %} 40 | {%- endif %} 41 | 42 | {% if attributes %} 43 | .. rubric:: Attributes 44 | 45 | .. autosummary:: 46 | {% for item in attributes %} 47 | {%- if not item in ["training", "T_destination", "dump_patches"] %} 48 | ~{{ name }}.{{ item }} 49 | {%- endif %} 50 | {%- endfor %} 51 | {%- endif %} 52 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | 7 | jobs: 8 | pylint: 9 | name: Pylint Checks 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 13 | uses: actions/checkout@v3 14 | - name: Setup Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: "3.10" 18 | - name: Install poetry 19 | uses: snok/install-poetry@v1 20 | - name: Install project 21 | run: poetry install --only main,pre-commit 22 | - name: Run pylint 23 | run: poetry run pylint **/*.py 24 | 25 | unit-tests: 26 | name: Unit Tests - Python ${{ matrix.python-version }} 27 | runs-on: ubuntu-latest 28 | strategy: 29 | matrix: 30 | python-version: ["3.8", "3.9", "3.10"] 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v3 34 | - name: Setup Python 35 | uses: actions/setup-python@v4 36 | with: 37 | python-version: ${{ matrix.python-version }} 38 | - name: Install poetry 39 | uses: snok/install-poetry@v1 40 | - name: Install project 41 | run: poetry install --only main,testing 42 | - name: Run Pytest 43 | run: poetry run pytest tests 44 | -------------------------------------------------------------------------------- /pycave/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | import lightkit 4 | 5 | # This is taken from PyTorch Lightning and ensures that logging for this package is enabled 6 | _root_logger = logging.getLogger() 7 | _logger = logging.getLogger(__name__) 8 | _logger.setLevel(logging.INFO) 9 | if not _root_logger.hasHandlers(): 10 | _logger.addHandler(logging.StreamHandler()) 11 | _logger.propagate = False 12 | 13 | # This disables most logs generated by PyTorch Lightning 14 | logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) 15 | warnings.filterwarnings( 16 | action="ignore", message=".*Consider increasing the value of the `num_workers` argument.*" 17 | ) 18 | warnings.filterwarnings( 19 | action="ignore", message=".*`LightningModule.configure_optimizers` returned `None`.*" 20 | ) 21 | warnings.filterwarnings( 22 | action="ignore", message=".*`LoggerConnector.gpus_metrics` was deprecated in v1.5.*" 23 | ) 24 | 25 | # We also want to define a function which silences info logs 26 | def set_logging_level(level: int) -> None: 27 | """ 28 | Enables or disables logging for the entire module. By default, logging is enabled. 29 | 30 | Args: 31 | enabled: Whether to enable logging. 32 | """ 33 | _logger.setLevel(level) 34 | lightkit.set_logging_level(level) 35 | 36 | 37 | # Export 38 | __all__ = ["set_logging_level"] 39 | -------------------------------------------------------------------------------- /docs/sites/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | Bayesian Models 5 | --------------- 6 | 7 | .. currentmodule:: pycave.bayes 8 | 9 | Gaussian Mixture 10 | ^^^^^^^^^^^^^^^^ 11 | 12 | .. autosummary:: 13 | :toctree: generated/bayes/gmm 14 | :nosignatures: 15 | :caption: Bayesian Models 16 | 17 | GaussianMixture 18 | 19 | :template: classes/pytorch_module.rst 20 | 21 | ~gmm.GaussianMixtureModel 22 | ~gmm.GaussianMixtureModelConfig 23 | 24 | 25 | Markov Chain 26 | ^^^^^^^^^^^^ 27 | 28 | .. autosummary:: 29 | :toctree: generated/bayes/markov_chain 30 | :nosignatures: 31 | 32 | MarkovChain 33 | 34 | :template: classes/pytorch_module.rst 35 | 36 | ~markov_chain.MarkovChainModel 37 | ~markov_chain.MarkovChainModelConfig 38 | 39 | 40 | Clustering Models 41 | ----------------- 42 | 43 | .. currentmodule:: pycave.clustering 44 | 45 | K-Means 46 | ^^^^^^^ 47 | 48 | .. autosummary:: 49 | :toctree: generated/clustering/kmeans 50 | :nosignatures: 51 | :caption: Clustering Models 52 | 53 | KMeans 54 | 55 | :template: classes/pytorch_module.rst 56 | 57 | ~kmeans.KMeansModel 58 | ~kmeans.KMeansModelConfig 59 | 60 | 61 | Utility Types 62 | ------------- 63 | 64 | .. currentmodule:: pycave 65 | .. autosummary:: 66 | :toctree: generated/types 67 | :nosignatures: 68 | :caption: Types 69 | :template: classes/type_alias.rst 70 | 71 | ~bayes.markov_chain.types.SequenceData 72 | ~bayes.core.CovarianceType 73 | ~bayes.gmm.types.GaussianMixtureInitStrategy 74 | ~clustering.kmeans.types.KMeansInitStrategy 75 | -------------------------------------------------------------------------------- /pycave/bayes/core/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .types import CovarianceType 3 | 4 | 5 | def covariance_dim(covariance_type: CovarianceType) -> int: 6 | """ 7 | Returns the number of dimension of the covariance matrix for a set of components. 8 | 9 | Args: 10 | covariance_type: The type of covariance to obtain the dimension for. 11 | 12 | Returns: 13 | The number of dimensions. 14 | """ 15 | if covariance_type == "full": 16 | return 3 17 | if covariance_type in ("tied", "diag"): 18 | return 2 19 | return 1 20 | 21 | 22 | def covariance_shape( 23 | num_components: int, num_features: int, covariance_type: CovarianceType 24 | ) -> torch.Size: 25 | """ 26 | Returns the expected shape of the covariance matrix for the given number of components with the 27 | provided number of features based on the covariance type. 28 | 29 | Args: 30 | num_components: The number of Normal distributions to describe with the covariance. 31 | num_features: The dimensionality of the Normal distributions. 32 | covariance_type: The type of covariance to use. 33 | 34 | Returns: 35 | The expected size of the tensor representing the covariances. 36 | """ 37 | if covariance_type == "full": 38 | return torch.Size([num_components, num_features, num_features]) 39 | if covariance_type == "tied": 40 | return torch.Size([num_features, num_features]) 41 | if covariance_type == "diag": 42 | return torch.Size([num_components, num_features]) 43 | # covariance_type == "spherical" 44 | return torch.Size([num_components]) 45 | -------------------------------------------------------------------------------- /tests/bayes/core/benchmark_precision_cholesky.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import numpy as np 3 | import torch 4 | from pytest_benchmark.fixture import BenchmarkFixture # type: ignore 5 | from sklearn.mixture._gaussian_mixture import _compute_precision_cholesky # type: ignore 6 | from pycave.bayes.core import cholesky_precision 7 | 8 | 9 | def test_cholesky_precision_spherical(benchmark: BenchmarkFixture): 10 | covars = torch.rand(50) 11 | benchmark(cholesky_precision, covars, "spherical") 12 | 13 | 14 | def test_numpy_cholesky_precision_spherical(benchmark: BenchmarkFixture): 15 | covars = np.random.rand(50) 16 | benchmark(_compute_precision_cholesky, covars, "spherical") # type: ignore 17 | 18 | 19 | # ------------------------------------------------------------------------------------------------- 20 | 21 | 22 | def test_cholesky_precision_tied(benchmark: BenchmarkFixture): 23 | A = torch.randn(10000, 100) 24 | covars = A.t().mm(A) 25 | benchmark(cholesky_precision, covars, "tied") 26 | 27 | 28 | def test_numpy_cholesky_precision_tied(benchmark: BenchmarkFixture): 29 | A = np.random.randn(10000, 100) 30 | covars = np.dot(A.T, A) 31 | benchmark(_compute_precision_cholesky, covars, "tied") # type: ignore 32 | 33 | 34 | # ------------------------------------------------------------------------------------------------- 35 | 36 | 37 | def test_cholesky_precision_full(benchmark: BenchmarkFixture): 38 | A = torch.randn(50, 10000, 100) 39 | covars = A.permute(0, 2, 1).bmm(A) 40 | benchmark(cholesky_precision, covars, "full") 41 | 42 | 43 | def test_numpy_cholesky_precision_full(benchmark: BenchmarkFixture): 44 | A = np.random.randn(50, 10000, 100) 45 | covars = np.matmul(np.transpose(A, (0, 2, 1)), A) 46 | benchmark(_compute_precision_cholesky, covars, "full") # type: ignore 47 | -------------------------------------------------------------------------------- /pycave/utils/lightning_module.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | import pytorch_lightning as pl 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class NonparametricLightningModule(pl.LightningModule, ABC): 9 | """ 10 | A lightning module which sets some defaults for training models with no parameters (i.e. only 11 | buffers that are optimized differently than via gradient descent). 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__() 16 | self.automatic_optimization = False 17 | 18 | # Required parameter to make DDP training work 19 | self.register_parameter("__ddp_dummy__", nn.Parameter(torch.empty(1))) 20 | 21 | def configure_optimizers(self) -> None: 22 | return None 23 | 24 | def training_step(self, batch: torch.Tensor, batch_idx: int) -> None: 25 | self.nonparametric_training_step(batch, batch_idx) 26 | 27 | def training_epoch_end(self, outputs: List[torch.Tensor]) -> None: 28 | self.nonparametric_training_epoch_end() 29 | 30 | @abstractmethod 31 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None: 32 | """ 33 | Training step that is not allowed to return any value. 34 | """ 35 | 36 | def nonparametric_training_epoch_end(self) -> None: 37 | """ 38 | Training epoch end that is not passed any outputs. 39 | 40 | Does nothing by default. 41 | """ 42 | 43 | def all_gather_first(self, x: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Gathers the provided tensor from all processes. 46 | 47 | If more than one process is available, chooses the value of the first process in every 48 | process. 49 | """ 50 | gathered = self.all_gather(x) 51 | if gathered.dim() > x.dim(): 52 | return gathered[0] 53 | return x 54 | -------------------------------------------------------------------------------- /pycave/bayes/markov_chain/types.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | import numpy as np 3 | import numpy.typing as npt 4 | import torch 5 | from torch.nn.utils.rnn import pack_sequence, PackedSequence 6 | from torch.utils.data import Dataset 7 | 8 | SequenceData = Union[ 9 | npt.NDArray[np.float32], 10 | torch.Tensor, 11 | Dataset[torch.Tensor], 12 | ] 13 | SequenceData.__doc__ = """ 14 | Data that may be passed to estimators expecting 1-D sequences. Data may be provided in multiple 15 | formats: 16 | 17 | - NumPy array of shape ``[num_sequences, sequence_length]``. 18 | - PyTorch tensor of shape ``[num_sequences, sequence_length]``. 19 | - PyTorch dataset yielding items of shape ``[sequence_length]`` where the sequence length may 20 | differ for different indices. 21 | """ 22 | 23 | 24 | def collate_sequences_same_length(data: Tuple[torch.Tensor]) -> PackedSequence: 25 | """ 26 | Collates the provided sequences into a packed sequence. Each sequence has to have the same 27 | length. 28 | 29 | Args: 30 | data: A single tensor of shape ``[num_sequences, sequence_length]`` where each item has the 31 | same length. 32 | 33 | Returns: 34 | A packed sequence containing all sequences. 35 | """ 36 | (sequences,) = data 37 | num_sequences, sequence_length = sequences.size() 38 | batch_sizes = torch.ones(sequence_length, dtype=torch.long) * num_sequences 39 | return PackedSequence(sequences.t().flatten(), batch_sizes) 40 | 41 | 42 | def collate_sequences(sequences: List[torch.Tensor]) -> PackedSequence: 43 | """ 44 | Collates the sequences provided as a list into a packed sequence. The sequences are not 45 | required to be sorted by their lengths. 46 | 47 | Args: 48 | sequences: A list of one-dimensional tensors to batch. 49 | 50 | Returns: 51 | A packed sequence with all the data provided. 52 | """ 53 | return pack_sequence(sequences, enforce_sorted=False) 54 | -------------------------------------------------------------------------------- /pycave/bayes/markov_chain/lightning_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import PackedSequence 3 | from torchmetrics import MeanMetric 4 | from pycave.bayes.markov_chain.metrics import StateCountAggregator 5 | from pycave.utils import NonparametricLightningModule 6 | from .model import MarkovChainModel 7 | 8 | 9 | class MarkovChainLightningModule(NonparametricLightningModule): 10 | """ 11 | Lightning module for training and evaluating a Markov chain. 12 | """ 13 | 14 | def __init__(self, model: MarkovChainModel, symmetric: bool = False): 15 | """ 16 | Args: 17 | model: The model to train or evaluate. 18 | symmetric: Whether transition probabilities should be symmetric. 19 | """ 20 | super().__init__() 21 | 22 | self.model = model 23 | self.symmetric = symmetric 24 | 25 | self.aggregator = StateCountAggregator( 26 | num_states=self.model.config.num_states, 27 | symmetric=self.symmetric, 28 | dist_sync_fn=self.all_gather, 29 | ) 30 | self.metric_nll = MeanMetric(dist_sync_fn=self.all_gather) 31 | 32 | def on_train_epoch_start(self) -> None: 33 | self.aggregator.reset() 34 | 35 | def nonparametric_training_step(self, batch: PackedSequence, _batch_idx: int) -> None: 36 | self.aggregator.update(batch) 37 | 38 | def nonparametric_training_epoch_end(self) -> None: 39 | initial_probs, transition_probs = self.aggregator.compute() 40 | self.model.initial_probs.copy_(initial_probs) 41 | self.model.transition_probs.copy_(transition_probs) 42 | 43 | def test_step(self, batch: PackedSequence, _batch_idx: int) -> None: 44 | log_probs = self.model.forward(batch) 45 | self.metric_nll.update(-log_probs) 46 | self.log("nll", self.metric_nll) 47 | 48 | def predict_step( # pylint: disable=signature-differs 49 | self, batch: PackedSequence, batch_idx: int 50 | ) -> torch.Tensor: 51 | return -self.model(batch) 52 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Generate Documentation 2 | on: 3 | release: 4 | types: [published] 5 | push: 6 | branches: [main] 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | name: Build 12 | runs-on: ubuntu-latest 13 | container: 14 | image: python:3.8-buster 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v3 18 | - name: Install Enchant 19 | run: apt-get update && apt-get install -y enchant 20 | - name: Install poetry 21 | uses: snok/install-poetry@v1 22 | - name: Install Dependencies 23 | run: poetry install --only main,docs 24 | - name: Fix Disutils 25 | run: poetry run pip install setuptools==59.5.0 26 | - name: Check Spelling 27 | run: poetry run sphinx-build -W -b spelling docs build 28 | - name: Generate HTML 29 | run: poetry run sphinx-build -W -b html docs build 30 | - name: Store Artifacts 31 | uses: actions/upload-artifact@v3 32 | with: 33 | name: html 34 | path: build 35 | 36 | deploy: 37 | name: Publish 38 | runs-on: ubuntu-latest 39 | if: ${{ github.event_name == 'release' }} 40 | needs: 41 | - build 42 | steps: 43 | - name: Retrieve Artifacts 44 | uses: actions/download-artifact@v3 45 | with: 46 | name: html 47 | path: build 48 | - name: Configure AWS Credentials 49 | uses: aws-actions/configure-aws-credentials@v1 50 | with: 51 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 52 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 53 | aws-region: ${{ secrets.AWS_REGION }} 54 | - name: Deploy to S3 55 | run: aws s3 sync build s3://pycave.borchero.com --delete --acl public-read 56 | - name: Invalidate Cloudfront 57 | run: | 58 | aws cloudfront create-invalidation \ 59 | --distribution-id ${AWS_CLOUDFRONT_DISTRIBUTION} --paths "/*" 60 | env: 61 | AWS_CLOUDFRONT_DISTRIBUTION: ${{ secrets.AWS_CLOUDFRONT_DISTRIBUTION }} 62 | -------------------------------------------------------------------------------- /tests/bayes/markov_chain/test_markov_chain_model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import math 3 | import torch 4 | from torch import jit 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | from pycave.bayes.markov_chain import MarkovChainModel, MarkovChainModelConfig 7 | 8 | 9 | def test_compile(): 10 | config = MarkovChainModelConfig(num_states=2) 11 | model = MarkovChainModel(config) 12 | jit.script(model) 13 | 14 | 15 | def test_forward_tensor(): 16 | model = _get_default_model() 17 | sequences = torch.as_tensor([[1, 0, 0, 1], [0, 1, 1, 1]]) 18 | expected = torch.as_tensor( 19 | [ 20 | math.log(0.8) + math.log(0.1) + math.log(0.5) + math.log(0.5), 21 | math.log(0.2) + math.log(0.5) + math.log(0.9) + math.log(0.9), 22 | ] 23 | ) 24 | assert torch.allclose(expected, model(sequences)) 25 | 26 | 27 | def test_forward_packed_sequence(): 28 | model = _get_default_model() 29 | sequences = torch.as_tensor([[1, 0, 0, 1], [0, 1, 1, -1]]) 30 | packed_sequences = pack_padded_sequence(sequences.t(), torch.Tensor([4, 3])) 31 | expected = torch.as_tensor( 32 | [ 33 | math.log(0.8) + math.log(0.1) + math.log(0.5) + math.log(0.5), 34 | math.log(0.2) + math.log(0.5) + math.log(0.9), 35 | ] 36 | ) 37 | assert torch.allclose(expected, model(packed_sequences)) 38 | 39 | 40 | def test_sample(): 41 | torch.manual_seed(42) 42 | model = _get_default_model() 43 | n = 100000 44 | samples = model.sample(n, 3) 45 | assert math.isclose((samples[:, 0] == 0).sum() / n, 0.2, abs_tol=0.01) 46 | 47 | 48 | def test_stationary_distribution(): 49 | model = _get_default_model() 50 | tol = 1e-7 51 | sd = model.stationary_distribution(tol=tol) 52 | assert math.isclose(sd[0].item(), 1 / 6, abs_tol=tol) 53 | assert math.isclose(sd[1].item(), 5 / 6, abs_tol=tol) 54 | 55 | 56 | # ------------------------------------------------------------------------------------------------- 57 | 58 | 59 | def _get_default_model() -> MarkovChainModel: 60 | config = MarkovChainModelConfig(num_states=2) 61 | model = MarkovChainModel(config) 62 | 63 | model.initial_probs.copy_(torch.as_tensor([0.2, 0.8])) 64 | model.transition_probs.copy_(torch.as_tensor([[0.5, 0.5], [0.1, 0.9]])) 65 | return model 66 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=all 2 | from __future__ import annotations 3 | import datetime 4 | import os 5 | import sys 6 | from typing import Any 7 | 8 | filepath = os.path.abspath(os.path.dirname(__file__)) 9 | sys.path.insert(0, os.path.join(filepath, "..")) 10 | 11 | # ------------------------------------------------------------------------------------------------- 12 | # BASICS 13 | 14 | project = "PyCave" 15 | copyright = f"{datetime.datetime.now().year}, Oliver Borchert" 16 | 17 | # ------------------------------------------------------------------------------------------------- 18 | # PLUGINS 19 | 20 | extensions = [ 21 | "sphinx.ext.autodoc", 22 | "sphinx.ext.autosummary", 23 | "sphinx.ext.intersphinx", 24 | "sphinx.ext.napoleon", 25 | "sphinx.ext.viewcode", 26 | "sphinx_autodoc_typehints", 27 | "sphinx_automodapi.smart_resolver", 28 | "sphinx_copybutton", 29 | ] 30 | if os.uname().machine != "arm64": 31 | extensions.append("sphinxcontrib.spelling") 32 | templates_path = ["_templates"] 33 | 34 | # ------------------------------------------------------------------------------------------------- 35 | # CONFIGURATION 36 | 37 | html_theme = "pydata_sphinx_theme" 38 | pygments_style = "lovelace" 39 | html_theme_options = { 40 | "show_prev_next": False, 41 | "github_url": "https://github.com/borchero/pycave", 42 | } 43 | html_logo = "_static/logo.svg" 44 | html_favicon = "_static/favicon.ico" 45 | html_permalinks = True 46 | 47 | autosummary_generate = True 48 | autosummary_imported_members = True 49 | autodoc_member_order = "groupwise" 50 | autodoc_type_aliases = { 51 | "CovarianceType": ":class:`~pycave.bayes.core.CovarianceType`", 52 | "SequenceData": ":class:`~pycave.data.SequenceData`", 53 | "TabularData": ":class:`~pycave.data.TabularData`", 54 | "GaussianMixtureInitStrategy": ":class:`~pycave.bayes.gmm.types.GaussianMixtureInitStrategy`", 55 | "KMeansInitStrategy": ":class:`~pycave.clustering.kmeans.types.KMeansInitStrategy`", 56 | } 57 | autoclass_content = "both" 58 | 59 | simplify_optional_unions = False 60 | 61 | spelling_lang = "en_US" 62 | spelling_word_list_filename = "spelling_wordlist.txt" 63 | 64 | intersphinx_mapping = { 65 | "python": ("https://docs.python.org/3", None), 66 | "torch": ("https://pytorch.org/docs/stable/", None), 67 | "numpy": ("https://numpy.org/doc/stable/", None), 68 | "pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None), 69 | } 70 | -------------------------------------------------------------------------------- /tests/bayes/markov_chain/test_markov_chain_estimator.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import math 3 | from typing import Tuple 4 | import pytest 5 | import torch 6 | from pycave.bayes import MarkovChain 7 | 8 | 9 | def test_fit_automatic_config(): 10 | chain = MarkovChain() 11 | data = torch.randint(50, size=(100, 20)) 12 | chain.fit(data) 13 | assert chain.model_.config.num_states == 50 14 | 15 | 16 | @pytest.mark.flaky(max_runs=3, min_passes=1) 17 | def test_sample_and_fit(): 18 | chain = MarkovChain(2) 19 | initial_probs, transition_probs = _set_probs(chain) 20 | sample = chain.sample(1000000, 10) 21 | 22 | new = MarkovChain(2) 23 | new.fit(sample) 24 | 25 | assert torch.allclose(initial_probs, new.model_.initial_probs, atol=1e-3) 26 | assert torch.allclose(transition_probs, new.model_.transition_probs, atol=1e-3) 27 | 28 | 29 | def test_score(): 30 | chain = MarkovChain(2) 31 | test_data, expected = _set_sample_data(chain) 32 | actual = chain.score(test_data) 33 | assert math.isclose(actual, -expected.mean()) 34 | 35 | 36 | def test_score_samples(): 37 | chain = MarkovChain(2) 38 | test_data, expected = _set_sample_data(chain) 39 | actual = chain.score_samples(test_data) 40 | assert torch.allclose(actual, -expected) 41 | 42 | 43 | # ------------------------------------------------------------------------------------------------- 44 | 45 | 46 | def _set_probs(chain: MarkovChain) -> Tuple[torch.Tensor, torch.Tensor]: 47 | data = torch.randint(2, size=(100, 20)) 48 | chain.fit(data) 49 | 50 | initial_probs = torch.as_tensor([0.8, 0.2]) 51 | chain.model_.initial_probs.copy_(initial_probs) 52 | 53 | transition_probs = torch.as_tensor([[0.5, 0.5], [0.1, 0.9]]) 54 | chain.model_.transition_probs.copy_(transition_probs) 55 | 56 | return initial_probs, transition_probs 57 | 58 | 59 | def _set_sample_data(chain: MarkovChain) -> Tuple[torch.Tensor, torch.Tensor]: 60 | _set_probs(chain) 61 | 62 | test_data = torch.as_tensor( 63 | [ 64 | [1, 1, 0, 1], 65 | [0, 1, 0, 1], 66 | [0, 0, 1, 1], 67 | ] 68 | ) 69 | expected = torch.as_tensor( 70 | [ 71 | math.log(0.2) + math.log(0.9) + math.log(0.1) + math.log(0.5), 72 | math.log(0.8) + math.log(0.5) + math.log(0.1) + math.log(0.5), 73 | math.log(0.8) + math.log(0.5) + math.log(0.5) + math.log(0.9), 74 | ] 75 | ) 76 | return test_data, expected 77 | -------------------------------------------------------------------------------- /pycave/bayes/markov_chain/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, cast, Optional, Tuple 2 | import torch 3 | from torch.nn.utils.rnn import PackedSequence 4 | from torchmetrics import Metric 5 | 6 | 7 | class StateCountAggregator(Metric): 8 | """ 9 | The state count aggregator aggregates initial states and transitions between states. 10 | """ 11 | 12 | full_state_update = False 13 | 14 | def __init__( 15 | self, 16 | num_states: int, 17 | symmetric: bool, 18 | *, 19 | dist_sync_fn: Optional[Callable[[Any], Any]] = None, 20 | ): 21 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 22 | 23 | self.num_states = num_states 24 | self.symmetric = symmetric 25 | 26 | self.initial_counts: torch.Tensor 27 | self.add_state("initial_counts", torch.zeros(num_states), dist_reduce_fx="sum") 28 | 29 | self.transition_counts: torch.Tensor 30 | self.add_state( 31 | "transition_counts", torch.zeros(num_states, num_states).view(-1), dist_reduce_fx="sum" 32 | ) 33 | 34 | def update(self, sequences: PackedSequence) -> None: 35 | batch_sizes = cast(torch.Tensor, sequences.batch_sizes) 36 | num_sequences = batch_sizes[0] 37 | data = cast(torch.Tensor, sequences.data) 38 | 39 | # First, we count the initial states 40 | initial_counts = torch.bincount(data[:num_sequences], minlength=self.num_states).float() 41 | self.initial_counts.add_(initial_counts) 42 | 43 | # Then, we count the transitions 44 | offset = 0 45 | for prev_size, size in zip(batch_sizes, batch_sizes[1:]): 46 | sources = data[offset : offset + size] 47 | targets = data[offset + prev_size : offset + prev_size + size] 48 | transitions = sources * self.num_states + targets 49 | values = torch.ones_like(transitions, dtype=torch.float) 50 | self.transition_counts.scatter_add_(0, transitions, values) 51 | offset += prev_size 52 | 53 | def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: 54 | initial_probs = self.initial_counts / self.initial_counts.sum() 55 | transition_counts = self.transition_counts.view(self.num_states, self.num_states) 56 | 57 | if self.symmetric: 58 | self.transition_counts.add_(transition_counts.t()) 59 | transition_probs = transition_counts / transition_counts.sum(1, keepdim=True) 60 | 61 | return initial_probs, transition_probs 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | authors = ["Oliver Borchert "] 3 | classifiers = [ 4 | "Development Status :: 5 - Production/Stable", 5 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 6 | ] 7 | description = "Traditional Machine Learning Models in PyTorch." 8 | documentation = "https://pycave.borchero.com" 9 | license = "MIT" 10 | name = "pycave" 11 | readme = "README.md" 12 | repository = "https://github.com/borchero/pycave" 13 | version = "0.0.0" 14 | 15 | [tool.poetry.dependencies] 16 | lightkit = "^0.5.0" 17 | numpy = "^1.20.3" 18 | python = ">=3.8,<3.11" 19 | pytorch-lightning = "^1.6.0" 20 | torch = "^1.8.0" 21 | torchmetrics = ">=0.6,<0.12" 22 | 23 | [tool.poetry.group.pre-commit.dependencies] 24 | black = "^22.12.0" 25 | docformatter = "^1.5.0" 26 | isort = "^5.10.1" 27 | mypy = "^0.991" 28 | pylint = "^2.12.2" 29 | pyupgrade = "^3.3.1" 30 | 31 | [tool.poetry.group.docs.dependencies] 32 | Sphinx = "^5.0.0" 33 | pydata-sphinx-theme = "^0.7.2" 34 | scanpydoc = "^0.7.1" 35 | sphinx-autodoc-typehints = "^1.12.0" 36 | sphinx-automodapi = "^0.13" 37 | sphinx-copybutton = "^0.3.3" 38 | sphinxcontrib-spelling = "^7.2.1" 39 | 40 | [tool.poetry.group.testing.dependencies] 41 | flaky = "^3.7.0" 42 | pytest = "^6.2.4" 43 | pytest-benchmark = "^3.4.1" 44 | scikit-learn = "^0.24.2" 45 | 46 | [tool.poetry.group.dev.dependencies] 47 | jupyter = "^1.0.0" 48 | 49 | [build-system] 50 | build-backend = "poetry.core.masonry.api" 51 | requires = ["poetry-core>=1.0.0"] 52 | 53 | [tool.pylint.messages_control] 54 | disable = [ 55 | "arguments-differ", 56 | "duplicate-code", 57 | "missing-module-docstring", 58 | "invalid-name", 59 | "too-few-public-methods", 60 | "too-many-ancestors", 61 | "too-many-arguments", 62 | "too-many-branches", 63 | "too-many-locals", 64 | "too-many-instance-attributes", 65 | ] 66 | 67 | [tool.pylint.typecheck] 68 | generated-members = [ 69 | "torch.*", 70 | ] 71 | 72 | [tool.black] 73 | line-length = 99 74 | target-version = ["py38", "py39", "py310"] 75 | 76 | [tool.isort] 77 | force_alphabetical_sort_within_sections = true 78 | include_trailing_comma = true 79 | known_first_party = "pycave,tests" 80 | line_length = 99 81 | lines_between_sections = 0 82 | profile = "black" 83 | skip_gitignore = true 84 | 85 | [tool.docformatter] 86 | make-summary-multi-line = true 87 | pre-summary-newline = true 88 | recursive = true 89 | wrap-descriptions = 99 90 | wrap-summaries = 99 91 | 92 | [tool.pytest.ini_options] 93 | filterwarnings = [ 94 | "ignore:.*Create unlinked descriptors is going to go away.*:DeprecationWarning", 95 | "ignore:.*this fit will run with no optimizer.*", 96 | "ignore:.*Consider increasing the value of the `num_workers` argument.*", 97 | ] 98 | -------------------------------------------------------------------------------- /pycave/clustering/kmeans/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | import torch 4 | from lightkit.nn import Configurable 5 | from torch import jit, nn 6 | 7 | 8 | @dataclass 9 | class KMeansModelConfig: 10 | """ 11 | Configuration class for a K-Means model. 12 | 13 | See also: 14 | :class:`KMeansModel` 15 | """ 16 | 17 | #: The number of clusters. 18 | num_clusters: int 19 | #: The number of features of each cluster. 20 | num_features: int 21 | 22 | 23 | class KMeansModel(Configurable[KMeansModelConfig], nn.Module): 24 | """ 25 | PyTorch module for the K-Means model. 26 | 27 | The centroids managed by this model are non-trainable parameters. 28 | """ 29 | 30 | def __init__(self, config: KMeansModelConfig): 31 | """ 32 | Args: 33 | config: The configuration to use for initializing the module's buffers. 34 | """ 35 | super().__init__(config) 36 | 37 | #: The centers of all clusters, buffer of shape ``[num_clusters, num_features].`` 38 | self.centroids: torch.Tensor 39 | self.register_buffer("centroids", torch.empty(config.num_clusters, config.num_features)) 40 | 41 | self.reset_parameters() 42 | 43 | @jit.unused 44 | def reset_parameters(self) -> None: 45 | """ 46 | Resets the parameters of the KMeans model. 47 | 48 | It samples all cluster centers from a standard Normal. 49 | """ 50 | nn.init.normal_(self.centroids) 51 | 52 | def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 53 | """ 54 | Computes the distance of each datapoint to each centroid as well as the "inertia", the 55 | squared distance of each datapoint to its closest centroid. 56 | 57 | Args: 58 | data: A tensor of shape ``[num_datapoints, num_features]`` for which to compute the 59 | distances and inertia. 60 | 61 | Returns: 62 | - A tensor of shape ``[num_datapoints, num_centroids]`` with the distance from each 63 | datapoint to each centroid. 64 | - A tensor of shape ``[num_datapoints]`` with the assignments, i.e. the indices of 65 | each datapoint's closest centroid. 66 | - A tensor of shape ``[num_datapoints]`` with the inertia (squared distance to the 67 | closest centroid) of each datapoint. 68 | """ 69 | distances = torch.cdist(data, self.centroids) 70 | assignments = distances.min(1, keepdim=True).indices 71 | inertias = distances.gather(1, assignments).square() 72 | return distances, assignments.squeeze(1), inertias.squeeze(1) 73 | -------------------------------------------------------------------------------- /tests/clustering/kmeans/test_kmeans_estimator.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import math 3 | from typing import Optional 4 | import pytest 5 | import torch 6 | from sklearn.cluster import KMeans as SklearnKMeans # type: ignore 7 | from pycave.clustering import KMeans 8 | from tests._data.gmm import sample_gmm 9 | 10 | 11 | def test_fit_automatic_config(): 12 | estimator = KMeans(4) 13 | data = torch.cat([torch.randn(1000, 3) * 0.1 - 100, torch.randn(1000, 3) * 0.1 + 100]) 14 | estimator.fit(data) 15 | assert estimator.model_.config.num_clusters == 4 16 | assert estimator.model_.config.num_features == 3 17 | 18 | 19 | def test_fit_num_iter(): 20 | # The k-means++ iterations should find the centroids. Afterwards, it should only take a single 21 | # epoch until the centroids do not change anymore. 22 | data = torch.cat([torch.randn(1000, 4) * 0.1 - 100, torch.randn(1000, 4) * 0.1 + 100]) 23 | 24 | estimator = KMeans(2) 25 | estimator.fit(data) 26 | 27 | assert estimator.num_iter_ == 1 28 | 29 | 30 | @pytest.mark.flaky(max_runs=2, min_passes=1) 31 | @pytest.mark.parametrize( 32 | ("num_epochs", "converged"), 33 | [(100, True), (1, False)], 34 | ) 35 | def test_fit_converged(num_epochs: int, converged: bool): 36 | data, _ = sample_gmm( 37 | num_datapoints=10000, 38 | num_features=8, 39 | num_components=4, 40 | covariance_type="spherical", 41 | ) 42 | 43 | estimator = KMeans(4, trainer_params=dict(max_epochs=num_epochs)) 44 | estimator.fit(data) 45 | 46 | assert estimator.converged_ == converged 47 | 48 | 49 | @pytest.mark.flaky(max_runs=5, min_passes=1) 50 | @pytest.mark.parametrize( 51 | ("num_datapoints", "batch_size", "num_features", "num_centroids"), 52 | [ 53 | (10000, None, 8, 4), 54 | (10000, 1000, 8, 4), 55 | ], 56 | ) 57 | def test_fit_inertia( 58 | num_datapoints: int, 59 | batch_size: Optional[int], 60 | num_features: int, 61 | num_centroids: int, 62 | ): 63 | data, _ = sample_gmm( 64 | num_datapoints=num_datapoints, 65 | num_features=num_features, 66 | num_components=num_centroids, 67 | covariance_type="spherical", 68 | ) 69 | 70 | # Ours 71 | estimator = KMeans( 72 | num_centroids, 73 | batch_size=batch_size, 74 | trainer_params=dict(precision=64), 75 | ) 76 | ours_inertia = float("inf") 77 | for _ in range(10): 78 | ours_inertia = min(ours_inertia, estimator.fit(data).score(data)) 79 | 80 | # Sklearn 81 | gmm = SklearnKMeans(num_centroids, n_init=10) 82 | sklearn_inertia = gmm.fit(data.numpy()).score(data.numpy()) 83 | 84 | assert math.isclose(ours_inertia, -sklearn_inertia / data.size(0), rel_tol=0.01, abs_tol=0.01) 85 | -------------------------------------------------------------------------------- /tests/bayes/gmm/test_gmm_estimator.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import math 3 | from typing import Optional 4 | import pytest 5 | import torch 6 | from sklearn.mixture import GaussianMixture as SklearnGaussianMixture # type: ignore 7 | from pycave.bayes import GaussianMixture 8 | from pycave.bayes.core import CovarianceType 9 | from tests._data.gmm import sample_gmm 10 | 11 | 12 | def test_fit_model_config(): 13 | estimator = GaussianMixture() 14 | data = torch.randn(1000, 4) 15 | estimator.fit(data) 16 | 17 | assert estimator.model_.config.num_components == 1 18 | assert estimator.model_.config.num_features == 4 19 | 20 | 21 | @pytest.mark.parametrize("batch_size", [2, None]) 22 | def test_fit_num_iter(batch_size: Optional[int]): 23 | # For the following data, K-means will find centroids [0.5, 3.5]. The estimator first computes 24 | # the NLL (first iteration), afterwards there is no improvmement in the NLL (second iteration). 25 | data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]]) 26 | estimator = GaussianMixture( 27 | 2, 28 | batch_size=batch_size, 29 | trainer_params=dict(precision=64), 30 | ) 31 | estimator.fit(data) 32 | 33 | assert estimator.num_iter_ == 2 34 | 35 | 36 | @pytest.mark.flaky(max_runs=3, min_passes=1) 37 | @pytest.mark.parametrize( 38 | ("batch_size", "max_epochs", "converged"), 39 | [(2, 1, False), (2, 3, True), (None, 1, False), (None, 3, True)], 40 | ) 41 | def test_fit_converged(batch_size: Optional[int], max_epochs: int, converged: bool): 42 | data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]]) 43 | 44 | estimator = GaussianMixture( 45 | 2, 46 | batch_size=batch_size, 47 | trainer_params=dict(precision=64, max_epochs=max_epochs), 48 | ) 49 | estimator.fit(data) 50 | assert estimator.converged_ == converged 51 | 52 | 53 | @pytest.mark.flaky(max_runs=25, min_passes=1) 54 | @pytest.mark.parametrize( 55 | ("num_datapoints", "batch_size", "num_features", "num_components", "covariance_type"), 56 | [ 57 | (10000, 10000, 4, 4, "spherical"), 58 | (10000, 10000, 4, 4, "diag"), 59 | (10000, 10000, 4, 4, "tied"), 60 | (10000, 10000, 4, 4, "full"), 61 | (10000, 1000, 4, 4, "spherical"), 62 | (10000, 1000, 4, 4, "diag"), 63 | (10000, 1000, 4, 4, "tied"), 64 | (10000, 1000, 4, 4, "full"), 65 | ], 66 | ) 67 | def test_fit_nll( 68 | num_datapoints: int, 69 | batch_size: int, 70 | num_features: int, 71 | num_components: int, 72 | covariance_type: CovarianceType, 73 | ): 74 | data, _ = sample_gmm( 75 | num_datapoints=num_datapoints, 76 | num_features=num_features, 77 | num_components=num_components, 78 | covariance_type=covariance_type, 79 | ) 80 | 81 | # Ours 82 | estimator = GaussianMixture( 83 | num_components, 84 | covariance_type=covariance_type, 85 | batch_size=batch_size, 86 | trainer_params=dict(precision=64), 87 | ) 88 | ours_nll = estimator.fit(data).score(data) 89 | 90 | # Sklearn 91 | gmm = SklearnGaussianMixture(num_components, covariance_type=covariance_type) 92 | sklearn_nll = gmm.fit(data.numpy()).score(data.numpy()) 93 | 94 | assert math.isclose(ours_nll, -sklearn_nll, rel_tol=0.01, abs_tol=0.01) 95 | -------------------------------------------------------------------------------- /pycave/bayes/core/_jit.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import math 3 | import torch 4 | 5 | 6 | def jit_log_normal( 7 | x: torch.Tensor, 8 | means: torch.Tensor, 9 | precisions_cholesky: torch.Tensor, 10 | covariance_type: str, 11 | ) -> torch.Tensor: 12 | if covariance_type == "full": 13 | # Precision shape is `[num_components, dim, dim]`. 14 | log_prob = x.new_empty((x.size(0), means.size(0))) 15 | # We loop here to not blow up the size of intermediate matrices 16 | for k, (mu, prec_chol) in enumerate(zip(means, precisions_cholesky)): 17 | inner = x.matmul(prec_chol) - mu.matmul(prec_chol) 18 | log_prob[:, k] = inner.square().sum(1) 19 | elif covariance_type == "tied": 20 | # Precision shape is `[dim, dim]`. 21 | a = x.matmul(precisions_cholesky) # [N, D] 22 | b = means.matmul(precisions_cholesky) # [K, D] 23 | log_prob = (a.unsqueeze(1) - b).square().sum(-1) 24 | else: 25 | precisions = precisions_cholesky.square() 26 | if covariance_type == "diag": 27 | # Precision shape is `[num_components, dim]`. 28 | x_prob = torch.matmul(x * x, precisions.t()) 29 | m_prob = torch.einsum("ij,ij,ij->i", means, means, precisions) 30 | xm_prob = torch.matmul(x, (means * precisions).t()) 31 | else: # covariance_type == "spherical" 32 | # Precision shape is `[num_components]` 33 | x_prob = torch.ger(torch.einsum("ij,ij->i", x, x), precisions) 34 | m_prob = torch.einsum("ij,ij->i", means, means) * precisions 35 | xm_prob = torch.matmul(x, means.t() * precisions) 36 | 37 | log_prob = x_prob - 2 * xm_prob + m_prob 38 | 39 | num_features = x.size(1) 40 | logdet = _cholesky_logdet(num_features, precisions_cholesky, covariance_type) 41 | constant = math.log(2 * math.pi) * num_features 42 | return logdet - 0.5 * (constant + log_prob) 43 | 44 | 45 | def _cholesky_logdet( 46 | num_features: int, 47 | precisions_cholesky: torch.Tensor, 48 | covariance_type: str, 49 | ) -> torch.Tensor: 50 | if covariance_type == "full": 51 | return precisions_cholesky.diagonal(dim1=-2, dim2=-1).log().sum(-1) 52 | if covariance_type == "tied": 53 | return precisions_cholesky.diagonal().log().sum(-1) 54 | if covariance_type == "diag": 55 | return precisions_cholesky.log().sum(1) 56 | # covariance_type == "spherical" 57 | return precisions_cholesky.log() * num_features 58 | 59 | 60 | # ------------------------------------------------------------------------------------------------- 61 | 62 | 63 | def jit_sample_normal( 64 | num: int, 65 | mean: torch.Tensor, 66 | cholesky_precisions: torch.Tensor, 67 | covariance_type: str, 68 | ) -> torch.Tensor: 69 | samples = torch.randn(num, mean.size(0), dtype=mean.dtype, device=mean.device) 70 | chol_covariance = _cholesky_covariance(cholesky_precisions, covariance_type) 71 | 72 | if covariance_type in ("tied", "full"): 73 | scale = chol_covariance.matmul(samples.unsqueeze(-1)).squeeze(-1) 74 | else: 75 | scale = chol_covariance * samples 76 | 77 | return mean + scale 78 | 79 | 80 | def _cholesky_covariance(chol_precision: torch.Tensor, covariance_type: str) -> torch.Tensor: 81 | # For complex covariance types, invert the 82 | if covariance_type in ("tied", "full"): 83 | num_features = chol_precision.size(-1) 84 | target = torch.eye(num_features, dtype=chol_precision.dtype, device=chol_precision.device) 85 | return torch.linalg.solve_triangular(chol_precision, target, upper=True).t() 86 | 87 | # Simple covariance type 88 | return chol_precision.reciprocal() 89 | -------------------------------------------------------------------------------- /tests/bayes/gmm/test_gmm_metrics.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=protected-access,missing-function-docstring 2 | from typing import Any, Callable 3 | import numpy as np 4 | import sklearn.mixture._gaussian_mixture as skgmm # type: ignore 5 | import torch 6 | from pycave.bayes.core import CovarianceType 7 | from pycave.bayes.gmm.metrics import CovarianceAggregator, MeanAggregator, PriorAggregator 8 | 9 | 10 | def test_prior_aggregator(): 11 | aggregator = PriorAggregator(3) 12 | aggregator.reset() 13 | 14 | # Step 1: single batch 15 | responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) 16 | actual = aggregator.forward(responsibilities1) 17 | expected = torch.tensor([0.5, 0.3, 0.2]) 18 | assert torch.allclose(actual, expected) 19 | 20 | # Step 2: batch aggregation 21 | responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]]) 22 | aggregator.update(responsibilities2) 23 | actual = aggregator.compute() 24 | expected = torch.tensor([0.54, 0.3, 0.16]) 25 | assert torch.allclose(actual, expected) 26 | 27 | 28 | def test_mean_aggregator(): 29 | aggregator = MeanAggregator(3, 2) 30 | aggregator.reset() 31 | 32 | # Step 1: single batch 33 | data1 = torch.tensor([[5.0, 2.0], [3.0, 4.0], [1.0, 0.0]]) 34 | responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) 35 | actual = aggregator.forward(data1, responsibilities1) 36 | expected = torch.tensor([[2.8667, 2.5333], [2.5556, 1.1111], [4.0, 2.0]]) 37 | assert torch.allclose(actual, expected, atol=1e-4) 38 | 39 | # Step 2: batch aggregation 40 | data2 = torch.tensor([[8.0, 2.5], [1.5, 4.0]]) 41 | responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]]) 42 | aggregator.update(data2, responsibilities2) 43 | actual = aggregator.compute() 44 | expected = torch.tensor([[3.9444, 2.7963], [3.0, 2.0667], [4.1875, 2.3125]]) 45 | assert torch.allclose(actual, expected, atol=1e-4) 46 | 47 | 48 | def test_covariance_aggregator_spherical(): 49 | _test_covariance("spherical", skgmm._estimate_gaussian_covariances_spherical) # type: ignore 50 | 51 | 52 | def test_covariance_aggregator_diag(): 53 | _test_covariance("diag", skgmm._estimate_gaussian_covariances_diag) # type: ignore 54 | 55 | 56 | def test_covariance_aggregator_tied(): 57 | _test_covariance("tied", skgmm._estimate_gaussian_covariances_tied) # type: ignore 58 | 59 | 60 | def test_covariance_aggregator_full(): 61 | _test_covariance("full", skgmm._estimate_gaussian_covariances_full) # type: ignore 62 | 63 | 64 | def _test_covariance( 65 | covariance_type: CovarianceType, 66 | sk_aggregator: Callable[[Any, Any, Any, Any, Any], Any], 67 | ): 68 | reg = 1e-5 69 | aggregator = CovarianceAggregator(3, 2, covariance_type, reg=reg) 70 | aggregator.reset() 71 | means = torch.tensor([[3.0, 2.5], [2.5, 1.0], [4.0, 2.0]]) 72 | 73 | # Step 1: single batch 74 | data1 = torch.tensor([[5.0, 2.0], [3.0, 4.0], [1.0, 0.0]]) 75 | responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) 76 | actual = aggregator.forward(data1, responsibilities1, means) 77 | expected = sk_aggregator( # type: ignore 78 | responsibilities1.numpy(), 79 | data1.numpy(), 80 | responsibilities1.sum(0).numpy(), 81 | means.numpy(), 82 | reg, 83 | ).astype(np.float32) 84 | assert torch.allclose(actual, torch.from_numpy(expected)) 85 | 86 | # Step 2: batch aggregation 87 | data2 = torch.tensor([[8.0, 2.5], [1.5, 4.0]]) 88 | responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]]) 89 | aggregator.update(data2, responsibilities2, means) 90 | actual = aggregator.compute() 91 | expected = sk_aggregator( # type: ignore 92 | torch.cat([responsibilities1, responsibilities2]).numpy(), 93 | torch.cat([data1, data2]).numpy(), 94 | (responsibilities1.sum(0) + responsibilities2.sum(0)).numpy(), 95 | means.numpy(), 96 | reg, 97 | ).astype(np.float32) 98 | assert torch.allclose(actual, torch.from_numpy(expected)) 99 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | PyCave Documentation 2 | ==================== 3 | 4 | PyCave allows you to run traditional machine learning models on CPU, GPU, and even on multiple nodes. All models are implemented in `PyTorch `_ and provide an ``Estimator`` API that is fully compatible with `scikit-learn `_. 5 | 6 | .. image:: https://img.shields.io/pypi/v/pycave?label=version 7 | .. image:: https://img.shields.io/pypi/l/pycave 8 | 9 | 10 | Features 11 | -------- 12 | 13 | - Support for GPU and multi-node training by implementing models in PyTorch and relying on `PyTorch Lightning `_ 14 | - Mini-batch training for all models such that they can be used on huge datasets 15 | - Well-structured implementation of models 16 | 17 | - High-level ``Estimator`` API allows for easy usage such that models feel and behave like in 18 | scikit-learn 19 | - Medium-level ``LightingModule`` implements the training algorithm 20 | - Low-level PyTorch ``Module`` manages the model parameters 21 | 22 | 23 | Installation 24 | ------------ 25 | 26 | PyCave is available via ``pip``: 27 | 28 | .. code-block:: python 29 | 30 | pip install pycave 31 | 32 | If you are using `Poetry `_: 33 | 34 | .. code-block:: python 35 | 36 | poetry add pycave 37 | 38 | 39 | Usage 40 | ----- 41 | 42 | If you've ever used scikit-learn, you'll feel right at home when using PyCave. First, let's create 43 | some artificial data to work with: 44 | 45 | .. code-block:: python 46 | 47 | import torch 48 | 49 | X = torch.cat([ 50 | torch.randn(10000, 8) - 5, 51 | torch.randn(10000, 8), 52 | torch.randn(10000, 8) + 5, 53 | ]) 54 | 55 | This dataset consists of three clusters with 8-dimensional datapoints. If you want to fit a K-Means 56 | model, to find the clusters' centroids, it's as easy as: 57 | 58 | .. code-block:: python 59 | 60 | from pycave.clustering import KMeans 61 | 62 | estimator = KMeans(3) 63 | estimator.fit(X) 64 | 65 | # Once the estimator is fitted, it provides various properties. One of them is 66 | # the `model_` property which yields the PyTorch module with the fitted parameters. 67 | print("Centroids are:") 68 | print(estimator.model_.centroids) 69 | 70 | Due to the high-level estimator API, the usage for all machine learning models is similar. The API 71 | documentation provides more detailed information about parameters that can be passed to estimators 72 | and which methods are available. 73 | 74 | GPU and Multi-Node training 75 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 76 | 77 | For GPU- and multi-node training, PyCave leverages PyTorch Lightning. The hardware that training 78 | runs on is determined by the :class:`pytorch_lightning.trainer.Trainer` class. It's 79 | :meth:`~pytorch_lightning.trainer.Trainer.__init__` method provides various configuration options. 80 | 81 | If you want to run K-Means with a GPU, you can pass the option ``accelerator='gpu'`` and ``devices=1`` to the estimator's 82 | initializer: 83 | 84 | .. code-block:: python 85 | 86 | estimator = KMeans(3, trainer_params=dict(accelerator='gpu', devices=1)) 87 | 88 | Similarly, if you want to train on 4 nodes simultaneously where each node has one GPU available, 89 | you can specify this as follows: 90 | 91 | .. code-block:: python 92 | 93 | estimator = KMeans(3, trainer_params=dict(num_nodes=4, accelerator='gpu', 1)) 94 | 95 | In fact, **you do not need to change anything else in your code**. 96 | 97 | 98 | Implemented Models 99 | ^^^^^^^^^^^^^^^^^^ 100 | 101 | Currently, PyCave implements three different models. Some of these models are also available in 102 | scikit-learn. In this case, we benchmark our implementation against their (see 103 | :doc:`here `). 104 | 105 | .. currentmodule:: pycave 106 | 107 | .. autosummary:: 108 | :nosignatures: 109 | 110 | ~bayes.GaussianMixture 111 | ~bayes.MarkovChain 112 | ~clustering.KMeans 113 | 114 | Reference 115 | --------- 116 | 117 | .. toctree:: 118 | :maxdepth: 2 119 | 120 | sites/benchmark 121 | sites/api 122 | 123 | Index 124 | ^^^^^ 125 | 126 | - :ref:`genindex` 127 | -------------------------------------------------------------------------------- /tests/clustering/kmeans/benchmark_kmeans_estimator.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | from typing import Optional 3 | import pytest 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytest_benchmark.fixture import BenchmarkFixture # type: ignore 7 | from sklearn.cluster import KMeans as SklearnKMeans # type: ignore 8 | from pycave.clustering import KMeans 9 | from pycave.clustering.kmeans.types import KMeansInitStrategy 10 | from tests._data.gmm import sample_gmm 11 | 12 | 13 | @pytest.mark.parametrize( 14 | ("num_datapoints", "num_features", "num_centroids", "init_strategy"), 15 | [ 16 | (10000, 8, 4, "k-means++"), 17 | (100000, 32, 16, "k-means++"), 18 | (1000000, 64, 64, "k-means++"), 19 | (10000000, 128, 128, "k-means++"), 20 | (10000, 8, 4, "random"), 21 | (100000, 32, 16, "random"), 22 | (1000000, 64, 64, "random"), 23 | (10000000, 128, 128, "random"), 24 | ], 25 | ) 26 | def test_sklearn( 27 | benchmark: BenchmarkFixture, 28 | num_datapoints: int, 29 | num_features: int, 30 | num_centroids: int, 31 | init_strategy: str, 32 | ): 33 | pl.seed_everything(0) 34 | data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical") 35 | 36 | estimator = SklearnKMeans( 37 | num_centroids, 38 | algorithm="full", 39 | n_init=1, 40 | max_iter=100, 41 | tol=0, 42 | init=init_strategy, 43 | ) 44 | benchmark(estimator.fit, data.numpy()) 45 | 46 | 47 | @pytest.mark.parametrize( 48 | ("num_datapoints", "batch_size", "num_features", "num_centroids", "init_strategy"), 49 | [ 50 | (10000, None, 8, 4, "kmeans++"), 51 | (10000, 1000, 8, 4, "kmeans++"), 52 | (100000, None, 32, 16, "kmeans++"), 53 | (100000, 10000, 32, 16, "kmeans++"), 54 | (1000000, None, 64, 64, "kmeans++"), 55 | (1000000, 100000, 64, 64, "kmeans++"), 56 | (10000, None, 8, 4, "random"), 57 | (10000, 1000, 8, 4, "random"), 58 | (100000, None, 32, 16, "random"), 59 | (100000, 10000, 32, 16, "random"), 60 | (1000000, None, 64, 64, "random"), 61 | (1000000, 100000, 64, 64, "random"), 62 | ], 63 | ) 64 | def test_pycave( 65 | benchmark: BenchmarkFixture, 66 | num_datapoints: int, 67 | batch_size: Optional[int], 68 | num_features: int, 69 | num_centroids: int, 70 | init_strategy: KMeansInitStrategy, 71 | ): 72 | pl.seed_everything(0) 73 | data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical") 74 | 75 | estimator = KMeans( 76 | num_centroids, 77 | init_strategy=init_strategy, 78 | batch_size=batch_size, 79 | trainer_params=dict(max_epochs=100), 80 | ) 81 | benchmark(estimator.fit, data) 82 | 83 | 84 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") 85 | @pytest.mark.parametrize( 86 | ("num_datapoints", "batch_size", "num_features", "num_centroids", "init_strategy"), 87 | [ 88 | (10000, None, 8, 4, "kmeans++"), 89 | (10000, 1000, 8, 4, "kmeans++"), 90 | (100000, None, 32, 16, "kmeans++"), 91 | (100000, 10000, 32, 16, "kmeans++"), 92 | (1000000, None, 64, 64, "kmeans++"), 93 | (1000000, 100000, 64, 64, "kmeans++"), 94 | (10000000, 1000000, 128, 128, "kmeans++"), 95 | (10000, None, 8, 4, "random"), 96 | (10000, 1000, 8, 4, "random"), 97 | (100000, None, 32, 16, "random"), 98 | (100000, 10000, 32, 16, "random"), 99 | (1000000, None, 64, 64, "random"), 100 | (1000000, 100000, 64, 64, "random"), 101 | (10000000, 1000000, 128, 128, "random"), 102 | ], 103 | ) 104 | def test_pycave_gpu( 105 | benchmark: BenchmarkFixture, 106 | num_datapoints: int, 107 | batch_size: Optional[int], 108 | num_features: int, 109 | num_centroids: int, 110 | init_strategy: KMeansInitStrategy, 111 | ): 112 | # Initialize GPU 113 | torch.empty(1, device="cuda:0") 114 | 115 | pl.seed_everything(0) 116 | data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical") 117 | 118 | estimator = KMeans( 119 | num_centroids, 120 | init_strategy=init_strategy, 121 | batch_size=batch_size, 122 | convergence_tolerance=0, 123 | trainer_params=dict(max_epochs=100, accelerator="gpu", devices=1), 124 | ) 125 | benchmark(estimator.fit, data) 126 | -------------------------------------------------------------------------------- /tests/bayes/gmm/benchmark_gmm_estimator.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | from typing import Optional 3 | import pytest 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytest_benchmark.fixture import BenchmarkFixture # type: ignore 7 | from sklearn.mixture import GaussianMixture as SklearnGaussianMixture # type: ignore 8 | from pycave.bayes import GaussianMixture 9 | from pycave.bayes.core.types import CovarianceType 10 | from tests._data.gmm import sample_gmm 11 | 12 | 13 | @pytest.mark.parametrize( 14 | ("num_datapoints", "num_features", "num_components", "covariance_type"), 15 | [ 16 | (10000, 8, 4, "diag"), 17 | (10000, 8, 4, "tied"), 18 | (10000, 8, 4, "full"), 19 | (100000, 32, 16, "diag"), 20 | (100000, 32, 16, "tied"), 21 | (100000, 32, 16, "full"), 22 | (1000000, 64, 64, "diag"), 23 | ], 24 | ) 25 | def test_sklearn( 26 | benchmark: BenchmarkFixture, 27 | num_datapoints: int, 28 | num_features: int, 29 | num_components: int, 30 | covariance_type: CovarianceType, 31 | ): 32 | pl.seed_everything(0) 33 | data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type) 34 | 35 | estimator = SklearnGaussianMixture( 36 | num_components, 37 | covariance_type=covariance_type, 38 | tol=0, 39 | n_init=1, 40 | max_iter=100, 41 | reg_covar=1e-3, 42 | init_params="random", 43 | means_init=means.numpy(), 44 | ) 45 | benchmark(estimator.fit, data.numpy()) 46 | 47 | 48 | @pytest.mark.parametrize( 49 | ("num_datapoints", "num_features", "num_components", "covariance_type", "batch_size"), 50 | [ 51 | (10000, 8, 4, "diag", None), 52 | (10000, 8, 4, "tied", None), 53 | (10000, 8, 4, "full", None), 54 | (100000, 32, 16, "diag", None), 55 | (100000, 32, 16, "tied", None), 56 | (100000, 32, 16, "full", None), 57 | (1000000, 64, 64, "diag", None), 58 | (10000, 8, 4, "diag", 1000), 59 | (10000, 8, 4, "tied", 1000), 60 | (10000, 8, 4, "full", 1000), 61 | (100000, 32, 16, "diag", 10000), 62 | (100000, 32, 16, "tied", 10000), 63 | (100000, 32, 16, "full", 10000), 64 | (1000000, 64, 64, "diag", 100000), 65 | ], 66 | ) 67 | def test_pycave( 68 | benchmark: BenchmarkFixture, 69 | num_datapoints: int, 70 | num_features: int, 71 | num_components: int, 72 | covariance_type: CovarianceType, 73 | batch_size: Optional[int], 74 | ): 75 | pl.seed_everything(0) 76 | data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type) 77 | 78 | estimator = GaussianMixture( 79 | num_components, 80 | covariance_type=covariance_type, 81 | init_means=means, 82 | convergence_tolerance=0, 83 | covariance_regularization=1e-3, 84 | batch_size=batch_size, 85 | trainer_params=dict(max_epochs=100), 86 | ) 87 | benchmark(estimator.fit, data) 88 | 89 | 90 | @pytest.mark.parametrize( 91 | ("num_datapoints", "num_features", "num_components", "covariance_type", "batch_size"), 92 | [ 93 | (10000, 8, 4, "diag", None), 94 | (10000, 8, 4, "tied", None), 95 | (10000, 8, 4, "full", None), 96 | (100000, 32, 16, "diag", None), 97 | (100000, 32, 16, "tied", None), 98 | (100000, 32, 16, "full", None), 99 | (1000000, 64, 64, "diag", None), 100 | (10000, 8, 4, "diag", 1000), 101 | (10000, 8, 4, "tied", 1000), 102 | (10000, 8, 4, "full", 1000), 103 | (100000, 32, 16, "diag", 10000), 104 | (100000, 32, 16, "tied", 10000), 105 | (100000, 32, 16, "full", 10000), 106 | (1000000, 64, 64, "diag", 100000), 107 | (1000000, 64, 64, "tied", 100000), 108 | ], 109 | ) 110 | def test_pycave_gpu( 111 | benchmark: BenchmarkFixture, 112 | num_datapoints: int, 113 | num_features: int, 114 | num_components: int, 115 | covariance_type: CovarianceType, 116 | batch_size: Optional[int], 117 | ): 118 | # Initialize GPU 119 | torch.empty(1, device="cuda:0") 120 | 121 | pl.seed_everything(0) 122 | data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type) 123 | 124 | estimator = GaussianMixture( 125 | num_components, 126 | covariance_type=covariance_type, 127 | init_means=means, 128 | convergence_tolerance=0, 129 | covariance_regularization=1e-3, 130 | batch_size=batch_size, 131 | trainer_params=dict(max_epochs=100, accelerator="gpu", devices=1), 132 | ) 133 | benchmark(estimator.fit, data) 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyCave 2 | 3 | ![PyPi](https://img.shields.io/pypi/v/pycave?label=version) 4 | ![License](https://img.shields.io/pypi/l/pycave) 5 | 6 | PyCave allows you to run traditional machine learning models on CPU, GPU, and even on multiple 7 | nodes. All models are implemented in [PyTorch](https://pytorch.org/) and provide an `Estimator` API 8 | that is fully compatible with [scikit-learn](https://scikit-learn.org/stable/). 9 | 10 | For Gaussian mixture model, PyCave allows for 100x speed ups when using a GPU and enables to train 11 | on markedly larger datasets via mini-batch training. The full suite of benchmarks run to compare 12 | PyCave models against scikit-learn models is available on the 13 | [documentation website](https://pycave.borchero.com/sites/benchmark.html). 14 | 15 | _PyCave version 3 is a complete rewrite of PyCave which is tested much more rigorously, depends on 16 | well-maintained libraries and is tuned for better performance. While you are, thus, highly 17 | encouraged to upgrade, refer to [pycave-v2.borchero.com](https://pycave-v2.borchero.com) for 18 | documentation on PyCave 2._ 19 | 20 | ## Features 21 | 22 | - Support for GPU and multi-node training by implementing models in PyTorch and relying on 23 | [PyTorch Lightning](https://www.pytorchlightning.ai/) 24 | - Mini-batch training for all models such that they can be used on huge datasets 25 | - Well-structured implementation of models 26 | 27 | - High-level `Estimator` API allows for easy usage such that models feel and behave like in 28 | scikit-learn 29 | - Medium-level `LightingModule` implements the training algorithm 30 | - Low-level PyTorch `Module` manages the model parameters 31 | 32 | ## Installation 33 | 34 | PyCave is available via `pip`: 35 | 36 | ```bash 37 | pip install pycave 38 | ``` 39 | 40 | If you are using [Poetry](https://python-poetry.org/): 41 | 42 | ```bash 43 | poetry add pycave 44 | ``` 45 | 46 | ## Usage 47 | 48 | If you've ever used scikit-learn, you'll feel right at home when using PyCave. First, let's create 49 | some artificial data to work with: 50 | 51 | ```python 52 | import torch 53 | 54 | X = torch.cat([ 55 | torch.randn(10000, 8) - 5, 56 | torch.randn(10000, 8), 57 | torch.randn(10000, 8) + 5, 58 | ]) 59 | ``` 60 | 61 | This dataset consists of three clusters with 8-dimensional datapoints. If you want to fit a K-Means 62 | model, to find the clusters' centroids, it's as easy as: 63 | 64 | ```python 65 | from pycave.clustering import KMeans 66 | 67 | estimator = KMeans(3) 68 | estimator.fit(X) 69 | 70 | # Once the estimator is fitted, it provides various properties. One of them is 71 | # the `model_` property which yields the PyTorch module with the fitted parameters. 72 | print("Centroids are:") 73 | print(estimator.model_.centroids) 74 | ``` 75 | 76 | Due to the high-level estimator API, the usage for all machine learning models is similar. The API 77 | documentation provides more detailed information about parameters that can be passed to estimators 78 | and which methods are available. 79 | 80 | ### GPU and Multi-Node training 81 | 82 | For GPU- and multi-node training, PyCave leverages PyTorch Lightning. The hardware that training 83 | runs on is determined by the 84 | [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.trainer.html#pytorch_lightning.trainer.trainer.Trainer) 85 | class. It's 86 | [**init**](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.trainer.html#pytorch_lightning.trainer.trainer.Trainer.__init__) 87 | method provides various configuration options. 88 | 89 | If you want to run K-Means with a GPU, you can pass the options `accelerator='gpu'` and `devices=1` 90 | to the estimator's initializer: 91 | 92 | ```python 93 | estimator = KMeans(3, trainer_params=dict(accelerator='gpu', devices=1)) 94 | ``` 95 | 96 | Similarly, if you want to train on 4 nodes simultaneously where each node has one GPU available, 97 | you can specify this as follows: 98 | 99 | ```python 100 | estimator = KMeans(3, trainer_params=dict(num_nodes=4, accelerator='gpu', devices=1)) 101 | ``` 102 | 103 | In fact, **you do not need to change anything else in your code**. 104 | 105 | ### Implemented Models 106 | 107 | Currently, PyCave implements three different models: 108 | 109 | - [GaussianMixture](https://pycave.borchero.com/sites/generated/bayes/gmm/pycave.bayes.GaussianMixture.html) 110 | - [MarkovChain](https://pycave.borchero.com/sites/generated/bayes/markov_chain/pycave.bayes.MarkovChain.html) 111 | - [K-Means](https://pycave.borchero.com/sites/generated/clustering/kmeans/pycave.clustering.KMeans.html) 112 | 113 | ## License 114 | 115 | PyCave is licensed under the [MIT License](https://github.com/borchero/pycave/blob/main/LICENSE). 116 | -------------------------------------------------------------------------------- /pycave/bayes/core/normal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ._jit import jit_log_normal, jit_sample_normal 3 | from .types import CovarianceType 4 | 5 | 6 | def cholesky_precision(covariances: torch.Tensor, covariance_type: CovarianceType) -> torch.Tensor: 7 | """ 8 | Computes the Cholesky decompositions of the precision matrices induced by the provided 9 | covariance matrices. 10 | 11 | Args: 12 | covariances: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``, 13 | ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the 14 | ``covariance_type``. These are the covariance matrices of multivariate Normal 15 | distributions. 16 | covariance_type: The type of covariance for the covariance matrices given. 17 | 18 | Returns: 19 | A tensor of the same shape as ``covariances``, providing the lower-triangular Cholesky 20 | decompositions of the precision matrices. 21 | """ 22 | if covariance_type in ("tied", "full"): 23 | # Compute Cholesky decomposition 24 | cholesky = torch.linalg.cholesky(covariances) 25 | # Invert 26 | num_features = covariances.size(-1) 27 | target = torch.eye(num_features, dtype=covariances.dtype, device=covariances.device) 28 | if covariance_type == "full": 29 | num_components = covariances.size(0) 30 | target = target.unsqueeze(0).expand(num_components, -1, -1) 31 | return torch.linalg.solve_triangular(cholesky, target, upper=False).transpose(-2, -1) 32 | 33 | # "Simple" kind of covariance 34 | return covariances.sqrt().reciprocal() 35 | 36 | 37 | def covariance(cholesky_precisions: torch.Tensor, covariance_type: CovarianceType) -> torch.Tensor: 38 | """ 39 | Computes the covariances matrices of the provided Cholesky decompositions of the precision 40 | matrices. This function is the inverse of :meth:`cholesky_precision`. 41 | 42 | Args: 43 | cholesky_precisions: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``, 44 | ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the 45 | ``covariance_type``. These are the Cholesky decompositions of the precisions of 46 | multivariate Normal distributions. 47 | covariance_type: The type of covariance for the covariance matrices given. 48 | 49 | Returns: 50 | A tensor of the same shape as ``cholesky_precisions``, providing the covariance matrices 51 | corresponding to the given Cholesky-decomposed precision matrices. 52 | """ 53 | if covariance_type in ("tied", "full"): 54 | choleksy_covars = torch.linalg.inv(cholesky_precisions) 55 | if covariance_type == "tied": 56 | return torch.matmul(choleksy_covars.T, choleksy_covars) 57 | return torch.bmm(choleksy_covars.transpose(1, 2), choleksy_covars) 58 | 59 | # "Simple" kind of covariance 60 | return (cholesky_precisions**2).reciprocal() 61 | 62 | 63 | def log_normal( 64 | x: torch.Tensor, 65 | means: torch.Tensor, 66 | precisions_cholesky: torch.Tensor, 67 | covariance_type: CovarianceType, 68 | ) -> torch.Tensor: 69 | """ 70 | Computes the log-probability of the given data for multiple multivariate Normal distributions 71 | defined by their means and covariances. 72 | 73 | Args: 74 | x: A tensor of shape ``[num_datapoints, dim]``. This is the data to compute the 75 | log-probability for. 76 | means: A tensor of shape ``[num_components, dim]``. These are the means of the multivariate 77 | Normal distributions. 78 | precisions_cholesky: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``, 79 | ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the 80 | ``covariance_type``. These are the upper-triangular Cholesky matrices for the inverse 81 | covariance matrices (i.e. precision matrices) of the multivariate Normal distributions. 82 | covariance_type: The type of covariance for the covariance matrices given. 83 | 84 | Returns: 85 | A tensor of shape ``[num_datapoints, num_components]`` with the log-probabilities for each 86 | datapoint and each multivariate Normal distribution. 87 | """ 88 | return jit_log_normal(x, means, precisions_cholesky, covariance_type) 89 | 90 | 91 | def sample_normal( 92 | num: int, 93 | mean: torch.Tensor, 94 | cholesky_precisions: torch.Tensor, 95 | covariance_type: CovarianceType, 96 | ) -> torch.Tensor: 97 | """ 98 | Samples the given number of times from the multivariate Normal distribution described by the 99 | mean and Cholesky precision. 100 | 101 | Args: 102 | num: The number of times to sample. 103 | means: A tensor of shape ``[dim]`` with the mean of the distribution to sample from. 104 | choleksy_precisions: A tensor of shape ``[dim, dim]``, ``[dim]``, ``[dim]`` or ``[1]`` 105 | depending on the ``covariance_type``. This is the corresponding Cholesky precision 106 | matrix for the mean. 107 | covariance_type: The type of covariance for the covariance matrix given. 108 | 109 | Returns: 110 | A tensor of shape ``[num_samples, dim]`` with the samples from the Normal distribution. 111 | """ 112 | return jit_sample_normal(num, mean, cholesky_precisions, covariance_type) 113 | -------------------------------------------------------------------------------- /tests/bayes/core/benchmark_log_normal.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import numpy as np 3 | import torch 4 | from pytest_benchmark.fixture import BenchmarkFixture # type: ignore 5 | from sklearn.mixture._gaussian_mixture import _compute_precision_cholesky # type: ignore 6 | from sklearn.mixture._gaussian_mixture import _estimate_log_gaussian_prob # type: ignore 7 | from torch.distributions import MultivariateNormal 8 | from pycave.bayes.core import cholesky_precision, log_normal 9 | 10 | 11 | def test_log_normal_spherical(benchmark: BenchmarkFixture): 12 | data = torch.randn(10000, 100) 13 | means = torch.randn(50, 100) 14 | precisions = cholesky_precision(torch.rand(50), "spherical") 15 | benchmark(log_normal, data, means, precisions, covariance_type="spherical") 16 | 17 | 18 | def test_torch_log_normal_spherical(benchmark: BenchmarkFixture): 19 | data = torch.randn(10000, 100) 20 | means = torch.randn(50, 100) 21 | covars = torch.rand(50) 22 | covar_matrices = torch.stack([torch.eye(means.size(-1)) * c for c in covars]) 23 | 24 | cholesky = torch.linalg.cholesky(covar_matrices) 25 | distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False) 26 | benchmark(distribution.log_prob, data.unsqueeze(1)) 27 | 28 | 29 | def test_numpy_log_normal_spherical(benchmark: BenchmarkFixture): 30 | data = np.random.randn(10000, 100) 31 | means = np.random.randn(50, 100) 32 | covars = np.random.rand(50) 33 | benchmark(_estimate_log_gaussian_prob, data, means, covars, "spherical") # type: ignore 34 | 35 | 36 | # ------------------------------------------------------------------------------------------------- 37 | 38 | 39 | def test_log_normal_diag(benchmark: BenchmarkFixture): 40 | data = torch.randn(10000, 100) 41 | means = torch.randn(50, 100) 42 | precisions = cholesky_precision(torch.rand(50, 100), "diag") 43 | benchmark(log_normal, data, means, precisions, covariance_type="diag") 44 | 45 | 46 | def test_torch_log_normal_diag(benchmark: BenchmarkFixture): 47 | data = torch.randn(10000, 100) 48 | means = torch.randn(50, 100) 49 | covars = torch.rand(50, 100) 50 | covar_matrices = torch.stack([torch.diag(c) for c in covars]) 51 | 52 | cholesky = torch.linalg.cholesky(covar_matrices) 53 | distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False) 54 | benchmark(distribution.log_prob, data.unsqueeze(1)) 55 | 56 | 57 | def test_numpy_log_normal_diag(benchmark: BenchmarkFixture): 58 | data = np.random.randn(10000, 100) 59 | means = np.random.randn(50, 100) 60 | covars = np.random.rand(50, 100) 61 | benchmark(_estimate_log_gaussian_prob, data, means, covars, "diag") # type: ignore 62 | 63 | 64 | # ------------------------------------------------------------------------------------------------- 65 | 66 | 67 | def test_log_normal_full(benchmark: BenchmarkFixture): 68 | data = torch.randn(10000, 100) 69 | means = torch.randn(50, 100) 70 | A = torch.randn(50, 1000, 100) 71 | covars = A.permute(0, 2, 1).bmm(A) 72 | precisions = cholesky_precision(covars, "full") 73 | benchmark(log_normal, data, means, precisions, covariance_type="full") 74 | 75 | 76 | def test_torch_log_normal_full(benchmark: BenchmarkFixture): 77 | data = torch.randn(10000, 100) 78 | means = torch.randn(50, 100) 79 | A = torch.randn(50, 1000, 100) 80 | covars = A.permute(0, 2, 1).bmm(A) 81 | 82 | cholesky = torch.linalg.cholesky(covars) 83 | distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False) 84 | benchmark(distribution.log_prob, data.unsqueeze(1)) 85 | 86 | 87 | def test_numpy_log_normal_full(benchmark: BenchmarkFixture): 88 | data = np.random.randn(10000, 100) 89 | means = np.random.randn(50, 100) 90 | A = np.random.randn(50, 1000, 100) 91 | covars = np.matmul(np.transpose(A, (0, 2, 1)), A) 92 | 93 | precisions = _compute_precision_cholesky(covars, "full") # type: ignore 94 | benchmark( 95 | _estimate_log_gaussian_prob, # type: ignore 96 | data, 97 | means, 98 | precisions, 99 | covariance_type="full", 100 | ) 101 | 102 | 103 | # ------------------------------------------------------------------------------------------------- 104 | 105 | 106 | def test_log_normal_tied(benchmark: BenchmarkFixture): 107 | data = torch.randn(10000, 100) 108 | means = torch.randn(50, 100) 109 | A = torch.randn(1000, 100) 110 | covars = A.t().mm(A) 111 | precisions = cholesky_precision(covars, "tied") 112 | benchmark(log_normal, data, means, precisions, covariance_type="tied") 113 | 114 | 115 | def test_torch_log_normal_tied(benchmark: BenchmarkFixture): 116 | data = torch.randn(10000, 100) 117 | means = torch.randn(50, 100) 118 | A = torch.randn(1000, 100) 119 | covars = A.t().mm(A) 120 | 121 | cholesky = torch.linalg.cholesky(covars) 122 | distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False) 123 | benchmark(distribution.log_prob, data.unsqueeze(1)) 124 | 125 | 126 | def test_numpy_log_normal_tied(benchmark: BenchmarkFixture): 127 | data = np.random.randn(10000, 100) 128 | means = np.random.randn(50, 100) 129 | A = np.random.randn(1000, 100) 130 | covars = A.T.dot(A) 131 | 132 | precisions = _compute_precision_cholesky(covars, "tied") # type: ignore 133 | benchmark( 134 | _estimate_log_gaussian_prob, # type: ignore 135 | data, 136 | means, 137 | precisions, 138 | covariance_type="tied", 139 | ) 140 | -------------------------------------------------------------------------------- /pycave/bayes/gmm/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | import numpy as np 4 | import torch 5 | from lightkit.nn import Configurable 6 | from torch import jit, nn 7 | from pycave.bayes.core import covariance, covariance_shape, CovarianceType 8 | from pycave.bayes.core._jit import jit_log_normal, jit_sample_normal 9 | 10 | 11 | @dataclass 12 | class GaussianMixtureModelConfig: 13 | """ 14 | Configuration class for a Gaussian mixture model. 15 | 16 | See also: 17 | :class:`GaussianMixtureModel` 18 | """ 19 | 20 | #: The number of components in the GMM. 21 | num_components: int 22 | #: The number of features for the GMM's components. 23 | num_features: int 24 | #: The type of covariance to use for the components. 25 | covariance_type: CovarianceType 26 | 27 | 28 | class GaussianMixtureModel(Configurable[GaussianMixtureModelConfig], nn.Module): 29 | """ 30 | PyTorch module for a Gaussian mixture model. 31 | 32 | Covariances are represented via their Cholesky decomposition for computational efficiency. The 33 | model does not have trainable parameters. 34 | """ 35 | 36 | #: The probabilities of each component, buffer of shape ``[num_components]``. 37 | component_probs: torch.Tensor 38 | #: The means of each component, buffer of shape ``[num_components, num_features]``. 39 | means: torch.Tensor 40 | #: The precision matrices for the components' covariances, buffer with a shape dependent 41 | #: on the covariance type, see :class:`CovarianceType`. 42 | precisions_cholesky: torch.Tensor 43 | 44 | def __init__(self, config: GaussianMixtureModelConfig): 45 | """ 46 | Args: 47 | config: The configuration to use for initializing the module's buffers. 48 | """ 49 | super().__init__(config) 50 | 51 | self.covariance_type = config.covariance_type 52 | 53 | self.register_buffer("component_probs", torch.empty(config.num_components)) 54 | self.register_buffer("means", torch.empty(config.num_components, config.num_features)) 55 | 56 | shape = covariance_shape( 57 | config.num_components, config.num_features, config.covariance_type 58 | ) 59 | self.register_buffer("precisions_cholesky", torch.empty(shape)) 60 | 61 | self.reset_parameters() 62 | 63 | @jit.unused # type: ignore 64 | @property 65 | def covariances(self) -> torch.Tensor: 66 | """ 67 | The covariance matrices learnt for the GMM's components. 68 | 69 | The shape of the tensor depends on the covariance type, see :class:`CovarianceType`. 70 | """ 71 | return covariance(self.precisions_cholesky, self.covariance_type) # type: ignore 72 | 73 | @jit.unused 74 | def reset_parameters(self) -> None: 75 | """ 76 | Resets the parameters of the GMM. 77 | 78 | - Component probabilities are initialized via uniform sampling and normalization. 79 | - Means are initialized randomly from a Standard Normal. 80 | - Cholesky precisions are initialized randomly based on the covariance type. For all 81 | covariance types, it is based on uniform sampling. 82 | """ 83 | nn.init.uniform_(self.component_probs) 84 | self.component_probs.div_(self.component_probs.sum()) 85 | 86 | nn.init.normal_(self.means) 87 | 88 | nn.init.uniform_(self.precisions_cholesky) 89 | if self.covariance_type in ("full", "tied"): 90 | self.precisions_cholesky.tril_() 91 | 92 | def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 93 | """ 94 | Computes the log-probability of observing each of the provided datapoints for each of the 95 | GMM's components. 96 | 97 | Args: 98 | data: A tensor of shape ``[num_datapoints, num_features]`` for which to compute the 99 | log-probabilities. 100 | 101 | Returns: 102 | - A tensor of shape ``[num_datapoints, num_components]`` with the log-responsibilities 103 | for each datapoint and components. These are the logits of the Categorical 104 | distribution over the parameters. 105 | - A tensor of shape ``[num_datapoints]`` with the log-likelihood of each datapoint. 106 | """ 107 | log_probabilities = jit_log_normal( 108 | data, self.means, self.precisions_cholesky, self.covariance_type 109 | ) 110 | log_responsibilities = log_probabilities + self.component_probs.log() 111 | log_prob = log_responsibilities.logsumexp(1, keepdim=True) 112 | return log_responsibilities - log_prob, log_prob.squeeze(1) 113 | 114 | def sample(self, num_datapoints: int) -> torch.Tensor: 115 | """ 116 | Samples the provided number of datapoints from the GMM. 117 | 118 | Args: 119 | num_datapoints: The number of datapoints to sample. 120 | 121 | Returns: 122 | A tensor of shape ``[num_datapoints, num_features]`` with the random samples. 123 | 124 | Attention: 125 | This method does not automatically perform batching. If you need to sample many 126 | datapoints, call this method multiple times. 127 | """ 128 | # First, we sample counts for each 129 | component_counts = np.random.multinomial(num_datapoints, self.component_probs.numpy()) 130 | 131 | # Then, we generate datapoints for each components 132 | result = [] 133 | for i, count in enumerate(component_counts): 134 | sample = jit_sample_normal( 135 | count.item(), 136 | self.means[i], 137 | self._get_component_precision(i), 138 | self.covariance_type, 139 | ) 140 | result.append(sample) 141 | 142 | return torch.cat(result, dim=0) 143 | 144 | def _get_component_precision(self, component: int) -> torch.Tensor: 145 | if self.covariance_type == "tied": 146 | return self.precisions_cholesky 147 | return self.precisions_cholesky[component] 148 | -------------------------------------------------------------------------------- /pycave/bayes/gmm/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | import torch 3 | from torchmetrics import Metric 4 | from pycave.bayes.core import covariance_shape, CovarianceType 5 | 6 | 7 | class PriorAggregator(Metric): 8 | """ 9 | The prior aggregator aggregates component probabilities over batches and process. 10 | """ 11 | 12 | full_state_update = False 13 | 14 | def __init__( 15 | self, 16 | num_components: int, 17 | *, 18 | dist_sync_fn: Optional[Callable[[Any], Any]] = None, 19 | ): 20 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 21 | 22 | self.responsibilities: torch.Tensor 23 | self.add_state("responsibilities", torch.zeros(num_components), dist_reduce_fx="sum") 24 | 25 | def update(self, responsibilities: torch.Tensor) -> None: 26 | # Responsibilities have shape [N, K] 27 | self.responsibilities.add_(responsibilities.sum(0)) 28 | 29 | def compute(self) -> torch.Tensor: 30 | return self.responsibilities / self.responsibilities.sum() 31 | 32 | 33 | class MeanAggregator(Metric): 34 | """ 35 | The mean aggregator aggregates component means over batches and processes. 36 | """ 37 | 38 | full_state_update = False 39 | 40 | def __init__( 41 | self, 42 | num_components: int, 43 | num_features: int, 44 | *, 45 | dist_sync_fn: Optional[Callable[[Any], Any]] = None, 46 | ): 47 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 48 | 49 | self.mean_sum: torch.Tensor 50 | self.add_state("mean_sum", torch.zeros(num_components, num_features), dist_reduce_fx="sum") 51 | 52 | self.component_weights: torch.Tensor 53 | self.add_state("component_weights", torch.zeros(num_components), dist_reduce_fx="sum") 54 | 55 | def update(self, data: torch.Tensor, responsibilities: torch.Tensor) -> None: 56 | # Data has shape [N, D] 57 | # Responsibilities have shape [N, K] 58 | self.mean_sum.add_(responsibilities.t().matmul(data)) 59 | self.component_weights.add_(responsibilities.sum(0)) 60 | 61 | def compute(self) -> torch.Tensor: 62 | return self.mean_sum / self.component_weights.unsqueeze(1) 63 | 64 | 65 | class CovarianceAggregator(Metric): 66 | """ 67 | The covariance aggregator aggregates component covariances over batches and processes. 68 | """ 69 | 70 | full_state_update = False 71 | 72 | def __init__( 73 | self, 74 | num_components: int, 75 | num_features: int, 76 | covariance_type: CovarianceType, 77 | reg: float, 78 | *, 79 | dist_sync_fn: Optional[Callable[[Any], Any]] = None, 80 | ): 81 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 82 | 83 | self.num_components = num_components 84 | self.num_features = num_features 85 | self.covariance_type = covariance_type 86 | self.reg = reg 87 | 88 | self.covariance_sum: torch.Tensor 89 | self.add_state( 90 | "covariance_sum", 91 | torch.zeros(covariance_shape(num_components, num_features, covariance_type)), 92 | dist_reduce_fx="sum", 93 | ) 94 | 95 | self.component_weights: torch.Tensor 96 | self.add_state("component_weights", torch.zeros(num_components), dist_reduce_fx="sum") 97 | 98 | def update( 99 | self, data: torch.Tensor, responsibilities: torch.Tensor, means: torch.Tensor 100 | ) -> None: 101 | data_component_weights = responsibilities.sum(0) 102 | self.component_weights.add_(data_component_weights) 103 | 104 | if self.covariance_type in ("spherical", "diag"): 105 | x_prob = torch.matmul(responsibilities.t(), data.square()) 106 | m_prob = data_component_weights.unsqueeze(-1) * means.square() 107 | xm_prob = means * torch.matmul(responsibilities.t(), data) 108 | covars = x_prob - 2 * xm_prob + m_prob 109 | if self.covariance_type == "diag": 110 | self.covariance_sum.add_(covars) 111 | else: # covariance_type == "spherical" 112 | self.covariance_sum.add_(covars.mean(1)) 113 | elif self.covariance_type == "tied": 114 | # This is taken from https://github.com/scikit-learn/scikit-learn/blob/ 115 | # 844b4be24d20fc42cc13b957374c718956a0db39/sklearn/mixture/_gaussian_mixture.py#L183 116 | x_sq = data.T.matmul(data) 117 | mean_sq = (data_component_weights * means.T).matmul(means) 118 | self.covariance_sum.add_(x_sq - mean_sq) 119 | else: # covariance_type == "full": 120 | # We iterate over each component since this is typically faster... 121 | for i in range(self.num_components): 122 | component_diff = data - means[i] 123 | covars = (responsibilities[:, i].unsqueeze(1) * component_diff).T.matmul( 124 | component_diff 125 | ) 126 | self.covariance_sum[i].add_(covars) 127 | 128 | def compute(self) -> torch.Tensor: 129 | if self.covariance_type == "diag": 130 | return self.covariance_sum / self.component_weights.unsqueeze(-1) + self.reg 131 | if self.covariance_type == "spherical": 132 | return self.covariance_sum / self.component_weights + self.reg * self.num_features 133 | if self.covariance_type == "tied": 134 | result = self.covariance_sum / self.component_weights.sum() 135 | shape = result.size() 136 | result = result.flatten() 137 | result[:: self.num_features + 1].add_(self.reg) 138 | return result.view(shape) 139 | # covariance_type == "full" 140 | result = self.covariance_sum / self.component_weights.unsqueeze(-1).unsqueeze(-1) 141 | diag_mask = ( 142 | torch.eye(self.num_features, device=result.device, dtype=result.dtype) 143 | .bool() 144 | .unsqueeze(0) 145 | .expand(self.num_components, -1, -1) 146 | ) 147 | result[diag_mask] += self.reg 148 | return result 149 | -------------------------------------------------------------------------------- /pycave/bayes/markov_chain/model.py: -------------------------------------------------------------------------------- 1 | # pyright: reportPrivateUsage=false, reportUnknownParameterType=false 2 | from dataclasses import dataclass 3 | from typing import overload 4 | import torch 5 | import torch._jit_internal as _jit 6 | from lightkit.nn import Configurable 7 | from torch import jit, nn 8 | from torch.nn.utils.rnn import PackedSequence 9 | 10 | 11 | @dataclass 12 | class MarkovChainModelConfig: 13 | """ 14 | Configuration class for a Markov chain model. 15 | 16 | See also: 17 | :class:`MarkovChainModel` 18 | """ 19 | 20 | #: The number of states that are managed by the Markov chain. 21 | num_states: int 22 | 23 | 24 | class MarkovChainModel(Configurable[MarkovChainModelConfig], nn.Module): 25 | """ 26 | PyTorch module for a Markov chain. 27 | 28 | The initial state probabilities as well as the transition probabilities are non-trainable 29 | parameters. 30 | """ 31 | 32 | def __init__(self, config: MarkovChainModelConfig): 33 | """ 34 | Args: 35 | config: The configuration to use for initializing the module's buffers. 36 | """ 37 | super().__init__(config) 38 | 39 | #: The probabilities for the initial states, buffer of shape ``[num_states]``. 40 | self.initial_probs: torch.Tensor 41 | self.register_buffer("initial_probs", torch.empty(config.num_states)) 42 | 43 | #: The transition probabilities between all states, buffer of shape 44 | #: ``[num_states, num_states]``. 45 | self.transition_probs: torch.Tensor 46 | self.register_buffer("transition_probs", torch.empty(config.num_states, config.num_states)) 47 | 48 | self.reset_parameters() 49 | 50 | @jit.unused 51 | def reset_parameters(self) -> None: 52 | """ 53 | Resets the parameters of the Markov model. 54 | 55 | Initial and transition probabilities are sampled uniformly. 56 | """ 57 | nn.init.uniform_(self.initial_probs) 58 | self.initial_probs.div_(self.initial_probs.sum()) 59 | 60 | nn.init.uniform_(self.transition_probs) 61 | self.transition_probs.div_(self.transition_probs.sum(1, keepdim=True)) 62 | 63 | @overload 64 | @_jit._overload_method # pylint: disable=protected-access 65 | def forward(self, sequences: torch.Tensor) -> torch.Tensor: 66 | ... 67 | 68 | @overload 69 | @_jit._overload_method # pylint: disable=protected-access 70 | def forward(self, sequences: PackedSequence) -> torch.Tensor: # type: ignore 71 | ... 72 | 73 | def forward(self, sequences) -> torch.Tensor: # type: ignore 74 | """ 75 | Computes the log-probability of observing each of the provided sequences. 76 | 77 | Args: 78 | sequences: Tensor of shape ``[num_sequences, sequence_length]`` or a packed sequence. 79 | Packed sequences should be used whenever the sequence lengths differ. All 80 | sequences must contain state indices of dtype ``long``. 81 | 82 | Returns: 83 | A tensor of shape ``[sequence_length]``, returning the log-probability of each 84 | sequence. 85 | """ 86 | if isinstance(sequences, torch.Tensor): 87 | log_probs = self.initial_probs[sequences[:, 0]].log() 88 | sources = sequences[:, :-1] 89 | targets = sequences[:, 1:].unsqueeze(-1) 90 | transition_probs = self.transition_probs[sources].gather(-1, targets).squeeze(-1) 91 | return log_probs + transition_probs.log().sum(-1) 92 | if isinstance(sequences, PackedSequence): 93 | data = sequences.data 94 | batch_sizes = sequences.batch_sizes 95 | 96 | log_probs = self.initial_probs[data[: batch_sizes[0]]].log() 97 | offset = 0 98 | for prev_size, curr_size in zip(batch_sizes, batch_sizes[1:]): 99 | log_probs[:curr_size] += self.transition_probs[ 100 | data[offset : offset + curr_size], 101 | data[offset + prev_size : offset + prev_size + curr_size], 102 | ].log() 103 | offset += prev_size 104 | 105 | if sequences.unsorted_indices is not None: 106 | return log_probs[sequences.unsorted_indices] 107 | return log_probs 108 | raise ValueError("unsupported input type") 109 | 110 | def sample(self, num_sequences: int, sequence_length: int) -> torch.Tensor: 111 | """ 112 | Samples random sequences from the Markov chain. 113 | 114 | Args: 115 | num_sequences: The number of sequences to sample. 116 | sequence_length: The length of all sequences to sample. 117 | 118 | Returns: 119 | Tensor of shape ``[num_sequences, sequence_length]`` with dtype ``long``, providing the 120 | sampled states. 121 | """ 122 | samples = torch.empty( 123 | num_sequences, sequence_length, device=self.transition_probs.device, dtype=torch.long 124 | ) 125 | samples[:, 0] = self.initial_probs.multinomial(num_sequences, replacement=True) 126 | for i in range(1, sequence_length): 127 | samples[:, i] = self.transition_probs[samples[:, i - 1]].multinomial(1).squeeze(-1) 128 | return samples 129 | 130 | def stationary_distribution( 131 | self, tol: float = 1e-7, max_iterations: int = 1000 132 | ) -> torch.Tensor: 133 | """ 134 | Computes the stationary distribution of the Markov chain using power iteration. 135 | 136 | Args: 137 | tol: The tolerance to use when checking if the power iteration has converged. As soon 138 | as the norm between the vectors of two successive iterations is below this value, 139 | the iteration is stopped. 140 | max_iterations: The maximum number of iterations to run if the tolerance does not 141 | indicate convergence. 142 | 143 | Returns: 144 | A tensor of shape ``[num_states]`` with the stationary distribution (i.e. the 145 | eigenvector corresponding to the largest eigenvector of the transition matrix, 146 | normalized to describe a probability distribution). 147 | """ 148 | A = self.transition_probs.t() 149 | v = torch.rand(A.size(0), device=A.device, dtype=A.dtype) 150 | 151 | for _ in range(max_iterations): 152 | v_old = v 153 | v = A.mv(v) 154 | v = v / v.norm() 155 | if (v - v_old).norm() < tol: 156 | break 157 | 158 | return v / v.sum() 159 | -------------------------------------------------------------------------------- /docs/sites/benchmark.rst: -------------------------------------------------------------------------------- 1 | Benchmarks 2 | ========== 3 | 4 | In order to evaluate the runtime performance of PyCave, we run an exhaustive set of experiments to 5 | compare against the implementation found in scikit-learn. Evaluations are run at varying dataset 6 | sizes. 7 | 8 | All benchmarks are run on an instance with a Intel Xeon E5-2630 v4 CPU (2.2 GHz). We use at most 4 9 | cores and 60 GiB of memory. Also, there is a single GeForce GTX 1080 Ti GPU (11 GiB memory) 10 | available. For the performance measures, each benchmark is run at least 5 times. 11 | 12 | Gaussian Mixture 13 | ---------------- 14 | 15 | Setup 16 | ^^^^^ 17 | 18 | For measuring the performance of fitting a Gaussian mixture model, we fix the number of iterations 19 | after initialization to 100 to not measure any variances in the convergence criterion. For 20 | initialization, we further set the known means that were used to generate data to not run into 21 | issues of degenerate covariance matrices. Thus, all benchmarks essentially measure the performance 22 | after K-means initialization has been run. Benchmarks for K-means itself are listed below. 23 | 24 | Results 25 | ^^^^^^^ 26 | 27 | .. list-table:: Training Duration for Diagonal Covariance (``[num_datapoints, num_features] -> num_components``) 28 | :header-rows: 1 29 | :stub-columns: 1 30 | :widths: 3 2 2 2 2 2 31 | 32 | * - 33 | - Scikit-Learn 34 | - PyCave CPU (full) 35 | - PyCave CPU (batches) 36 | - PyCave GPU (full) 37 | - PyCave GPU (batches) 38 | * - ``[10k, 8] -> 4`` 39 | - **352 ms** 40 | - 649 ms 41 | - 3.9 s 42 | - 358 ms 43 | - 3.6 s 44 | * - ``[100k, 32] -> 16`` 45 | - 18.4 s 46 | - 4.3 s 47 | - 10.0 s 48 | - **527 ms** 49 | - 3.9 s 50 | * - ``[1M, 64] -> 64`` 51 | - 730 s 52 | - 196 s 53 | - 284 s 54 | - **7.7 s** 55 | - 15.3 s 56 | 57 | .. list-table:: Training Duration for Tied Covariance (``[num_datapoints, num_features] -> num_components``) 58 | :header-rows: 1 59 | :stub-columns: 1 60 | :widths: 3 2 2 2 2 2 61 | 62 | * - 63 | - Scikit-Learn 64 | - PyCave CPU (full) 65 | - PyCave CPU (batches) 66 | - PyCave GPU (full) 67 | - PyCave GPU (batches) 68 | * - ``[10k, 8] -> 4`` 69 | - 699 ms 70 | - 570 ms 71 | - 3.6 s 72 | - **356 ms** 73 | - 3.3 s 74 | * - ``[100k, 32] -> 16`` 75 | - 72.2 s 76 | - 12.1 s 77 | - 16.1 s 78 | - **919 ms** 79 | - 3.8 s 80 | * - ``[1M, 64] -> 64`` 81 | - -- 82 | - -- 83 | - -- 84 | - -- 85 | - **63.4 s** 86 | 87 | .. list-table:: Training Duration for Full Covariance (``[num_datapoints, num_features] -> num_components``) 88 | :header-rows: 1 89 | :stub-columns: 1 90 | :widths: 3 2 2 2 2 2 91 | 92 | * - 93 | - Scikit-Learn 94 | - PyCave CPU (full) 95 | - PyCave CPU (batches) 96 | - PyCave GPU (full) 97 | - PyCave GPU (batches) 98 | * - ``[10k, 8] -> 4`` 99 | - 1.1 s 100 | - 679 ms 101 | - 4.1 s 102 | - **648 ms** 103 | - 4.4 s 104 | * - ``[100k, 32] -> 16`` 105 | - 110 s 106 | - 13.5 s 107 | - 21.2 s 108 | - **2.4 s** 109 | - 7.8 s 110 | 111 | Summary 112 | ^^^^^^^ 113 | 114 | PyCave's implementation of the Gaussian mixture model is markedly more efficient than the one found 115 | in scikit-learn. Even on the CPU, PyCave outperforms scikit-learn significantly at a 100k 116 | datapoints already. When moving to the GPU, however, PyCave unfolds its full potential and yields 117 | speed ups at around 100x. For larger datasets, mini-batch training is the only alternative. PyCave 118 | fully supports that while the training is approximately twice as large as when training using the 119 | full data. The reason for this is that the M-step of the EM algorithm needs to be split across 120 | epochs, which, in turn, requires to replay the E-step. 121 | 122 | 123 | K-Means 124 | ------- 125 | 126 | Setup 127 | ^^^^^ 128 | 129 | For the scikit-learn implementation, we use Lloyd's algorithm instead of Elkan's algorithm to have 130 | a useful comparison with PyCave (which implements Lloyd's algorithm). 131 | 132 | Further, we fix the number of iterations after initialization to 100 to not measure any variances 133 | in the convergence criterion. 134 | 135 | Results 136 | ^^^^^^^ 137 | 138 | .. list-table:: Training Duration for Random Initialization (``[num_datapoints, num_features] -> num_clusters``) 139 | :header-rows: 1 140 | :stub-columns: 1 141 | :widths: 3 2 2 2 2 2 142 | 143 | * - 144 | - Scikit-Learn 145 | - PyCave CPU (full) 146 | - PyCave CPU (batches) 147 | - PyCave GPU (full) 148 | - PyCave GPU (batches) 149 | * - ``[10k, 8] -> 4`` 150 | - **13 ms** 151 | - 412 ms 152 | - 797 ms 153 | - 387 ms 154 | - 2.1 s 155 | * - ``[100k, 32] -> 16`` 156 | - **311 ms** 157 | - 2.1 s 158 | - 3.4 s 159 | - 707 ms 160 | - 2.5 s 161 | * - ``[1M, 64] -> 64`` 162 | - 10.0 s 163 | - 73.6 s 164 | - 58.1 s 165 | - **8.2 s** 166 | - 10.0 s 167 | * - ``[10M, 128] -> 128`` 168 | - 254 s 169 | - -- 170 | - -- 171 | - -- 172 | - **133 s** 173 | 174 | .. list-table:: Training Duration for K-Means++ Initialization (``[num_datapoints, num_features] -> num_clusters``) 175 | :header-rows: 1 176 | :stub-columns: 1 177 | :widths: 3 2 2 2 2 2 178 | 179 | * - 180 | - Scikit-Learn 181 | - PyCave CPU (full) 182 | - PyCave CPU (batches) 183 | - PyCave GPU (full) 184 | - PyCave GPU (batches) 185 | * - ``[10k, 8] -> 4`` 186 | - **15 ms** 187 | - 170 ms 188 | - 930 ms 189 | - 431 ms 190 | - 2.4 s 191 | * - ``[100k, 32] -> 16`` 192 | - **542 ms** 193 | - 2.3 s 194 | - 4.3 s 195 | - 840 ms 196 | - 3.2 s 197 | * - ``[1M, 64] -> 64`` 198 | - 25.3 s 199 | - 93.4 s 200 | - 83.7 s 201 | - **13.1 s** 202 | - 17.1 s 203 | * - ``[10M, 128] -> 128`` 204 | - 827 s 205 | - -- 206 | - -- 207 | - -- 208 | - **369 s** 209 | 210 | Summary 211 | ^^^^^^^ 212 | 213 | As it turns out, it is really hard to outperform the implementation found in scikit-learn. 214 | Especially if little data is available, the overhead of PyTorch and PyTorch Lightning renders 215 | PyCave comparatively slow. However, as more data is available, PyCave starts to become relatively 216 | faster and, when leveraging the GPU, it finally outperforms scikit-learn for a dataset size of 1M 217 | datapoints. Nonetheless, the improvement is marginal. 218 | -------------------------------------------------------------------------------- /pycave/bayes/markov_chain/estimator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import logging 3 | from typing import Any, cast, List 4 | import numpy as np 5 | import torch 6 | from lightkit import ConfigurableBaseEstimator 7 | from lightkit.data import DataLoader, dataset_from_tensors 8 | from torch.nn.utils.rnn import PackedSequence 9 | from torch.utils.data import Dataset 10 | from .lightning_module import MarkovChainLightningModule 11 | from .model import MarkovChainModel, MarkovChainModelConfig 12 | from .types import collate_sequences, collate_sequences_same_length, SequenceData 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class MarkovChain(ConfigurableBaseEstimator[MarkovChainModel]): # type: ignore 18 | """ 19 | Probabilistic model for observed state transitions. The Markov chain is similar to the hidden 20 | Markov model, only that the hidden states are known. More information on the Markov chain is 21 | available on `Wikipedia `_. 22 | 23 | See also: 24 | .. currentmodule:: pycave.bayes.markov_chain 25 | .. autosummary:: 26 | :nosignatures: 27 | :template: classes/pytorch_module.rst 28 | 29 | MarkovChainModel 30 | MarkovChainModelConfig 31 | """ 32 | 33 | #: The fitted PyTorch module with all estimated parameters. 34 | model_: MarkovChainModel 35 | 36 | def __init__( 37 | self, 38 | num_states: int | None = None, 39 | *, 40 | symmetric: bool = False, 41 | batch_size: int | None = None, 42 | trainer_params: dict[str, Any] | None = None, 43 | ): 44 | """ 45 | Args: 46 | num_states: The number of states that the Markov chain has. If not provided, it will 47 | be derived automatically when calling :meth:`fit`. Note that this requires a pass 48 | through the data. Consider setting this option explicitly if you're fitting a lot 49 | of data. 50 | symmetric: Whether the transitions between states should be considered symmetric. 51 | batch_size: The batch size to use when fitting the model. If not provided, the full 52 | data will be used as a single batch. Set this if the full data does not fit into 53 | memory. 54 | num_workers: The number of workers to use for loading the data. Only used if a PyTorch 55 | dataset is passed to :meth:`fit` or related methods. 56 | trainer_params: Initialization parameters to use when initializing a PyTorch Lightning 57 | trainer. By default, it disables various stdout logs unless PyCave is configured to 58 | do verbose logging. Checkpointing and logging are disabled regardless of the log 59 | level. This estimator further enforces the following parameters: 60 | 61 | - ``max_epochs=1`` 62 | """ 63 | super().__init__( 64 | user_params=trainer_params, 65 | overwrite_params=dict(max_epochs=1), 66 | ) 67 | 68 | self.num_states = num_states 69 | self.symmetric = symmetric 70 | self.batch_size = batch_size 71 | 72 | def fit(self, sequences: SequenceData) -> MarkovChain: 73 | """ 74 | Fits the Markov chain on the provided data and returns the fitted estimator. 75 | 76 | Args: 77 | sequences: The sequences to fit the Markov chain on. 78 | 79 | Returns: 80 | The fitted Markov chain. 81 | """ 82 | config = MarkovChainModelConfig( 83 | num_states=self.num_states or _get_num_states(sequences), 84 | ) 85 | self.model_ = MarkovChainModel(config) 86 | 87 | logger.info("Fitting Markov chain...") 88 | self.trainer().fit( 89 | MarkovChainLightningModule(self.model_, self.symmetric), 90 | self._init_data_loader(sequences), 91 | ) 92 | return self 93 | 94 | def sample(self, num_sequences: int, sequence_length: int) -> torch.Tensor: 95 | """ 96 | Samples state sequences from the fitted Markov chain. 97 | 98 | Args: 99 | num_sequences: The number of sequences to sample. 100 | sequence_length: The length of the sequences to sample. 101 | 102 | Returns: 103 | The sampled sequences as a tensor of shape ``[num_sequences, sequence_length]``. 104 | 105 | Note: 106 | This method does not parallelize across multiple processes, i.e. performs no 107 | synchronization. 108 | """ 109 | return self.model_.sample(num_sequences, sequence_length) 110 | 111 | def score(self, sequences: SequenceData) -> float: 112 | """ 113 | Computes the average negative log-likelihood (NLL) of observing the provided sequences. If 114 | you want to have NLLs for each individual sequence, use :meth:`score_samples` instead. 115 | 116 | Args: 117 | sequences: The sequences for which to compute the average log-probability. 118 | 119 | Returns: 120 | The average NLL for all sequences. 121 | 122 | Note: 123 | See :meth:`score_samples` to obtain the NLL values for individual sequences. 124 | """ 125 | result = self.trainer().test( 126 | MarkovChainLightningModule(self.model_), 127 | self._init_data_loader(sequences), 128 | verbose=False, 129 | ) 130 | return result[0]["nll"] 131 | 132 | def score_samples(self, sequences: SequenceData) -> torch.Tensor: 133 | """ 134 | Computes the average negative log-likelihood (NLL) of observing the provided sequences. 135 | 136 | Args: 137 | sequences: The sequences for which to compute the NLL. 138 | 139 | Returns: 140 | A tensor of shape ``[num_sequences]`` with the NLLs for each individual sequence. 141 | 142 | Attention: 143 | When calling this function in a multi-process environment, each process receives only 144 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 145 | the values returned from this method. 146 | """ 147 | result = self.trainer().predict( 148 | MarkovChainLightningModule(self.model_), 149 | self._init_data_loader(sequences), 150 | return_predictions=True, 151 | ) 152 | return torch.stack(cast(List[torch.Tensor], result)) 153 | 154 | def _init_data_loader(self, sequences: SequenceData) -> DataLoader[PackedSequence]: 155 | if isinstance(sequences, Dataset): 156 | return DataLoader( 157 | sequences, 158 | batch_size=self.batch_size or len(sequences), # type: ignore 159 | collate_fn=collate_sequences, # type: ignore 160 | ) 161 | 162 | return DataLoader( # type: ignore 163 | dataset_from_tensors(sequences), 164 | batch_size=self.batch_size or len(sequences), 165 | collate_fn=collate_sequences_same_length, 166 | ) 167 | 168 | 169 | def _get_num_states(data: SequenceData) -> int: 170 | if isinstance(data, np.ndarray): 171 | assert data.dtype == np.int64, "array states must have type `np.int64`" 172 | return int(data.max() + 1) 173 | if isinstance(data, torch.Tensor): 174 | assert data.dtype == torch.long, "tensor states must have type `torch.long`" 175 | return int(data.max().item() + 1) 176 | return max(_get_num_states(entry) for entry in data) 177 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom ignore 2 | docs/sites/generated/ 3 | lightning_logs/ 4 | 5 | # Created by https://www.toptal.com/developers/gitignore/api/macos,visualstudiocode,pycharm+all,python 6 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,visualstudiocode,pycharm+all,python 7 | 8 | ### macOS ### 9 | # General 10 | .DS_Store 11 | .AppleDouble 12 | .LSOverride 13 | 14 | # Icon must end with two \r 15 | Icon 16 | 17 | 18 | # Thumbnails 19 | ._* 20 | 21 | # Files that might appear in the root of a volume 22 | .DocumentRevisions-V100 23 | .fseventsd 24 | .Spotlight-V100 25 | .TemporaryItems 26 | .Trashes 27 | .VolumeIcon.icns 28 | .com.apple.timemachine.donotpresent 29 | 30 | # Directories potentially created on remote AFP share 31 | .AppleDB 32 | .AppleDesktop 33 | Network Trash Folder 34 | Temporary Items 35 | .apdisk 36 | 37 | ### macOS Patch ### 38 | # iCloud generated files 39 | *.icloud 40 | 41 | ### PyCharm+all ### 42 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 43 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 44 | 45 | # User-specific stuff 46 | .idea/**/workspace.xml 47 | .idea/**/tasks.xml 48 | .idea/**/usage.statistics.xml 49 | .idea/**/dictionaries 50 | .idea/**/shelf 51 | 52 | # AWS User-specific 53 | .idea/**/aws.xml 54 | 55 | # Generated files 56 | .idea/**/contentModel.xml 57 | 58 | # Sensitive or high-churn files 59 | .idea/**/dataSources/ 60 | .idea/**/dataSources.ids 61 | .idea/**/dataSources.local.xml 62 | .idea/**/sqlDataSources.xml 63 | .idea/**/dynamic.xml 64 | .idea/**/uiDesigner.xml 65 | .idea/**/dbnavigator.xml 66 | 67 | # Gradle 68 | .idea/**/gradle.xml 69 | .idea/**/libraries 70 | 71 | # Gradle and Maven with auto-import 72 | # When using Gradle or Maven with auto-import, you should exclude module files, 73 | # since they will be recreated, and may cause churn. Uncomment if using 74 | # auto-import. 75 | # .idea/artifacts 76 | # .idea/compiler.xml 77 | # .idea/jarRepositories.xml 78 | # .idea/modules.xml 79 | # .idea/*.iml 80 | # .idea/modules 81 | # *.iml 82 | # *.ipr 83 | 84 | # CMake 85 | cmake-build-*/ 86 | 87 | # Mongo Explorer plugin 88 | .idea/**/mongoSettings.xml 89 | 90 | # File-based project format 91 | *.iws 92 | 93 | # IntelliJ 94 | out/ 95 | 96 | # mpeltonen/sbt-idea plugin 97 | .idea_modules/ 98 | 99 | # JIRA plugin 100 | atlassian-ide-plugin.xml 101 | 102 | # Cursive Clojure plugin 103 | .idea/replstate.xml 104 | 105 | # SonarLint plugin 106 | .idea/sonarlint/ 107 | 108 | # Crashlytics plugin (for Android Studio and IntelliJ) 109 | com_crashlytics_export_strings.xml 110 | crashlytics.properties 111 | crashlytics-build.properties 112 | fabric.properties 113 | 114 | # Editor-based Rest Client 115 | .idea/httpRequests 116 | 117 | # Android studio 3.1+ serialized cache file 118 | .idea/caches/build_file_checksums.ser 119 | 120 | ### PyCharm+all Patch ### 121 | # Ignore everything but code style settings and run configurations 122 | # that are supposed to be shared within teams. 123 | 124 | .idea/* 125 | 126 | !.idea/codeStyles 127 | !.idea/runConfigurations 128 | 129 | ### Python ### 130 | # Byte-compiled / optimized / DLL files 131 | __pycache__/ 132 | *.py[cod] 133 | *$py.class 134 | 135 | # C extensions 136 | *.so 137 | 138 | # Distribution / packaging 139 | .Python 140 | build/ 141 | develop-eggs/ 142 | dist/ 143 | downloads/ 144 | eggs/ 145 | .eggs/ 146 | lib/ 147 | lib64/ 148 | parts/ 149 | sdist/ 150 | var/ 151 | wheels/ 152 | share/python-wheels/ 153 | *.egg-info/ 154 | .installed.cfg 155 | *.egg 156 | MANIFEST 157 | 158 | # PyInstaller 159 | # Usually these files are written by a python script from a template 160 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 161 | *.manifest 162 | *.spec 163 | 164 | # Installer logs 165 | pip-log.txt 166 | pip-delete-this-directory.txt 167 | 168 | # Unit test / coverage reports 169 | htmlcov/ 170 | .tox/ 171 | .nox/ 172 | .coverage 173 | .coverage.* 174 | .cache 175 | nosetests.xml 176 | coverage.xml 177 | *.cover 178 | *.py,cover 179 | .hypothesis/ 180 | .pytest_cache/ 181 | cover/ 182 | 183 | # Translations 184 | *.mo 185 | *.pot 186 | 187 | # Django stuff: 188 | *.log 189 | local_settings.py 190 | db.sqlite3 191 | db.sqlite3-journal 192 | 193 | # Flask stuff: 194 | instance/ 195 | .webassets-cache 196 | 197 | # Scrapy stuff: 198 | .scrapy 199 | 200 | # Sphinx documentation 201 | docs/_build/ 202 | 203 | # PyBuilder 204 | .pybuilder/ 205 | target/ 206 | 207 | # Jupyter Notebook 208 | .ipynb_checkpoints 209 | 210 | # IPython 211 | profile_default/ 212 | ipython_config.py 213 | 214 | # pyenv 215 | # For a library or package, you might want to ignore these files since the code is 216 | # intended to run in multiple environments; otherwise, check them in: 217 | .python-version 218 | 219 | # pipenv 220 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 221 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 222 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 223 | # install all needed dependencies. 224 | #Pipfile.lock 225 | 226 | # poetry 227 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 228 | # This is especially recommended for binary packages to ensure reproducibility, and is more 229 | # commonly ignored for libraries. 230 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 231 | #poetry.lock 232 | 233 | # pdm 234 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 235 | #pdm.lock 236 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 237 | # in version control. 238 | # https://pdm.fming.dev/#use-with-ide 239 | .pdm.toml 240 | 241 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 242 | __pypackages__/ 243 | 244 | # Celery stuff 245 | celerybeat-schedule 246 | celerybeat.pid 247 | 248 | # SageMath parsed files 249 | *.sage.py 250 | 251 | # Environments 252 | .env 253 | .venv 254 | env/ 255 | venv/ 256 | ENV/ 257 | env.bak/ 258 | venv.bak/ 259 | 260 | # Spyder project settings 261 | .spyderproject 262 | .spyproject 263 | 264 | # Rope project settings 265 | .ropeproject 266 | 267 | # mkdocs documentation 268 | /site 269 | 270 | # mypy 271 | .mypy_cache/ 272 | .dmypy.json 273 | dmypy.json 274 | 275 | # Pyre type checker 276 | .pyre/ 277 | 278 | # pytype static type analyzer 279 | .pytype/ 280 | 281 | # Cython debug symbols 282 | cython_debug/ 283 | 284 | # PyCharm 285 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 286 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 287 | # and can be added to the global gitignore or merged into this file. For a more nuclear 288 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 289 | #.idea/ 290 | 291 | ### Python Patch ### 292 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 293 | poetry.toml 294 | 295 | 296 | ### VisualStudioCode ### 297 | .vscode/* 298 | !.vscode/settings.json 299 | !.vscode/tasks.json 300 | !.vscode/launch.json 301 | !.vscode/extensions.json 302 | !.vscode/*.code-snippets 303 | 304 | # Local History for Visual Studio Code 305 | .history/ 306 | 307 | # Built Visual Studio Code Extensions 308 | *.vsix 309 | 310 | ### VisualStudioCode Patch ### 311 | # Ignore all local history of files 312 | .history 313 | .ionide 314 | 315 | # End of https://www.toptal.com/developers/gitignore/api/macos,visualstudiocode,pycharm+all,python 316 | -------------------------------------------------------------------------------- /pycave/clustering/kmeans/metrics.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Callable, Optional 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | 7 | class CentroidAggregator(Metric): 8 | """ 9 | The centroid aggregator aggregates kmeans centroids over batches and processes. 10 | """ 11 | 12 | full_state_update = False 13 | 14 | def __init__( 15 | self, 16 | num_clusters: int, 17 | num_features: int, 18 | *, 19 | dist_sync_fn: Optional[Callable[[Any], Any]] = None, 20 | ): 21 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 22 | 23 | self.num_clusters = num_clusters 24 | self.num_features = num_features 25 | 26 | self.centroids: torch.Tensor 27 | self.add_state("centroids", torch.zeros(num_clusters, num_features), dist_reduce_fx="sum") 28 | 29 | self.cluster_counts: torch.Tensor 30 | self.add_state("cluster_counts", torch.zeros(num_clusters), dist_reduce_fx="sum") 31 | 32 | def update(self, data: torch.Tensor, assignments: torch.Tensor) -> None: 33 | indices = assignments.unsqueeze(1).expand(-1, self.num_features) 34 | self.centroids.scatter_add_(0, indices, data) 35 | 36 | counts = assignments.bincount(minlength=self.num_clusters).float() 37 | self.cluster_counts.add_(counts) 38 | 39 | def compute(self) -> torch.Tensor: 40 | return self.centroids / self.cluster_counts.unsqueeze(-1) 41 | 42 | 43 | class UniformSampler(Metric): 44 | """ 45 | The uniform sampler randomly samples a specified number of datapoints uniformly from all 46 | datapoints. 47 | 48 | The idea is the following: sample the number of choices from each batch and track the number of 49 | datapoints that was already sampled from. When sampling from the union of existing choices and 50 | a new batch, more weight is put on the existing choices (according to the number of datapoints 51 | they were already sampled from). 52 | """ 53 | 54 | full_state_update = False 55 | 56 | def __init__( 57 | self, 58 | num_choices: int, 59 | num_features: int, 60 | *, 61 | dist_sync_fn: Optional[Callable[[Any], Any]] = None, 62 | ): 63 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 64 | 65 | self.num_choices = num_choices 66 | 67 | self.choices: torch.Tensor 68 | self.add_state("choices", torch.empty(num_choices, num_features), dist_reduce_fx="cat") 69 | 70 | self.choice_weights: torch.Tensor 71 | self.add_state("choice_weights", torch.zeros(num_choices), dist_reduce_fx="cat") 72 | 73 | def update(self, data: torch.Tensor) -> None: 74 | if self.num_choices == 1: 75 | # If there is only one choice, the fastest thing is to use the `random` package. The 76 | # cumulative weight of the data is its size, the cumulative weight of the current 77 | # choice is some value. 78 | cum_weight = data.size(0) + self.choice_weights.item() 79 | if random.random() * cum_weight < data.size(0): 80 | # Use some item from the data, else keep the current choice 81 | self.choices.copy_(data[random.randrange(data.size(0))]) 82 | else: 83 | # The choices are computed from scratch every time, weighting the current choices by 84 | # the cumulative weight put on them 85 | weights = torch.cat( 86 | [ 87 | torch.ones(data.size(0), device=data.device, dtype=data.dtype), 88 | self.choice_weights, 89 | ] 90 | ) 91 | pool = torch.cat([data, self.choices]) 92 | samples = weights.multinomial(self.num_choices) 93 | self.choices.copy_(pool[samples]) 94 | 95 | # The weights are the cumulative counts, divided by the number of choices 96 | self.choice_weights.add_(data.size(0) / self.num_choices) 97 | 98 | def compute(self) -> torch.Tensor: 99 | # In the ddp setting, there are "too many" choices, so we sample 100 | if self.choices.size(0) > self.num_choices: 101 | samples = self.choice_weights.multinomial(self.num_choices) 102 | return self.choices[samples] 103 | return self.choices 104 | 105 | 106 | class DistanceSampler(Metric): 107 | """ 108 | The distance sampler may be used for kmeans++ initialization, to iteratively select centroids 109 | according to their squared distances to existing choices. 110 | 111 | Computing the distance to existing choices is not part of this sampler. Within each "cycle", it 112 | computes a given number of candidates. Candidates are sampled independently and may be 113 | duplicates. 114 | """ 115 | 116 | full_state_update = False 117 | 118 | def __init__( 119 | self, 120 | num_choices: int, 121 | num_features: int, 122 | *, 123 | dist_sync_fn: Optional[Callable[[Any], Any]] = None, 124 | ): 125 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 126 | 127 | self.num_choices = num_choices 128 | self.num_features = num_features 129 | 130 | self.choices: torch.Tensor 131 | self.add_state("choices", torch.empty(num_choices, num_features), dist_reduce_fx="cat") 132 | 133 | # Cumulative distance is the same for all choices 134 | self.cumulative_squared_distance: torch.Tensor 135 | self.add_state("cumulative_squared_distance", torch.zeros(1), dist_reduce_fx="cat") 136 | 137 | def update(self, data: torch.Tensor, shortest_distances: torch.Tensor) -> None: 138 | eps = torch.finfo(data.dtype).eps 139 | squared_distances = shortest_distances.square() 140 | 141 | # For all choices, check if we should use a sample from the data or the existing choice 142 | data_dist = squared_distances.sum() 143 | cum_dist = data_dist + eps + self.cumulative_squared_distance 144 | use_choice_from_data = ( 145 | torch.rand(self.num_choices, device=data.device, dtype=data.dtype) * cum_dist 146 | < data_dist + eps 147 | ) 148 | 149 | # Then, we sample from the data `num_choices` times and replace if needed 150 | choices = (squared_distances + eps).multinomial(self.num_choices, replacement=True) 151 | self.choices.masked_scatter_( 152 | use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]] 153 | ) 154 | 155 | # In any case, the cumulative distances are updated 156 | self.cumulative_squared_distance.add_(data_dist) 157 | 158 | def compute(self) -> torch.Tensor: 159 | # Upon computation, we sample if there is more than one choice (ddp setting) 160 | if self.choices.size(0) > self.num_choices: 161 | # choices now have shape [num_choices, num_processes, num_features] 162 | choices = self.choices.reshape(-1, self.num_choices, self.num_features).transpose(0, 1) 163 | # For each choice, we sample across processes 164 | choice_indices = torch.arange(self.num_choices, device=self.choices.device) 165 | process_indices = self.cumulative_squared_distance.multinomial( 166 | self.num_choices, replacement=True 167 | ) 168 | return choices[choice_indices, process_indices] 169 | # Otherwise, we can return the choices 170 | return self.choices 171 | 172 | 173 | class BatchSummer(Metric): 174 | """ 175 | Sums the values for a batch of items independently. 176 | """ 177 | 178 | full_state_update = True 179 | 180 | def __init__(self, num_values: int, *, dist_sync_fn: Optional[Callable[[Any], Any]] = None): 181 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 182 | 183 | self.sums: torch.Tensor 184 | self.add_state("sums", torch.zeros(num_values), dist_reduce_fx="sum") 185 | 186 | def update(self, values: torch.Tensor) -> None: 187 | self.sums.add_(values.sum(0)) 188 | 189 | def compute(self) -> torch.Tensor: 190 | return self.sums 191 | 192 | 193 | class BatchAverager(Metric): 194 | """ 195 | Averages the values for a batch of items independently. 196 | """ 197 | 198 | full_state_update = False 199 | 200 | def __init__( 201 | self, 202 | num_values: int, 203 | for_variance: bool, 204 | *, 205 | dist_sync_fn: Optional[Callable[[Any], Any]] = None, 206 | ): 207 | super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore 208 | 209 | self.for_variance = for_variance 210 | 211 | self.sums: torch.Tensor 212 | self.add_state("sums", torch.zeros(num_values), dist_reduce_fx="sum") 213 | 214 | self.counts: torch.Tensor 215 | self.add_state("counts", torch.zeros(num_values), dist_reduce_fx="sum") 216 | 217 | def update(self, values: torch.Tensor) -> None: 218 | self.sums.add_(values.sum(0)) 219 | self.counts.add_(values.size(0)) 220 | 221 | def compute(self) -> torch.Tensor: 222 | return self.sums / (self.counts - 1 if self.for_variance else self.counts) 223 | -------------------------------------------------------------------------------- /pycave/clustering/kmeans/estimator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import logging 3 | from typing import Any, cast, List 4 | import torch 5 | from lightkit import ConfigurableBaseEstimator 6 | from lightkit.data import collate_tensor, DataLoader, dataset_from_tensors, TensorLike 7 | from lightkit.estimator import PredictorMixin, TransformerMixin 8 | from .lightning_module import ( 9 | FeatureVarianceLightningModule, 10 | KMeansLightningModule, 11 | KmeansPlusPlusInitLightningModule, 12 | KmeansRandomInitLightningModule, 13 | ) 14 | from .model import KMeansModel, KMeansModelConfig 15 | from .types import KMeansInitStrategy 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class KMeans( 21 | ConfigurableBaseEstimator[KMeansModel], # type: ignore 22 | TransformerMixin[TensorLike, torch.Tensor], 23 | PredictorMixin[TensorLike, torch.Tensor], 24 | ): 25 | """ 26 | Model for clustering data into a predefined number of clusters. More information on K-means 27 | clustering is available on `Wikipedia `_. 28 | 29 | See also: 30 | .. currentmodule:: pycave.clustering.kmeans 31 | .. autosummary:: 32 | :nosignatures: 33 | :template: classes/pytorch_module.rst 34 | 35 | KMeansModel 36 | KMeansModelConfig 37 | """ 38 | 39 | #: The fitted PyTorch module with all estimated parameters. 40 | model_: KMeansModel 41 | #: A boolean indicating whether the model converged during training. 42 | converged_: bool 43 | #: The number of iterations the model was fitted for, excluding initialization. 44 | num_iter_: int 45 | #: The mean squared distance of all datapoints to their closest cluster centers. 46 | inertia_: float 47 | 48 | def __init__( 49 | self, 50 | num_clusters: int = 1, 51 | *, 52 | init_strategy: KMeansInitStrategy = "kmeans++", 53 | convergence_tolerance: float = 1e-4, 54 | batch_size: int | None = None, 55 | trainer_params: dict[str, Any] | None = None, 56 | ): 57 | """ 58 | Args: 59 | num_clusters: The number of clusters. 60 | init_strategy: The strategy for initializing centroids. 61 | convergence_tolerance: Training is conducted until the Frobenius norm of the change 62 | between cluster centroids falls below this threshold. The tolerance is multiplied 63 | by the average variance of the features. 64 | batch_size: The batch size to use when fitting the model. If not provided, the full 65 | data will be used as a single batch. Set this if the full data does not fit into 66 | memory. 67 | trainer_params: Initialization parameters to use when initializing a PyTorch Lightning 68 | trainer. By default, it disables various stdout logs unless PyCave is configured to 69 | do verbose logging. Checkpointing and logging are disabled regardless of the log 70 | level. This estimator further sets the following overridable defaults: 71 | 72 | - ``max_epochs=300`` 73 | 74 | Note: 75 | The number of epochs passed to the initializer only define the number of optimization 76 | epochs. Prior to that, initialization is run which may perform additional iterations 77 | through the data. 78 | """ 79 | super().__init__( 80 | default_params=dict(max_epochs=300), 81 | user_params=trainer_params, 82 | ) 83 | 84 | # Assign other properties 85 | self.batch_size = batch_size 86 | self.num_clusters = num_clusters 87 | self.init_strategy = init_strategy 88 | self.convergence_tolerance = convergence_tolerance 89 | 90 | def fit(self, data: TensorLike) -> KMeans: 91 | """ 92 | Fits the KMeans model on the provided data by running Lloyd's algorithm. 93 | 94 | Args: 95 | data: The tabular data to fit on. The dimensionality of the KMeans model is 96 | automatically inferred from this data. 97 | 98 | Returns: 99 | The fitted KMeans model. 100 | """ 101 | # Initialize model 102 | num_features = len(data[0]) 103 | config = KMeansModelConfig( 104 | num_clusters=self.num_clusters, 105 | num_features=num_features, 106 | ) 107 | self.model_ = KMeansModel(config) 108 | 109 | # Setup the data loading 110 | loader = DataLoader( 111 | dataset_from_tensors(data), 112 | batch_size=self.batch_size or len(data), 113 | collate_fn=collate_tensor, 114 | ) 115 | is_batch_training = self._num_batches_per_epoch(loader) > 1 116 | 117 | # First, initialize the centroids 118 | if self.init_strategy == "random": 119 | module = KmeansRandomInitLightningModule(self.model_) 120 | num_epochs = 1 121 | else: 122 | module = KmeansPlusPlusInitLightningModule( 123 | self.model_, 124 | is_batch_training=is_batch_training, 125 | ) 126 | num_epochs = 2 * config.num_clusters - 1 127 | 128 | logger.info("Running initialization...") 129 | self.trainer(max_epochs=num_epochs).fit(module, loader) 130 | 131 | # Then, in order to find the right convergence tolerance, we need to compute the variance 132 | # of the data. 133 | if self.convergence_tolerance != 0: 134 | variances = torch.empty(config.num_features) 135 | module = FeatureVarianceLightningModule(variances) 136 | self.trainer().fit(module, loader) 137 | 138 | tolerance_multiplier = cast(float, variances.mean().item()) 139 | convergence_tolerance = self.convergence_tolerance * tolerance_multiplier 140 | else: 141 | convergence_tolerance = 0 142 | 143 | # Then, we can fit the actual model. We need a new trainer for that 144 | logger.info("Fitting K-Means...") 145 | trainer = self.trainer() 146 | module = KMeansLightningModule( 147 | self.model_, 148 | convergence_tolerance=convergence_tolerance, 149 | ) 150 | trainer.fit(module, loader) 151 | 152 | # Assign convergence properties 153 | self.num_iter_ = module.current_epoch 154 | self.converged_ = module.current_epoch < trainer.max_epochs 155 | if "inertia" in trainer.callback_metrics: 156 | self.inertia_ = cast(float, trainer.callback_metrics["inertia"].item()) 157 | return self 158 | 159 | def predict(self, data: TensorLike) -> torch.Tensor: 160 | """ 161 | Predicts the closest cluster for each item provided. 162 | 163 | Args: 164 | data: The datapoints for which to predict the clusters. 165 | 166 | Returns: 167 | Tensor of shape ``[num_datapoints]`` with the index of the closest cluster for each 168 | datapoint. 169 | 170 | Attention: 171 | When calling this function in a multi-process environment, each process receives only 172 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 173 | the values returned from this method. 174 | """ 175 | loader = DataLoader( 176 | dataset_from_tensors(data), 177 | batch_size=self.batch_size or len(data), 178 | collate_fn=collate_tensor, 179 | ) 180 | result = self.trainer().predict( 181 | KMeansLightningModule(self.model_, predict_target="assignments"), loader 182 | ) 183 | return torch.cat(cast(List[torch.Tensor], result)) 184 | 185 | def score(self, data: TensorLike) -> float: 186 | """ 187 | Computes the average inertia of all the provided datapoints. That is, it computes the mean 188 | squared distance to each datapoint's closest centroid. 189 | 190 | Args: 191 | data: The data for which to compute the average inertia. 192 | 193 | Returns: 194 | The average inertia. 195 | 196 | Note: 197 | See :meth:`score_samples` to obtain the inertia for individual sequences. 198 | """ 199 | loader = DataLoader( 200 | dataset_from_tensors(data), 201 | batch_size=self.batch_size or len(data), 202 | collate_fn=collate_tensor, 203 | ) 204 | result = self.trainer().test(KMeansLightningModule(self.model_), loader, verbose=False) 205 | return result[0]["inertia"] 206 | 207 | def score_samples(self, data: TensorLike) -> torch.Tensor: 208 | """ 209 | Computes the inertia for each of the the provided datapoints. That is, it computes the mean 210 | squared distance of each datapoint to its closest centroid. 211 | 212 | Args: 213 | data: The data for which to compute the inertia values. 214 | 215 | Returns: 216 | A tensor of shape ``[num_datapoints]`` with the inertia of each datapoint. 217 | 218 | Attention: 219 | When calling this function in a multi-process environment, each process receives only 220 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 221 | the values returned from this method. 222 | """ 223 | loader = DataLoader( 224 | dataset_from_tensors(data), 225 | batch_size=self.batch_size or len(data), 226 | collate_fn=collate_tensor, 227 | ) 228 | result = self.trainer().predict( 229 | KMeansLightningModule(self.model_, predict_target="inertias"), loader 230 | ) 231 | return torch.cat(cast(List[torch.Tensor], result)) 232 | 233 | def transform(self, data: TensorLike) -> torch.Tensor: 234 | """ 235 | Transforms the provided data into the cluster-distance space. That is, it returns the 236 | distance of each datapoint to each cluster centroid. 237 | 238 | Args: 239 | data: The data to transform. 240 | 241 | Returns: 242 | A tensor of shape ``[num_datapoints, num_clusters]`` with the distances to the cluster 243 | centroids. 244 | 245 | Attention: 246 | When calling this function in a multi-process environment, each process receives only 247 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 248 | the values returned from this method. 249 | """ 250 | loader = DataLoader( 251 | dataset_from_tensors(data), 252 | batch_size=self.batch_size or len(data), 253 | collate_fn=collate_tensor, 254 | ) 255 | result = self.trainer().predict( 256 | KMeansLightningModule(self.model_, predict_target="distances"), loader 257 | ) 258 | return torch.cat(cast(List[torch.Tensor], result)) 259 | -------------------------------------------------------------------------------- /tests/bayes/core/test_normal.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import pytest 3 | import torch 4 | from sklearn.mixture._gaussian_mixture import _compute_log_det_cholesky # type: ignore 5 | from sklearn.mixture._gaussian_mixture import _compute_precision_cholesky # type: ignore 6 | from torch.distributions import MultivariateNormal 7 | from pycave.bayes.core import cholesky_precision, covariance, log_normal, sample_normal 8 | from pycave.bayes.core._jit import _cholesky_logdet # type: ignore 9 | from tests._data.normal import ( 10 | sample_data, 11 | sample_diag_covars, 12 | sample_full_covars, 13 | sample_means, 14 | sample_spherical_covars, 15 | ) 16 | 17 | # ------------------------------------------------------------------------------------------------- 18 | # CHOLESKY PRECISIONS 19 | # ------------------------------------------------------------------------------------------------- 20 | 21 | 22 | @pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200])) 23 | def test_cholesky_precision_spherical(covars: torch.Tensor): 24 | expected = _compute_precision_cholesky(covars.numpy(), "spherical") # type: ignore 25 | actual = cholesky_precision(covars, "spherical") 26 | assert torch.allclose( 27 | torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4 28 | ) 29 | 30 | 31 | @pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100])) 32 | def test_cholesky_precision_diag(covars: torch.Tensor): 33 | expected = _compute_precision_cholesky(covars.numpy(), "diag") # type: ignore 34 | actual = cholesky_precision(covars, "diag") 35 | assert torch.allclose( 36 | torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4 37 | ) 38 | 39 | 40 | @pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100])) 41 | def test_cholesky_precision_full(covars: torch.Tensor): 42 | expected = _compute_precision_cholesky(covars.numpy(), "full") # type: ignore 43 | actual = cholesky_precision(covars, "full") 44 | assert torch.allclose( 45 | torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4 46 | ) 47 | 48 | 49 | @pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100])) 50 | def test_cholesky_precision_tied(covars: torch.Tensor): 51 | expected = _compute_precision_cholesky(covars.numpy(), "tied") # type: ignore 52 | actual = cholesky_precision(covars, "tied") 53 | assert torch.allclose( 54 | torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4 55 | ) 56 | 57 | 58 | # ------------------------------------------------------------------------------------------------- 59 | # COVARIANCES 60 | # ------------------------------------------------------------------------------------------------- 61 | 62 | 63 | @pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200])) 64 | def test_covariances_spherical(covars: torch.Tensor): 65 | precision_cholesky = _compute_precision_cholesky(covars.numpy(), "spherical") # type: ignore 66 | actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.float), "spherical") 67 | assert torch.allclose(covars, actual) 68 | 69 | 70 | @pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100])) 71 | def test_covariances_diag(covars: torch.Tensor): 72 | precision_cholesky = _compute_precision_cholesky(covars.numpy(), "diag") # type: ignore 73 | actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.float), "diag") 74 | assert torch.allclose(covars, actual) 75 | 76 | 77 | @pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100])) 78 | def test_covariances_full(covars: torch.Tensor): 79 | precision_cholesky = _compute_precision_cholesky(covars.numpy(), "full") # type: ignore 80 | actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.double), "full") 81 | assert torch.allclose(covars, covars.transpose(1, 2)) 82 | assert torch.allclose(covars.to(torch.double), actual) 83 | 84 | 85 | @pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100])) 86 | def test_covariances_tied(covars: torch.Tensor): 87 | precision_cholesky = _compute_precision_cholesky(covars.numpy(), "tied") # type: ignore 88 | actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.double), "tied") 89 | assert torch.allclose(covars, covars.T) 90 | assert torch.allclose(covars.to(torch.double), actual) 91 | 92 | 93 | # ------------------------------------------------------------------------------------------------- 94 | # CHOLESKY LOG DETERMINANTS 95 | # ------------------------------------------------------------------------------------------------- 96 | 97 | 98 | @pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200])) 99 | def test_cholesky_logdet_spherical(covars: torch.Tensor): 100 | expected = _compute_log_det_cholesky( # type: ignore 101 | _compute_precision_cholesky(covars.numpy(), "spherical"), "spherical", 100 # type: ignore 102 | ) 103 | actual = _cholesky_logdet( # type: ignore 104 | 100, 105 | cholesky_precision(covars, "spherical"), 106 | "spherical", 107 | ) 108 | assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual) 109 | 110 | 111 | @pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100])) 112 | def test_cholesky_logdet_diag(covars: torch.Tensor): 113 | expected = _compute_log_det_cholesky( # type: ignore 114 | _compute_precision_cholesky(covars.numpy(), "diag"), # type: ignore 115 | "diag", 116 | covars.size(1), 117 | ) 118 | actual = _cholesky_logdet( # type: ignore 119 | covars.size(1), 120 | cholesky_precision(covars, "diag"), 121 | "diag", 122 | ) 123 | assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual) 124 | 125 | 126 | @pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100])) 127 | def test_cholesky_logdet_full(covars: torch.Tensor): 128 | expected = _compute_log_det_cholesky( # type: ignore 129 | _compute_precision_cholesky(covars.numpy(), "full"), # type: ignore 130 | "full", 131 | covars.size(1), 132 | ) 133 | actual = _cholesky_logdet( # type: ignore 134 | covars.size(1), 135 | cholesky_precision(covars, "full"), 136 | "full", 137 | ) 138 | assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual) 139 | 140 | 141 | @pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100])) 142 | def test_cholesky_logdet_tied(covars: torch.Tensor): 143 | expected = _compute_log_det_cholesky( # type: ignore 144 | _compute_precision_cholesky(covars.numpy(), "tied"), # type: ignore 145 | "tied", 146 | covars.size(0), 147 | ) 148 | actual = _cholesky_logdet( # type: ignore 149 | covars.size(0), 150 | cholesky_precision(covars, "tied"), 151 | "tied", 152 | ) 153 | assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual) 154 | 155 | 156 | # ------------------------------------------------------------------------------------------------- 157 | # LOG NORMAL 158 | # ------------------------------------------------------------------------------------------------- 159 | 160 | 161 | @pytest.mark.parametrize( 162 | "x, means, covars", 163 | zip( 164 | sample_data([10, 50, 100], [3, 50, 100]), 165 | sample_means([70, 5, 200], [3, 50, 100]), 166 | sample_spherical_covars([70, 5, 200]), 167 | ), 168 | ) 169 | def test_log_normal_spherical(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor): 170 | covar_matrices = torch.stack([torch.eye(means.size(-1)) * c for c in covars]) 171 | precisions_cholesky = cholesky_precision(covars, "spherical") 172 | actual = log_normal(x, means, precisions_cholesky, covariance_type="spherical") 173 | _assert_log_prob(actual, x, means, covar_matrices) 174 | 175 | 176 | @pytest.mark.parametrize( 177 | "x, means, covars", 178 | zip( 179 | sample_data([10, 50, 100], [3, 50, 100]), 180 | sample_means([70, 5, 200], [3, 50, 100]), 181 | sample_diag_covars([70, 5, 200], [3, 50, 100]), 182 | ), 183 | ) 184 | def test_log_normal_diag(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor): 185 | covar_matrices = torch.stack([torch.diag(c) for c in covars]) 186 | precisions_cholesky = cholesky_precision(covars, "diag") 187 | actual = log_normal(x, means, precisions_cholesky, covariance_type="diag") 188 | _assert_log_prob(actual, x, means, covar_matrices) 189 | 190 | 191 | @pytest.mark.parametrize( 192 | "x, means, covars", 193 | zip( 194 | sample_data([10, 50, 100], [3, 50, 100]), 195 | sample_means([70, 5, 200], [3, 50, 100]), 196 | sample_full_covars([70, 5, 200], [3, 50, 100]), 197 | ), 198 | ) 199 | def test_log_normal_full(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor): 200 | precisions_cholesky = cholesky_precision(covars, "full") 201 | actual = log_normal(x, means, precisions_cholesky, covariance_type="full") 202 | _assert_log_prob(actual.float(), x, means, covars) 203 | 204 | 205 | @pytest.mark.parametrize( 206 | "x, means, covars", 207 | zip( 208 | sample_data([10, 50, 100], [3, 50, 100]), 209 | sample_means([70, 5, 200], [3, 50, 100]), 210 | sample_full_covars([1, 1, 1], [3, 50, 100]), 211 | ), 212 | ) 213 | def test_log_normal_tied(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor): 214 | precisions_cholesky = cholesky_precision(covars, "tied") 215 | actual = log_normal(x, means, precisions_cholesky, covariance_type="tied") 216 | _assert_log_prob(actual, x, means, covars) 217 | 218 | 219 | # ------------------------------------------------------------------------------------------------- 220 | # SAMPLING 221 | # ------------------------------------------------------------------------------------------------- 222 | 223 | 224 | @pytest.mark.flaky(max_runs=3, min_passes=1) 225 | def test_sample_normal_spherical(): 226 | mean = torch.tensor([1.5, 3.5]) 227 | covar = torch.tensor(4.0) 228 | target_covar = torch.tensor([[4.0, 0.0], [0.0, 4.0]]) 229 | 230 | n = 1_000_000 231 | precisions = cholesky_precision(covar, "spherical") 232 | samples = sample_normal(n, mean, precisions, "spherical") 233 | 234 | sample_mean = samples.mean(0) 235 | sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n 236 | 237 | assert torch.allclose(mean, sample_mean, atol=1e-2) 238 | assert torch.allclose(target_covar, sample_covar, atol=1e-2) 239 | 240 | 241 | @pytest.mark.flaky(max_runs=3, min_passes=1) 242 | def test_sample_normal_diag(): 243 | mean = torch.tensor([1.5, 3.5]) 244 | covar = torch.tensor([0.5, 4.5]) 245 | target_covar = torch.tensor([[0.5, 0.0], [0.0, 4.5]]) 246 | 247 | n = 1_000_000 248 | precisions = cholesky_precision(covar, "diag") 249 | samples = sample_normal(n, mean, precisions, "diag") 250 | 251 | sample_mean = samples.mean(0) 252 | sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n 253 | 254 | assert torch.allclose(mean, sample_mean, atol=1e-2) 255 | assert torch.allclose(target_covar, sample_covar, atol=1e-2) 256 | 257 | 258 | @pytest.mark.flaky(max_runs=3, min_passes=1) 259 | def test_sample_normal_full(): 260 | mean = torch.tensor([1.5, 3.5]) 261 | covar = torch.tensor([[4.0, 2.5], [2.5, 2.0]]) 262 | 263 | n = 1_000_000 264 | precisions = cholesky_precision(covar, "tied") 265 | samples = sample_normal(n, mean, precisions, "full") 266 | 267 | sample_mean = samples.mean(0) 268 | sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n 269 | 270 | assert torch.allclose(mean, sample_mean, atol=1e-2) 271 | assert torch.allclose(covar, sample_covar, atol=1e-2) 272 | 273 | 274 | # ------------------------------------------------------------------------------------------------- 275 | 276 | 277 | def _assert_log_prob( 278 | actual: torch.Tensor, x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor 279 | ) -> None: 280 | distribution = MultivariateNormal(means, covars) 281 | expected = distribution.log_prob(x.unsqueeze(1)) 282 | assert torch.allclose(actual, expected, rtol=1e-3) 283 | -------------------------------------------------------------------------------- /pycave/clustering/kmeans/lightning_module.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=abstract-method 2 | import math 3 | from typing import List, Literal 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytorch_lightning.callbacks import EarlyStopping 7 | from torchmetrics import MeanMetric 8 | from pycave.utils import NonparametricLightningModule 9 | from .metrics import ( 10 | BatchAverager, 11 | BatchSummer, 12 | CentroidAggregator, 13 | DistanceSampler, 14 | UniformSampler, 15 | ) 16 | from .model import KMeansModel 17 | 18 | # ------------------------------------------------------------------------------------------------- 19 | # TRAINING 20 | 21 | 22 | class KMeansLightningModule(NonparametricLightningModule): 23 | """ 24 | Lightning module for training and evaluating a K-Means model. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model: KMeansModel, 30 | convergence_tolerance: float = 1e-4, 31 | predict_target: Literal["assignments", "distances", "inertias"] = "assignments", 32 | ): 33 | """ 34 | Args: 35 | model: The model to train. 36 | convergence_tolerance: Training is conducted until the Frobenius norm of the change 37 | between cluster centroids falls below this threshold. 38 | predict_target: Whether to predict cluster assigments or distances to clusters. 39 | """ 40 | super().__init__() 41 | 42 | self.model = model 43 | self.convergence_tolerance = convergence_tolerance 44 | self.predict_target = predict_target 45 | 46 | # Initialize aggregators 47 | self.centroid_aggregator = CentroidAggregator( 48 | num_clusters=self.model.config.num_clusters, 49 | num_features=self.model.config.num_features, 50 | dist_sync_fn=self.all_gather, 51 | ) 52 | 53 | # Initialize metrics 54 | self.metric_inertia = MeanMetric() 55 | 56 | def configure_callbacks(self) -> List[pl.Callback]: 57 | if self.convergence_tolerance == 0: 58 | return [] 59 | early_stopping = EarlyStopping( 60 | "frobenius_norm_change", 61 | patience=100000, 62 | stopping_threshold=self.convergence_tolerance, 63 | check_on_train_epoch_end=True, 64 | ) 65 | return [early_stopping] 66 | 67 | def on_train_epoch_start(self) -> None: 68 | self.centroid_aggregator.reset() 69 | 70 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None: 71 | # First, we compute the cluster assignments 72 | _, assignments, inertias = self.model.forward(batch) 73 | 74 | # Then, we update the centroids 75 | self.centroid_aggregator.update(batch, assignments) 76 | 77 | # And log the inertia 78 | self.metric_inertia.update(inertias) 79 | self.log("inertia", self.metric_inertia, on_step=False, on_epoch=True, prog_bar=True) 80 | 81 | def nonparametric_training_epoch_end(self) -> None: 82 | centroids = self.centroid_aggregator.compute() 83 | self.log("frobenius_norm_change", torch.linalg.norm(self.model.centroids - centroids)) 84 | self.model.centroids.copy_(centroids) 85 | 86 | def test_step(self, batch: torch.Tensor, _batch_idx: int) -> None: 87 | _, _, inertias = self.model.forward(batch) 88 | self.metric_inertia.update(inertias) 89 | self.log("inertia", self.metric_inertia) 90 | 91 | def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: 92 | distances, assignments, inertias = self.model.forward(batch) 93 | if self.predict_target == "assignments": 94 | return assignments 95 | if self.predict_target == "inertias": 96 | return inertias 97 | return distances 98 | 99 | 100 | # ------------------------------------------------------------------------------------------------- 101 | # INIT STRATEGIES 102 | 103 | 104 | class KmeansRandomInitLightningModule(NonparametricLightningModule): 105 | """ 106 | Lightning module for initializing K-Means centroids randomly. 107 | 108 | Within the first epoch, all items are sampled. Thus, this module should only be trained for a 109 | single epoch. 110 | """ 111 | 112 | def __init__(self, model: KMeansModel): 113 | """ 114 | Args: 115 | model: The model to initialize. 116 | """ 117 | super().__init__() 118 | 119 | self.model = model 120 | 121 | self.sampler = UniformSampler( 122 | num_choices=self.model.config.num_clusters, 123 | num_features=self.model.config.num_features, 124 | dist_sync_fn=self.all_gather_first, 125 | ) 126 | 127 | def on_train_epoch_start(self) -> None: 128 | self.sampler.reset() 129 | 130 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None: 131 | self.sampler.update(batch) 132 | 133 | def nonparametric_training_epoch_end(self) -> None: 134 | choices = self.sampler.compute() 135 | self.model.centroids.copy_(choices) 136 | 137 | 138 | class KmeansPlusPlusInitLightningModule(NonparametricLightningModule): 139 | """ 140 | Lightning module for K-Means++ initialization. It performs the following operations: 141 | 142 | - In the first epoch, a centroid is chosen at random. 143 | - In even epochs, candidates for the next centroid are sampled, based on the squared distance 144 | to their nearest cluster center. 145 | - In odd epochs, a candidate is selected deterministically as the next centroid. 146 | 147 | In total, initialization thus requires ``2 * k - 1`` epochs where ``k`` is the number of 148 | clusters. 149 | """ 150 | 151 | def __init__(self, model: KMeansModel, is_batch_training: bool): 152 | """ 153 | Args: 154 | model: The model to initialize. 155 | is_batch_training: Whether training is performed on mini-batches instead of the entire 156 | data at once. 157 | """ 158 | super().__init__() 159 | 160 | self.model = model 161 | self.is_batch_training = is_batch_training 162 | 163 | self.uniform_sampler = UniformSampler( 164 | num_choices=1, 165 | num_features=self.model.config.num_features, 166 | dist_sync_fn=self.all_gather_first, 167 | ) 168 | num_candidates = 2 + int(math.log(self.model.config.num_clusters)) 169 | self.distance_sampler = DistanceSampler( 170 | num_choices=num_candidates, 171 | num_features=self.model.config.num_features, 172 | dist_sync_fn=self.all_gather_first, 173 | ) 174 | self.candidate_inertia_summer = BatchSummer( 175 | num_candidates, 176 | dist_sync_fn=self.all_gather, 177 | ) 178 | 179 | # Some buffers required for running initialization 180 | self.centroid_candidates: torch.Tensor 181 | self.register_buffer( 182 | "centroid_candidates", 183 | torch.empty(num_candidates, self.model.config.num_features), 184 | persistent=False, 185 | ) 186 | 187 | if not self.is_batch_training: 188 | self.shortest_distance_cache: torch.Tensor 189 | self.register_buffer("shortest_distance_cache", torch.empty(1), persistent=False) 190 | 191 | def on_train_epoch_start(self) -> None: 192 | if self.current_epoch == 0: 193 | self.uniform_sampler.reset() 194 | elif self._is_current_epoch_sampling: 195 | self.distance_sampler.reset() 196 | else: 197 | self.candidate_inertia_summer.reset() 198 | 199 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None: 200 | if self.current_epoch == 0: 201 | self.uniform_sampler.update(batch) 202 | return 203 | # In all other epochs, we either sample a number of candidates from the remaining 204 | # datapoints or select a candidate deterministically. In any case, the shortest 205 | # distance is required. 206 | if self.current_epoch == 1: 207 | # In the first epoch, we can skip any argmin as the shortest distances are computed 208 | # with respect to the first centroid. 209 | shortest_distances = torch.cdist(batch, self.model.centroids[:1]).squeeze(1) 210 | if not self.is_batch_training: 211 | self.shortest_distance_cache = shortest_distances 212 | elif self.is_batch_training: 213 | # For batch training, we always need to recompute all distances since we can't 214 | # cache them (this is the whole reason for batch training). 215 | distances = torch.cdist(batch, self.model.centroids[: self._init_epoch + 1]) 216 | shortest_distances = distances.gather( 217 | 1, distances.min(1, keepdim=True).indices # min is faster than argmin on CPU 218 | ).squeeze(1) 219 | else: 220 | # If we're not doing batch training, we only need to compute the distance to the 221 | # newest centroid (and only if we're currently sampling) 222 | if self._is_current_epoch_sampling: 223 | latest_distance = torch.cdist( 224 | batch, self.model.centroids[self._init_epoch - 1].unsqueeze(0) 225 | ).squeeze(1) 226 | shortest_distances = torch.minimum(self.shortest_distance_cache, latest_distance) 227 | self.shortest_distance_cache = shortest_distances 228 | else: 229 | shortest_distances = self.shortest_distance_cache 230 | 231 | if self._is_current_epoch_sampling: 232 | # After computing the shortest distances, we can finally do the sampling 233 | self.distance_sampler.update(batch, shortest_distances) 234 | else: 235 | # Or, we select a candidate by the lowest resulting inertia 236 | distances = torch.cdist(batch, self.centroid_candidates) 237 | updated_distances = torch.minimum(distances, shortest_distances.unsqueeze(1)) 238 | self.candidate_inertia_summer.update(updated_distances) 239 | 240 | def nonparametric_training_epoch_end(self) -> None: 241 | if self.current_epoch == 0: 242 | choice = self.uniform_sampler.compute() 243 | self.model.centroids[0].copy_(choice[0] if choice.dim() > 0 else choice) 244 | elif self._is_current_epoch_sampling: 245 | candidates = self.distance_sampler.compute() 246 | self.centroid_candidates.copy_(candidates) 247 | else: 248 | new_inertias = self.candidate_inertia_summer.compute() 249 | choice = new_inertias.argmin() 250 | self.model.centroids[self._init_epoch].copy_(self.centroid_candidates[choice]) 251 | 252 | @property 253 | def _init_epoch(self) -> int: 254 | return (self.current_epoch + 1) // 2 255 | 256 | @property 257 | def _is_current_epoch_sampling(self) -> bool: 258 | return self.current_epoch % 2 == 1 259 | 260 | 261 | # ------------------------------------------------------------------------------------------------- 262 | # MISC 263 | 264 | 265 | class FeatureVarianceLightningModule(NonparametricLightningModule): 266 | """ 267 | Lightning module for computing the average variance of a dataset's features. 268 | 269 | In the first epoch, it computes the features' means, then it can compute their variances. 270 | """ 271 | 272 | def __init__(self, variances: torch.Tensor): 273 | """ 274 | Args: 275 | variances: The output tensor where the variances are stored. 276 | """ 277 | super().__init__() 278 | 279 | self.mean_aggregator = BatchAverager( 280 | num_values=variances.size(0), 281 | for_variance=False, 282 | dist_sync_fn=self.all_gather, 283 | ) 284 | self.variance_aggregator = BatchAverager( 285 | num_values=variances.size(0), 286 | for_variance=True, 287 | dist_sync_fn=self.all_gather, 288 | ) 289 | 290 | self.means: torch.Tensor 291 | self.register_buffer("means", torch.empty(variances.size(0)), persistent=False) 292 | 293 | self.variances: torch.Tensor 294 | self.register_buffer("variances", variances, persistent=False) 295 | 296 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None: 297 | if self.current_epoch == 0: 298 | self.mean_aggregator.update(batch) 299 | else: 300 | self.variance_aggregator.update((batch - self.means.unsqueeze(0)).square()) 301 | 302 | def nonparametric_training_epoch_end(self) -> None: 303 | if self.current_epoch == 0: 304 | self.means.copy_(self.mean_aggregator.compute()) 305 | else: 306 | self.variances.copy_(self.variance_aggregator.compute()) 307 | -------------------------------------------------------------------------------- /pycave/bayes/gmm/estimator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import logging 3 | from typing import Any, cast, List, Tuple 4 | import torch 5 | from lightkit import ConfigurableBaseEstimator 6 | from lightkit.data import collate_tensor, DataLoader, dataset_from_tensors, TensorLike 7 | from lightkit.estimator import PredictorMixin 8 | from pycave.bayes.core import CovarianceType 9 | from pycave.clustering import KMeans 10 | from .lightning_module import ( 11 | GaussianMixtureKmeansInitLightningModule, 12 | GaussianMixtureLightningModule, 13 | GaussianMixtureRandomInitLightningModule, 14 | ) 15 | from .model import GaussianMixtureModel, GaussianMixtureModelConfig 16 | from .types import GaussianMixtureInitStrategy 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class GaussianMixture( 22 | ConfigurableBaseEstimator[GaussianMixtureModel], # type: ignore 23 | PredictorMixin[TensorLike, torch.Tensor], 24 | ): 25 | """ 26 | Probabilistic model assuming that data is generated from a mixture of Gaussians. 27 | 28 | The mixture is assumed to be composed of a fixed number of components with individual means 29 | and covariances. More information on Gaussian mixture models (GMMs) is available on 30 | `Wikipedia `_. 31 | 32 | See also: 33 | .. currentmodule:: pycave.bayes.gmm 34 | .. autosummary:: 35 | :nosignatures: 36 | :template: classes/pytorch_module.rst 37 | 38 | GaussianMixtureModel 39 | GaussianMixtureModelConfig 40 | """ 41 | 42 | #: The fitted PyTorch module with all estimated parameters. 43 | model_: GaussianMixtureModel 44 | #: A boolean indicating whether the model converged during training. 45 | converged_: bool 46 | #: The number of iterations the model was fitted for, excluding initialization. 47 | num_iter_: int 48 | #: The average per-datapoint negative log-likelihood at the last training step. 49 | nll_: float 50 | 51 | def __init__( 52 | self, 53 | num_components: int = 1, 54 | *, 55 | covariance_type: CovarianceType = "diag", 56 | init_strategy: GaussianMixtureInitStrategy = "kmeans", 57 | init_means: torch.Tensor | None = None, 58 | convergence_tolerance: float = 1e-3, 59 | covariance_regularization: float = 1e-6, 60 | batch_size: int | None = None, 61 | trainer_params: dict[str, Any] | None = None, 62 | ): 63 | """ 64 | Args: 65 | num_components: The number of components in the GMM. The dimensionality of each 66 | component is automatically inferred from the data. 67 | covariance_type: The type of covariance to assume for all Gaussian components. 68 | init_strategy: The strategy for initializing component means and covariances. 69 | init_means: An optional initial guess for the means of the components. If provided, 70 | must be a tensor of shape ``[num_components, num_features]``. If this is given, 71 | the ``init_strategy`` is ignored and the means are handled as if K-means 72 | initialization has been run. 73 | convergence_tolerance: The change in the per-datapoint negative log-likelihood which 74 | implies that training has converged. 75 | covariance_regularization: A small value which is added to the diagonal of the 76 | covariance matrix to ensure that it is positive semi-definite. 77 | batch_size: The batch size to use when fitting the model. If not provided, the full 78 | data will be used as a single batch. Set this if the full data does not fit into 79 | memory. 80 | num_workers: The number of workers to use for loading the data. Only used if a PyTorch 81 | dataset is passed to :meth:`fit` or related methods. 82 | trainer_params: Initialization parameters to use when initializing a PyTorch Lightning 83 | trainer. By default, it disables various stdout logs unless PyCave is configured to 84 | do verbose logging. Checkpointing and logging are disabled regardless of the log 85 | level. This estimator further sets the following overridable defaults: 86 | 87 | - ``max_epochs=100`` 88 | 89 | Note: 90 | The number of epochs passed to the initializer only define the number of optimization 91 | epochs. Prior to that, initialization is run which may perform additional iterations 92 | through the data. 93 | 94 | Note: 95 | For batch training, the number of epochs run (i.e. the number of passes through the 96 | data), does not align with the number of epochs passed to the initializer. This is 97 | because the EM algorithm needs to be split up across two epochs. The actual number of 98 | minimum/maximum epochs is, thus, doubled. Nonetheless, :attr:`num_iter_` indicates how 99 | many EM iterations have been run. 100 | """ 101 | super().__init__( 102 | default_params=dict(max_epochs=100), 103 | user_params=trainer_params, 104 | ) 105 | 106 | self.num_components = num_components 107 | self.covariance_type = covariance_type 108 | self.init_strategy = init_strategy 109 | self.init_means = init_means 110 | self.convergence_tolerance = convergence_tolerance 111 | self.covariance_regularization = covariance_regularization 112 | 113 | self.batch_size = batch_size 114 | 115 | def fit(self, data: TensorLike) -> GaussianMixture: 116 | """ 117 | Fits the Gaussian mixture on the provided data, estimating component priors, means and 118 | covariances. Parameters are estimated using the EM algorithm. 119 | 120 | Args: 121 | data: The tabular data to fit on. The dimensionality of the Gaussian mixture is 122 | automatically inferred from this data. 123 | 124 | Returns: 125 | The fitted Gaussian mixture. 126 | """ 127 | # Initialize the model 128 | num_features = len(data[0]) 129 | config = GaussianMixtureModelConfig( 130 | num_components=self.num_components, 131 | num_features=num_features, 132 | covariance_type=self.covariance_type, # type: ignore 133 | ) 134 | self.model_ = GaussianMixtureModel(config) 135 | 136 | # Setup the data loading 137 | loader = DataLoader( 138 | dataset_from_tensors(data), 139 | batch_size=self.batch_size or len(data), 140 | collate_fn=collate_tensor, 141 | ) 142 | is_batch_training = self._num_batches_per_epoch(loader) == 1 143 | 144 | # Run k-means if required or copy means 145 | if self.init_means is not None: 146 | self.model_.means.copy_(self.init_means) 147 | elif self.init_strategy in ("kmeans", "kmeans++"): 148 | logger.info("Fitting K-means estimator...") 149 | params = self.trainer_params_user 150 | if self.init_strategy == "kmeans++": 151 | params = {**(params or {}), **dict(max_epochs=0)} 152 | 153 | estimator = KMeans( 154 | self.num_components, 155 | batch_size=self.batch_size, 156 | trainer_params=params, 157 | ).fit(data) 158 | self.model_.means.copy_(estimator.model_.centroids) 159 | 160 | # Run initialization 161 | logger.info("Running initialization...") 162 | if self.init_strategy in ("kmeans", "kmeans++") and self.init_means is None: 163 | module = GaussianMixtureKmeansInitLightningModule( 164 | self.model_, 165 | covariance_regularization=self.covariance_regularization, 166 | ) 167 | self.trainer(max_epochs=1).fit(module, loader) 168 | else: 169 | module = GaussianMixtureRandomInitLightningModule( 170 | self.model_, 171 | covariance_regularization=self.covariance_regularization, 172 | is_batch_training=is_batch_training, 173 | use_model_means=self.init_means is not None, 174 | ) 175 | self.trainer(max_epochs=1 + int(is_batch_training)).fit(module, loader) 176 | 177 | # Fit model 178 | logger.info("Fitting Gaussian mixture...") 179 | module = GaussianMixtureLightningModule( 180 | self.model_, 181 | convergence_tolerance=self.convergence_tolerance, 182 | covariance_regularization=self.covariance_regularization, 183 | is_batch_training=is_batch_training, 184 | ) 185 | trainer = self.trainer( 186 | max_epochs=cast(int, self.trainer_params["max_epochs"]) * (1 + int(is_batch_training)) 187 | ) 188 | trainer.fit(module, loader) 189 | 190 | # Assign convergence properties 191 | self.num_iter_ = module.current_epoch 192 | if is_batch_training: 193 | self.num_iter_ //= 2 194 | self.converged_ = trainer.should_stop 195 | self.nll_ = cast(float, trainer.callback_metrics["nll"].item()) 196 | return self 197 | 198 | def sample(self, num_datapoints: int) -> torch.Tensor: 199 | """ 200 | Samples datapoints from the fitted Gaussian mixture. 201 | 202 | Args: 203 | num_datapoints: The number of datapoints to sample. 204 | 205 | Returns: 206 | A tensor of shape ``[num_datapoints, dim]`` providing the samples. 207 | 208 | Note: 209 | This method does not parallelize across multiple processes, i.e. performs no 210 | synchronization. 211 | """ 212 | return self.model_.sample(num_datapoints) 213 | 214 | def score(self, data: TensorLike) -> float: 215 | """ 216 | Computes the average negative log-likelihood (NLL) of the provided datapoints. 217 | 218 | Args: 219 | data: The datapoints for which to evaluate the NLL. 220 | 221 | Returns: 222 | The average NLL of all datapoints. 223 | 224 | Note: 225 | See :meth:`score_samples` to obtain NLL values for individual datapoints. 226 | """ 227 | loader = DataLoader( 228 | dataset_from_tensors(data), 229 | batch_size=self.batch_size or len(data), 230 | collate_fn=collate_tensor, 231 | ) 232 | result = self.trainer().test( 233 | GaussianMixtureLightningModule(self.model_), loader, verbose=False 234 | ) 235 | return result[0]["nll"] 236 | 237 | def score_samples(self, data: TensorLike) -> torch.Tensor: 238 | """ 239 | Computes the negative log-likelihood (NLL) of each of the provided datapoints. 240 | 241 | Args: 242 | data: The datapoints for which to compute the NLL. 243 | 244 | Returns: 245 | A tensor of shape ``[num_datapoints]`` with the NLL for each datapoint. 246 | 247 | Attention: 248 | When calling this function in a multi-process environment, each process receives only 249 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 250 | the values returned from this method. 251 | """ 252 | loader = DataLoader( 253 | dataset_from_tensors(data), 254 | batch_size=self.batch_size or len(data), 255 | collate_fn=collate_tensor, 256 | ) 257 | result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader) 258 | return torch.stack([x[1] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)]) 259 | 260 | def predict(self, data: TensorLike) -> torch.Tensor: 261 | """ 262 | Computes the most likely components for each of the provided datapoints. 263 | 264 | Args: 265 | data: The datapoints for which to obtain the most likely components. 266 | 267 | Returns: 268 | A tensor of shape ``[num_datapoints]`` with the indices of the most likely components. 269 | 270 | Note: 271 | Use :meth:`predict_proba` to obtain probabilities for each component instead of the 272 | most likely component only. 273 | 274 | Attention: 275 | When calling this function in a multi-process environment, each process receives only 276 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 277 | the values returned from this method. 278 | """ 279 | return self.predict_proba(data).argmax(-1) 280 | 281 | def predict_proba(self, data: TensorLike) -> torch.Tensor: 282 | """ 283 | Computes a distribution over the components for each of the provided datapoints. 284 | 285 | Args: 286 | data: The datapoints for which to compute the component assignment probabilities. 287 | 288 | Returns: 289 | A tensor of shape ``[num_datapoints, num_components]`` with the assignment 290 | probabilities for each component and datapoint. Note that each row of the vector sums 291 | to 1, i.e. the returned tensor provides a proper distribution over the components for 292 | each datapoint. 293 | 294 | Attention: 295 | When calling this function in a multi-process environment, each process receives only 296 | a subset of the predictions. If you want to aggregate predictions, make sure to gather 297 | the values returned from this method. 298 | """ 299 | loader = DataLoader( 300 | dataset_from_tensors(data), 301 | batch_size=self.batch_size or len(data), 302 | collate_fn=collate_tensor, 303 | ) 304 | result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader) 305 | return torch.cat([x[0] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)]) 306 | -------------------------------------------------------------------------------- /pycave/bayes/gmm/lightning_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import pytorch_lightning as pl 3 | import torch 4 | from pytorch_lightning.callbacks import EarlyStopping 5 | from torchmetrics import MeanMetric 6 | from pycave.bayes.core import cholesky_precision 7 | from pycave.utils import NonparametricLightningModule 8 | from .metrics import CovarianceAggregator, MeanAggregator, PriorAggregator 9 | from .model import GaussianMixtureModel 10 | 11 | # ------------------------------------------------------------------------------------------------- 12 | # TRAINING 13 | 14 | 15 | class GaussianMixtureLightningModule(NonparametricLightningModule): 16 | """ 17 | Lightning module for training and evaluating a Gaussian mixture model. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | model: GaussianMixtureModel, 23 | convergence_tolerance: float = 1e-3, 24 | covariance_regularization: float = 1e-6, 25 | is_batch_training: bool = False, 26 | ): 27 | """ 28 | Args: 29 | model: The Gaussian mixture model to use for training/evaluation. 30 | convergence_tolerance: The change in the per-datapoint negative log-likelihood which 31 | implies that training has converged. 32 | covariance_regularization: A small value which is added to the diagonal of the 33 | covariance matrix to ensure that it is positive semi-definite. 34 | is_batch_training: Whether training is performed on mini-batches instead of the entire 35 | data at once. In the case of batching, the EM-algorithm is "split" across two 36 | epochs. 37 | """ 38 | super().__init__() 39 | 40 | self.model = model 41 | self.convergence_tolerance = convergence_tolerance 42 | self.is_batch_training = is_batch_training 43 | 44 | # For batch training, we store a model copy such that we can "replay" responsibilities 45 | if self.is_batch_training: 46 | self.model_copy = GaussianMixtureModel(self.model.config) 47 | self.model_copy.load_state_dict(self.model.state_dict()) 48 | 49 | # Initialize aggregators 50 | self.prior_aggregator = PriorAggregator( 51 | num_components=self.model.config.num_components, 52 | dist_sync_fn=self.all_gather, 53 | ) 54 | self.mean_aggregator = MeanAggregator( 55 | num_components=self.model.config.num_components, 56 | num_features=self.model.config.num_features, 57 | dist_sync_fn=self.all_gather, 58 | ) 59 | self.covar_aggregator = CovarianceAggregator( 60 | num_components=self.model.config.num_components, 61 | num_features=self.model.config.num_features, 62 | covariance_type=self.model.config.covariance_type, 63 | reg=covariance_regularization, 64 | dist_sync_fn=self.all_gather, 65 | ) 66 | 67 | # Initialize metrics 68 | self.metric_nll = MeanMetric(dist_sync_fn=self.all_gather) 69 | 70 | def configure_callbacks(self) -> list[pl.Callback]: 71 | if self.convergence_tolerance == 0: 72 | return [] 73 | early_stopping = EarlyStopping( 74 | "nll", 75 | min_delta=self.convergence_tolerance, 76 | patience=2 if self.is_batch_training else 1, 77 | check_on_train_epoch_end=True, 78 | strict=False, # Allows to not log every epoch 79 | ) 80 | return [early_stopping] 81 | 82 | def on_train_epoch_start(self) -> None: 83 | self.prior_aggregator.reset() 84 | self.mean_aggregator.reset() 85 | self.covar_aggregator.reset() 86 | 87 | def nonparametric_training_step(self, batch: torch.Tensor, _batch_idx: int) -> None: 88 | ### E-Step 89 | if self._computes_responsibilities_on_live_model: 90 | log_responsibilities, log_probs = self.model.forward(batch) 91 | else: 92 | log_responsibilities, log_probs = self.model_copy.forward(batch) 93 | responsibilities = log_responsibilities.exp() 94 | 95 | # Compute the NLL for early stopping 96 | if self._should_log_nll: 97 | self.metric_nll.update(-log_probs) 98 | self.log("nll", self.metric_nll, on_step=False, on_epoch=True, prog_bar=True) 99 | 100 | ### (Partial) M-Step 101 | if self._should_update_means: 102 | self.prior_aggregator.update(responsibilities) 103 | self.mean_aggregator.update(batch, responsibilities) 104 | if self._should_update_covars: 105 | means = self.mean_aggregator.compute() 106 | self.covar_aggregator.update(batch, responsibilities, means) 107 | else: 108 | self.covar_aggregator.update(batch, responsibilities, self.model.means) 109 | 110 | def nonparametric_training_epoch_end(self) -> None: 111 | # Prior to updating the model, we might need to copy it in the case of batch training 112 | if self._requires_to_copy_live_model: 113 | self.model_copy.load_state_dict(self.model.state_dict()) 114 | 115 | # Finalize the M-Step 116 | if self._should_update_means: 117 | priors = self.prior_aggregator.compute() 118 | self.model.component_probs.copy_(priors) 119 | 120 | means = self.mean_aggregator.compute() 121 | self.model.means.copy_(means) 122 | 123 | if self._should_update_covars: 124 | covars = self.covar_aggregator.compute() 125 | self.model.precisions_cholesky.copy_( 126 | cholesky_precision(covars, self.model.config.covariance_type) 127 | ) 128 | 129 | def test_step(self, batch: torch.Tensor, _batch_idx: int) -> None: 130 | _, log_probs = self.model.forward(batch) 131 | self.metric_nll.update(-log_probs) 132 | self.log("nll", self.metric_nll) 133 | 134 | def predict_step( 135 | self, batch: torch.Tensor, batch_idx: int 136 | ) -> tuple[torch.Tensor, torch.Tensor]: 137 | log_responsibilities, log_probs = self.model.forward(batch) 138 | return log_responsibilities.exp(), -log_probs 139 | 140 | @property 141 | def _computes_responsibilities_on_live_model(self) -> bool: 142 | if not self.is_batch_training: 143 | return True 144 | return self.current_epoch % 2 == 0 145 | 146 | @property 147 | def _requires_to_copy_live_model(self) -> bool: 148 | if not self.is_batch_training: 149 | return False 150 | return self.current_epoch % 2 == 0 151 | 152 | @property 153 | def _should_log_nll(self) -> bool: 154 | if not self.is_batch_training: 155 | return True 156 | return self.current_epoch % 2 == 1 157 | 158 | @property 159 | def _should_update_means(self) -> bool: 160 | if not self.is_batch_training: 161 | return True 162 | return self.current_epoch % 2 == 0 163 | 164 | @property 165 | def _should_update_covars(self) -> bool: 166 | if not self.is_batch_training: 167 | return True 168 | return self.current_epoch % 2 == 1 169 | 170 | 171 | # ------------------------------------------------------------------------------------------------- 172 | # INIT STRATEGIES 173 | 174 | 175 | class GaussianMixtureKmeansInitLightningModule(NonparametricLightningModule): 176 | """ 177 | Lightning module for initializing a Gaussian mixture from centroids found via K-Means. 178 | """ 179 | 180 | def __init__(self, model: GaussianMixtureModel, covariance_regularization: float): 181 | """ 182 | Args: 183 | model: The model whose parameters to initialize. 184 | covariance_regularization: A small value which is added to the diagonal of the 185 | covariance matrix to ensure that it is positive semi-definite. 186 | """ 187 | super().__init__() 188 | 189 | self.model = model 190 | 191 | self.prior_aggregator = PriorAggregator( 192 | num_components=self.model.config.num_components, 193 | dist_sync_fn=self.all_gather, 194 | ) 195 | self.covar_aggregator = CovarianceAggregator( 196 | num_components=self.model.config.num_components, 197 | num_features=self.model.config.num_features, 198 | covariance_type=self.model.config.covariance_type, 199 | reg=covariance_regularization, 200 | dist_sync_fn=self.all_gather, 201 | ) 202 | 203 | def on_train_epoch_start(self) -> None: 204 | self.prior_aggregator.reset() 205 | self.covar_aggregator.reset() 206 | 207 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None: 208 | # Just like for k-means, responsibilities are one-hot assignments to the clusters 209 | responsibilities = _one_hot_responsibilities(batch, self.model.means) 210 | 211 | # Then, we can update the aggregators 212 | self.prior_aggregator.update(responsibilities) 213 | self.covar_aggregator.update(batch, responsibilities, self.model.means) 214 | 215 | def nonparametric_training_epoch_end(self) -> None: 216 | priors = self.prior_aggregator.compute() 217 | self.model.component_probs.copy_(priors) 218 | 219 | covars = self.covar_aggregator.compute() 220 | self.model.precisions_cholesky.copy_( 221 | cholesky_precision(covars, self.model.config.covariance_type) 222 | ) 223 | 224 | 225 | class GaussianMixtureRandomInitLightningModule(NonparametricLightningModule): 226 | """ 227 | Lightning module for initializing a Gaussian mixture randomly or using the assignments for 228 | arbitrary means that were not found via K-means. 229 | 230 | For batch training, this requires two epochs, otherwise, it requires a single epoch. 231 | """ 232 | 233 | def __init__( 234 | self, 235 | model: GaussianMixtureModel, 236 | covariance_regularization: float, 237 | is_batch_training: bool, 238 | use_model_means: bool, 239 | ): 240 | """ 241 | Args: 242 | model: The model whose parameters to initialize. 243 | covariance_regularization: A small value which is added to the diagonal of the 244 | covariance matrix to ensure that it is positive semi-definite. 245 | is_batch_training: Whether training is performed on mini-batches instead of the entire 246 | data at once. 247 | use_model_means: Whether the model's means ought to be used for one-hot component 248 | assignments. 249 | """ 250 | super().__init__() 251 | 252 | self.model = model 253 | self.is_batch_training = is_batch_training 254 | self.use_model_means = use_model_means 255 | 256 | self.prior_aggregator = PriorAggregator( 257 | num_components=self.model.config.num_components, 258 | dist_sync_fn=self.all_gather, 259 | ) 260 | self.mean_aggregator = MeanAggregator( 261 | num_components=self.model.config.num_components, 262 | num_features=self.model.config.num_features, 263 | dist_sync_fn=self.all_gather, 264 | ) 265 | self.covar_aggregator = CovarianceAggregator( 266 | num_components=self.model.config.num_components, 267 | num_features=self.model.config.num_features, 268 | covariance_type=self.model.config.covariance_type, 269 | reg=covariance_regularization, 270 | dist_sync_fn=self.all_gather, 271 | ) 272 | 273 | # For batch training, we store a model copy such that we can "replay" responsibilities 274 | if self.is_batch_training and self.use_model_means: 275 | self.model_copy = GaussianMixtureModel(self.model.config) 276 | self.model_copy.load_state_dict(self.model.state_dict()) 277 | 278 | def on_train_epoch_start(self) -> None: 279 | self.prior_aggregator.reset() 280 | self.mean_aggregator.reset() 281 | self.covar_aggregator.reset() 282 | 283 | def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None: 284 | if self.use_model_means: 285 | if self.current_epoch == 0: 286 | responsibilities = _one_hot_responsibilities(batch, self.model.means) 287 | else: 288 | responsibilities = _one_hot_responsibilities(batch, self.model_copy.means) 289 | else: 290 | responsibilities = torch.rand( 291 | batch.size(0), 292 | self.model.config.num_components, 293 | device=batch.device, 294 | dtype=batch.dtype, 295 | ) 296 | responsibilities = responsibilities / responsibilities.sum(1, keepdim=True) 297 | 298 | if self.current_epoch == 0: 299 | self.prior_aggregator.update(responsibilities) 300 | self.mean_aggregator.update(batch, responsibilities) 301 | if not self.is_batch_training: 302 | means = self.mean_aggregator.compute() 303 | self.covar_aggregator.update(batch, responsibilities, means) 304 | else: 305 | # Only reached if batch training 306 | self.covar_aggregator.update(batch, responsibilities, self.model.means) 307 | 308 | def nonparametric_training_epoch_end(self) -> None: 309 | if self.current_epoch == 0 and self.is_batch_training: 310 | self.model_copy.load_state_dict(self.model.state_dict()) 311 | 312 | if self.current_epoch == 0: 313 | priors = self.prior_aggregator.compute() 314 | self.model.component_probs.copy_(priors) 315 | 316 | means = self.mean_aggregator.compute() 317 | self.model.means.copy_(means) 318 | 319 | if (self.current_epoch == 0 and not self.is_batch_training) or self.current_epoch == 1: 320 | covars = self.covar_aggregator.compute() 321 | self.model.precisions_cholesky.copy_( 322 | cholesky_precision(covars, self.model.config.covariance_type) 323 | ) 324 | 325 | 326 | def _one_hot_responsibilities(data: torch.Tensor, centroids: torch.Tensor) -> torch.Tensor: 327 | distances = torch.cdist(data, centroids) 328 | assignments = distances.min(1).indices 329 | onehot = torch.eye( 330 | centroids.size(0), 331 | device=data.device, 332 | dtype=data.dtype, 333 | ) 334 | return onehot[assignments] 335 | --------------------------------------------------------------------------------