├── pyrichlet ├── tests │ ├── __init__.py │ ├── _base.py │ ├── test_convergence.py │ ├── test_mixture_models_variational.py │ ├── test_mixture_models_gibbs.py │ └── test_weight_models.py ├── utils │ ├── __init__.py │ ├── validators.py │ ├── _loaders.py │ ├── functions.py │ └── _data │ │ ├── chaetocnema.csv │ │ └── penguins.csv ├── exceptions.py ├── __init__.py ├── mixture_models │ ├── _beta_bernoulli_mixture.py │ ├── _dirichlet_process_mixture.py │ ├── _geometric_process_mixture.py │ ├── _beta_in_dirichlet_mixture.py │ ├── _pitman_yor_process_mixture.py │ ├── _equal_weighted_mixture.py │ ├── _frequency_weighted_mixture.py │ ├── _dirichlet_distribution_mixture.py │ ├── __init__.py │ ├── _beta_binomial_mixture.py │ ├── _beta_in_beta_mixture.py │ ├── _utils.py │ └── _base.py ├── weight_models │ ├── __init__.py │ ├── _equal.py │ ├── _frequency.py │ ├── _beta_binomial.py │ ├── _dirichlet_distribution.py │ ├── _geometric_process.py │ ├── _beta_bernoulli.py │ ├── _beta_in_dirichlet.py │ ├── _dirichlet_process.py │ ├── _pitman_yor_process.py │ ├── _beta_in_beta.py │ └── _base.py └── _version.py ├── .gitattributes ├── docs ├── .gitignore ├── requirements.txt ├── _templates │ └── autosummary │ │ └── class.rst ├── index.rst ├── Makefile ├── make.bat ├── models.rst ├── installation.rst ├── conf.py └── usage.rst ├── MANIFEST.in ├── pyproject.toml ├── setup.cfg ├── .readthedocs.yaml ├── .github └── workflows │ └── python-package.yml ├── README.md ├── setup.py ├── .gitignore └── LICENSE /pyrichlet/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | pyrichlet/_version.py export-subst 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | mixture_models/ 2 | weight_models/ -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx>=5.0.0 2 | numpydoc>=1.1.0 3 | sphinx_rtd_theme -------------------------------------------------------------------------------- /pyrichlet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._loaders import load_chaetocnema, load_penguins 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include versioneer.py 2 | include pyrichlet/_version.py 3 | include pyrichlet/utils/_data/*.csv -------------------------------------------------------------------------------- /pyrichlet/exceptions.py: -------------------------------------------------------------------------------- 1 | class NotFittedError(Exception): 2 | """Exception to raise if class method must be called after fitting.""" 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | "versioneer-518" 6 | ] 7 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /pyrichlet/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixture_models import * 2 | from .weight_models import * 3 | from . import utils 4 | from . import _version 5 | __version__ = _version.get_versions()['version'] 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [versioneer] 2 | VCS = git 3 | style = pep440 4 | versionfile_source = pyrichlet/_version.py 5 | versionfile_build = pyrichlet/_version.py 6 | tag_prefix = 7 | parentdir_prefix = pyrichlet- -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :inherited-members: 8 | 9 | .. autosummary:: 10 | -------------------------------------------------------------------------------- /pyrichlet/utils/validators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rng_parser(rng): 5 | if rng is None: 6 | return np.random.default_rng() 7 | if type(rng) is int: 8 | return np.random.default_rng(rng) 9 | if type(rng) is np.random.Generator: 10 | return rng 11 | raise TypeError("Invalid random number generator") 12 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_beta_bernoulli_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import BetaBernoulli 3 | 4 | 5 | class BetaBernoulliMixture(BaseGaussianMixture): 6 | def __init__(self, *, p=1, alpha=1, rng=None, **kwargs): 7 | weight_model = BetaBernoulli(p=p, alpha=alpha, rng=rng) 8 | super().__init__(weight_model=weight_model, rng=rng, **kwargs) 9 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_dirichlet_process_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import DirichletProcess 3 | 4 | 5 | class DirichletProcessMixture(BaseGaussianMixture): 6 | def __init__(self, *, alpha=1, rng=None, **kwargs): 7 | weight_model = DirichletProcess(alpha=alpha, rng=rng) 8 | super().__init__(weight_model=weight_model, rng=rng, **kwargs) 9 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_geometric_process_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import GeometricProcess 3 | 4 | 5 | class GeometricProcessMixture(BaseGaussianMixture): 6 | def __init__(self, *, a=1, b=1, rng=None, **kwargs): 7 | weight_model = GeometricProcess(a=a, b=b, rng=rng) 8 | super().__init__(weight_model=weight_model, rng=rng, **kwargs) 9 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_beta_in_dirichlet_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import BetaInDirichlet 3 | 4 | 5 | class BetaInDirichletMixture(BaseGaussianMixture): 6 | def __init__(self, *, alpha=1, a=0, rng=None, **kwargs): 7 | weight_model = BetaInDirichlet(alpha=alpha, a=a, rng=rng) 8 | super().__init__(weight_model=weight_model, rng=rng, **kwargs) 9 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_pitman_yor_process_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import PitmanYorProcess 3 | 4 | 5 | class PitmanYorMixture(BaseGaussianMixture): 6 | def __init__(self, *, alpha=1, pyd=0, rng=None, **kwargs): 7 | weight_model = PitmanYorProcess(pyd=pyd, alpha=alpha, rng=rng) 8 | super().__init__(weight_model=weight_model, rng=rng, **kwargs) 9 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_equal_weighted_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import EqualWeighting 3 | 4 | 5 | class EqualWeightedMixture(BaseGaussianMixture): 6 | def __init__(self, *, n=1, rng=None, **kwargs): 7 | weight_model = EqualWeighting(n=n, rng=rng) 8 | self.n = n 9 | super().__init__(weight_model=weight_model, rng=rng, **kwargs) 10 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_frequency_weighted_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import FrequencyWeighting 3 | 4 | 5 | class FrequencyWeightedMixture(BaseGaussianMixture): 6 | def __init__(self, *, n=1, rng=None, **kwargs): 7 | weight_model = FrequencyWeighting(n=n, rng=rng) 8 | self.n = n 9 | super().__init__(weight_model=weight_model, rng=rng, **kwargs) 10 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_dirichlet_distribution_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import DirichletDistribution 3 | 4 | 5 | class DirichletDistributionMixture(BaseGaussianMixture): 6 | def __init__(self, *, n=1, alpha=1, rng=None, **kwargs): 7 | weight_model = DirichletDistribution(n=n, alpha=alpha, rng=rng) 8 | self.n = n 9 | super().__init__(weight_model=weight_model, rng=rng, **kwargs) 10 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. pyrichlet documentation master file, created by 2 | sphinx-quickstart on Wed Dec 29 15:51:40 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to pyrichlet's documentation! 7 | ===================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | installation 14 | usage 15 | models 16 | 17 | .. toctree:: 18 | :caption: Reference: 19 | :hidden: 20 | 21 | Index 22 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | from ._dirichlet_distribution import DirichletDistribution 3 | from ._dirichlet_process import DirichletProcess 4 | from ._pitman_yor_process import PitmanYorProcess 5 | from ._geometric_process import GeometricProcess 6 | from ._beta_in_beta import BetaInBeta 7 | from ._beta_in_dirichlet import BetaInDirichlet 8 | from ._beta_bernoulli import BetaBernoulli 9 | from ._beta_binomial import BetaBinomial 10 | from ._equal import EqualWeighting 11 | from ._frequency import FrequencyWeighting 12 | -------------------------------------------------------------------------------- /pyrichlet/utils/_loaders.py: -------------------------------------------------------------------------------- 1 | try: 2 | from importlib.resources import files 3 | except ImportError: 4 | from importlib_resources import files 5 | import pandas as pd 6 | from . import _data 7 | 8 | 9 | def load_chaetocnema(): 10 | df = pd.read_csv(str(files(_data) / 'chaetocnema.csv'), 11 | usecols=list(range(8)), 12 | index_col=0) 13 | return df.iloc[:, :-1], df.iloc[:, -1].to_numpy() 14 | 15 | 16 | def load_penguins(): 17 | df = pd.read_csv(str(files(_data) / 'penguins.csv'), 18 | usecols=[0, 2, 3, 4, 5]) 19 | df = df.dropna() 20 | return df.iloc[:, 1:], df.iloc[:, 0].to_numpy() 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ._dirichlet_distribution_mixture import DirichletDistributionMixture 3 | from ._dirichlet_process_mixture import DirichletProcessMixture 4 | from ._pitman_yor_process_mixture import PitmanYorMixture 5 | from ._geometric_process_mixture import GeometricProcessMixture 6 | from ._beta_in_beta_mixture import BetaInBetaMixture 7 | from ._beta_in_dirichlet_mixture import BetaInDirichletMixture 8 | from ._beta_bernoulli_mixture import BetaBernoulliMixture 9 | from ._beta_binomial_mixture import BetaBinomialMixture 10 | from ._equal_weighted_mixture import EqualWeightedMixture 11 | from ._frequency_weighted_mixture import FrequencyWeightedMixture 12 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_beta_binomial_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import BetaBinomial 3 | 4 | 5 | class BetaBinomialMixture(BaseGaussianMixture): 6 | def __init__(self, *, n=0, alpha=1, mu_prior=None, lambda_prior=1, 7 | psi_prior=None, nu_prior=None, total_iter=1000, burn_in=100, 8 | subsample_steps=1, rng=None): 9 | weight_model = BetaBinomial(n=n, alpha=alpha, rng=rng) 10 | super().__init__(weight_model=weight_model, mu_prior=mu_prior, 11 | lambda_prior=lambda_prior, psi_prior=psi_prior, 12 | nu_prior=nu_prior, total_iter=total_iter, 13 | burn_in=burn_in, subsample_steps=subsample_steps, 14 | rng=rng) 15 | -------------------------------------------------------------------------------- /pyrichlet/tests/_base.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import scipy.stats 4 | 5 | 6 | class BaseTest(unittest.TestCase): 7 | def setUp(self): 8 | self.rng = np.random.default_rng(0) 9 | n = 50 10 | means = np.array([-10, 10]) 11 | sds = np.array([1, 1]) 12 | weights = np.array([0.5, 0.5]) 13 | theta = self.rng.choice(range(len(weights)), size=n, p=weights) 14 | self.y = np.array([ 15 | scipy.stats.multivariate_normal.rvs( 16 | means[j], sds[j], 17 | random_state=self.rng 18 | ) for j in theta 19 | ]) 20 | self.y_density = np.array([scipy.stats.multivariate_normal.pdf( 21 | self.y, means[j], sds[j]) * weights[j] for j in range(2)]) 22 | self.y_density = self.y_density.sum(axis=0) 23 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_beta_in_beta_mixture.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseGaussianMixture 2 | from ..weight_models import BetaInBeta 3 | 4 | 5 | class BetaInBetaMixture(BaseGaussianMixture): 6 | def __init__(self, *, x=0, alpha=1, a=1, b=1, mu_prior=None, 7 | lambda_prior=1, psi_prior=None, nu_prior=None, 8 | p_method="geometric", total_iter=1000, 9 | burn_in=100, subsample_steps=1, rng=None): 10 | weight_model = BetaInBeta(x=x, alpha=alpha, a=a, b=b, 11 | p_method=p_method, rng=rng) 12 | super().__init__(weight_model=weight_model, mu_prior=mu_prior, 13 | lambda_prior=lambda_prior, psi_prior=psi_prior, 14 | nu_prior=nu_prior, total_iter=total_iter, 15 | burn_in=burn_in, subsample_steps=subsample_steps, 16 | rng=rng) 17 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | builder: html 21 | configuration: docs/conf.py 22 | fail_on_warning: false 23 | 24 | # If using Sphinx, optionally build your docs in additional formats such as PDF 25 | # formats: 26 | # - pdf 27 | 28 | # Optionally declare the Python requirements required to build your docs 29 | python: 30 | install: 31 | - requirements: docs/requirements.txt 32 | - method: setuptools 33 | path: . -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ====== 3 | 4 | Weighting Models 5 | ---------------- 6 | 7 | The following weighting models are implemented under the module 8 | `pyrichlet.weight\_models`: 9 | 10 | .. currentmodule:: pyrichlet.weight_models 11 | 12 | .. autosummary:: 13 | :toctree: weight_models 14 | :nosignatures: 15 | 16 | BaseWeight 17 | DirichletDistribution 18 | DirichletProcess 19 | PitmanYorProcess 20 | GeometricProcess 21 | BetaInBeta 22 | BetaInDirichlet 23 | BetaBernoulli 24 | BetaBinomial 25 | EqualWeighting 26 | FrequencyWeighting 27 | 28 | 29 | Mixture Models 30 | -------------- 31 | 32 | The following mixture models are implemented under the module 33 | `pyrichlet.mixture\_models`: 34 | 35 | .. currentmodule:: pyrichlet.mixture_models 36 | 37 | .. autosummary:: 38 | :toctree: mixture_models 39 | :nosignatures: 40 | 41 | BaseGaussianMixture 42 | DirichletDistributionMixture 43 | DirichletProcessMixture 44 | PitmanYorMixture 45 | GeometricProcessMixture 46 | BetaInBetaMixture 47 | BetaInDirichletMixture 48 | BetaBernoulliMixture 49 | BetaBinomialMixture 50 | EqualWeightedMixture 51 | FrequencyWeightedMixture 52 | -------------------------------------------------------------------------------- /pyrichlet/tests/test_convergence.py: -------------------------------------------------------------------------------- 1 | from pyrichlet import mixture_models as mm 2 | 3 | from ._base import BaseTest 4 | 5 | 6 | class TestMixtureModels(BaseTest): 7 | def test_dirichlet_distribution(self): 8 | n = 10 9 | mixture = mm.DirichletDistributionMixture(n=n, rng=self.rng) 10 | mixture.fit_variational(self.y) 11 | assert mixture.var_converged 12 | 13 | def test_dirichlet_process(self): 14 | n = 10 15 | mixture = mm.DirichletProcessMixture(rng=self.rng) 16 | mixture.fit_variational(self.y, n_groups=n) 17 | assert mixture.var_converged 18 | 19 | def test_pitman_yor_process(self): 20 | n = 10 21 | mixture = mm.PitmanYorMixture(pyd=0.1, rng=self.rng) 22 | mixture.fit_variational(self.y, n_groups=n) 23 | assert mixture.var_converged 24 | 25 | def test_geometric_process(self): 26 | n = 10 27 | mixture = mm.GeometricProcessMixture(rng=self.rng) 28 | mixture.fit_variational(self.y, n_groups=n) 29 | assert mixture.var_converged 30 | 31 | def test_equal_weighting(self): 32 | n = 10 33 | mixture = mm.EqualWeightedMixture(n=n, rng=self.rng) 34 | mixture.fit_variational(self.y) 35 | assert mixture.var_converged 36 | 37 | def test_frequency_weighting(self): 38 | n = 10 39 | mixture = mm.FrequencyWeightedMixture(n=n, rng=self.rng) 40 | mixture.fit_variational(self.y) 41 | assert mixture.var_converged 42 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: true 18 | matrix: 19 | python-version: ["3.9"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest 31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 32 | python -m pip install . 33 | - name: Lint with flake8 34 | run: | 35 | # stop the build if there are Python syntax errors or undefined names 36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | - name: Test with pytest 40 | run: | 41 | pytest 42 | -------------------------------------------------------------------------------- /pyrichlet/utils/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import digamma, loggamma 3 | from scipy.stats import beta, betabinom, binom 4 | 5 | 6 | def log_likelihood_beta(v, a, b): 7 | return beta.logpdf(v, a, b) 8 | 9 | 10 | def log_likelihood_beta_binom(x, n, a, b): 11 | return betabinom.logpmf(x, n, a, b) 12 | 13 | 14 | def log_likelihood_binom(x, n, p): 15 | return binom.logpmf(x, n, p) 16 | 17 | 18 | def dirichlet_log_eppf(alpha, partition): 19 | k = len(partition) 20 | n = sum(partition) 21 | res = loggamma(alpha) - loggamma(alpha + n) 22 | res += k * np.log(alpha) 23 | res += np.sum(loggamma(partition)) 24 | return res 25 | 26 | 27 | def mean_log_beta(a, b): 28 | return digamma(a) - digamma(a + b) 29 | 30 | 31 | def density_students_t(x, mu, precision, nu): 32 | dim = x.shape[1] 33 | density = np.exp(loggamma((nu + dim) / 2) - loggamma(nu / 2)) 34 | density *= np.linalg.norm(precision) ** 0.5 / (nu * np.pi) ** (dim / 2) 35 | density *= (1 + np.einsum('ij,jk,ik->i', 36 | x - mu, precision, x - mu 37 | ) / nu) ** (-(nu + dim) / 2) 38 | return density 39 | 40 | 41 | def density_normal(x, mu, precision): 42 | dim = x.shape[1] 43 | density = np.sqrt((2 * np.pi) ** -dim * np.linalg.norm(precision)) 44 | density *= np.exp(- np.einsum('ij,jk,ik->i', 45 | x - mu, precision, x - mu 46 | ) / 2) 47 | return density 48 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_equal.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | from ..exceptions import NotFittedError 3 | 4 | import numpy as np 5 | 6 | 7 | class EqualWeighting(BaseWeight): 8 | def __init__(self, n=1, rng=None): 9 | super().__init__(rng=rng) 10 | self.n = n 11 | 12 | def weighting_log_likelihood(self): 13 | return 0 14 | 15 | def random(self, size=None): 16 | self.w = np.repeat(1 / self.n, self.n) 17 | return self.w 18 | 19 | def complete(self, size): 20 | return self.random(size) 21 | 22 | def fit_variational(self, variational_d: np.ndarray): 23 | self.variational_d = variational_d 24 | self.n = self.variational_d.shape[0] 25 | self.variational_k = self.n 26 | 27 | def variational_mean_log_w_j(self, j): 28 | return np.log(1 / self.n) * (j < self.n) 29 | 30 | def variational_mean_log_p_d__w(self, variational_d=None): 31 | _variational_d = variational_d 32 | if _variational_d is None: 33 | _variational_d = self.variational_d 34 | if _variational_d is None: 35 | raise NotFittedError 36 | return _variational_d.shape[1] * np.log(self.n) 37 | 38 | def variational_mean_log_p_w(self): 39 | return 0 40 | 41 | def variational_mean_log_q_w(self): 42 | return 0 43 | 44 | def variational_mean_w_j(self, j): 45 | return 1 / self.n * (j < self.n) 46 | 47 | def variational_mode_w_j(self, j): 48 | return 1 / self.n * (j < self.n) 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Documentation Status](https://readthedocs.org/projects/pyrichlet/badge/?version=main)](https://pyrichlet.readthedocs.io/en/main/?badge=main) 2 | ![PyPI - Version](https://img.shields.io/pypi/v/pyrichlet) 3 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pyrichlet) 4 | 5 | 6 | # Project description 7 | 8 | Pyrichlet is a package for doing density estimation and clustering using 9 | Gaussian mixtures with BNP weighting models 10 | 11 | # Installation 12 | 13 | With pip: 14 | 15 | ``` 16 | pip install pyrichlet 17 | ``` 18 | 19 | For a specific version: 20 | 21 | ``` 22 | pip install pyrichlet==0.0.9 23 | ``` 24 | 25 | 26 | # Usage 27 | 28 | This is a quick guide. For a more detailed usage see 29 | https://pyrichlet.readthedocs.io/en/main/index.html. 30 | 31 | 32 | The mixture models that this package implements are 33 | 34 | - `DirichletDistributionMixture` 35 | - `DirichletProcessMixture` 36 | - `PitmanYorMixture` 37 | - `GeometricProcessMixture` 38 | - `BetaInBetaMixture` 39 | - `BetaInDirichletMixture` 40 | - `BetaBernoulliMixture` 41 | - `BetaBinomialMixture` 42 | 43 | They can be fitted for an array or dataframe using a Gibbs sampler or 44 | variational Bayes methods, 45 | 46 | ```python 47 | from pyrichlet import mixture_models 48 | 49 | mm = mixture_models.DirichletProcessMixture() 50 | y = [1, 2, 3, 4] 51 | mm.fit_gibbs(y, init_groups=2) 52 | 53 | mm.fit_variational(y, n_groups=2) 54 | ``` 55 | 56 | and use the fitted class to do density estimation 57 | 58 | ```python 59 | x = 2.5 60 | f_x = mm.gibbs_eap_density(x) 61 | f_x = mm.var_eap_density(x) 62 | ``` 63 | 64 | or clustering 65 | 66 | ```python 67 | mm.var_map_cluster() 68 | mm.gibbs_map_cluster() 69 | mm.gibbs_eap_spectral_consensus_cluster() 70 | ``` -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (C) 2020-2021 Fidel Selva 4 | # License: Apache License 2.0 5 | 6 | import setuptools 7 | import versioneer 8 | 9 | with open("README.md", "r", encoding="utf-8") as fh: 10 | long_description = fh.read() 11 | 12 | setuptools.setup( 13 | name="pyrichlet", 14 | version=versioneer.get_version(), 15 | cmdclass=versioneer.get_cmdclass(), 16 | author="Fidel Selva", 17 | author_email="cfso100@gmail.com", 18 | description="A package for doing density estimation and clustering using " 19 | "Gaussian mixtures with BNP weighting models", 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | url="https://github.com/cabo40/pyrichlet", 23 | project_urls={ 24 | 'Bug Tracker': 'https://github.com/cabo40/pyrichlet/issues', 25 | 'Documentation': 'https://pyrichlet.readthedocs.io', 26 | 'Source Code': 'https://github.com/cabo40/pyrichlet' 27 | }, 28 | license='Apache License, Version 2.0', 29 | packages=setuptools.find_packages(exclude=['tests']), 30 | install_requires=[ 31 | 'numpy', 32 | 'scipy', 33 | 'pandas', 34 | 'scikit-learn', 35 | 'tqdm', 36 | 'matplotlib', 37 | 'importlib_resources; python_version < "3.9"', 38 | ], 39 | classifiers=[ 40 | "Intended Audience :: Science/Research", 41 | "Programming Language :: Python", 42 | "Programming Language :: Python :: 3", 43 | "Programming Language :: Python :: 3.9", 44 | "Programming Language :: Python :: 3.10", 45 | "Programming Language :: Python :: 3.11", 46 | "License :: OSI Approved :: Apache Software License", 47 | "Operating System :: OS Independent", 48 | ], 49 | python_requires='>=3.7', 50 | include_package_data=True, 51 | ) 52 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | This document outlines various methods for installing `pyrichlet`. 5 | 6 | Installing the Latest Version 7 | ----------------------------- 8 | 9 | To install the latest version of a package from the Python Package Index 10 | (PyPI), use the following command: 11 | 12 | :: 13 | 14 | pip install pyrichlet 15 | 16 | Installing a Specific Version 17 | ----------------------------- 18 | 19 | To install a specific version of a package, use the following command: 20 | 21 | :: 22 | 23 | pip install pyrichlet== 24 | 25 | Replace with the desired version number, for example to install 26 | version `0.0.9`: 27 | 28 | :: 29 | 30 | pip install pyrichlet==0.0.9 31 | 32 | 33 | Installing from the Repository 34 | ------------------------------ 35 | 36 | There are two ways to install `pyrichlet` repository: 37 | 38 | Using the URL 39 | ^^^^^^^^^^^^^ 40 | 41 | :: 42 | 43 | pip install git+https://github.com/cabo40/pyrichlet@ 44 | 45 | Replace with the specific branch you want to install 46 | from (optional, defaults to the default branch `main`). 47 | For example, 48 | 49 | :: 50 | 51 | pip install git+https://github.com/cabo40/pyrichlet@main 52 | 53 | 54 | Using a local clone 55 | ^^^^^^^^^^^^^^^^^^^ 56 | 57 | 58 | Clone the repository to your local machine. 59 | 60 | :: 61 | 62 | git clone https://github.com/cabo40/pyrichlet.git 63 | 64 | 65 | Navigate to the cloned repository directory. 66 | 67 | :: 68 | 69 | cd pytichlet 70 | 71 | Install as a local package: 72 | 73 | :: 74 | 75 | pip install . 76 | 77 | Uninstalling 78 | ------------ 79 | 80 | To uninstall `pyrichlet` regarding of the installation method you need to do: 81 | 82 | :: 83 | 84 | pip uninstall pyrichlet 85 | 86 | Afterwards you can remove any directory created during the installation, for 87 | example if it was installed as a local package. -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("..")) 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'pyrichlet' 21 | copyright = '2021, Fidel Selva' 22 | author = 'Fidel Selva' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '0.0.9' 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | 'sphinx.ext.autodoc', 34 | 'sphinx.ext.autosummary', 35 | 'numpydoc', 36 | 'sphinx_rtd_theme', 37 | 'sphinx.ext.doctest', 38 | ] 39 | 40 | autosummary_generate = True 41 | autosummary_imported_members = True 42 | numpydoc_show_class_members = False 43 | 44 | # Add any paths that contain templates here, relative to this directory. 45 | templates_path = ['_templates'] 46 | 47 | # List of patterns, relative to source directory, that match files and 48 | # directories to ignore when looking for source files. 49 | # This pattern also affects html_static_path and html_extra_path. 50 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | 54 | # The theme to use for HTML and HTML Help pages. See the documentation for 55 | # a list of builtin themes. 56 | # 57 | # html_theme = 'alabaster' 58 | html_theme = "sphinx_rtd_theme" 59 | 60 | # Add any paths that contain custom static files (such as style sheets) here, 61 | # relative to this directory. They are copied after the builtin static files, 62 | # so a file named "default.css" will overwrite the builtin "default.css". 63 | html_static_path = ['_static'] 64 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_frequency.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | from ..exceptions import NotFittedError 3 | 4 | import numpy as np 5 | 6 | 7 | class FrequencyWeighting(BaseWeight): 8 | def __init__(self, n=1, rng=None): 9 | super().__init__(rng=rng) 10 | self.n = n 11 | 12 | def weighting_log_likelihood(self): 13 | return 0 14 | 15 | def random(self, size=None): 16 | if len(self.d) == 0: 17 | self.w = np.repeat(1 / self.n, self.n) 18 | else: 19 | self.w = np.bincount(self.d) 20 | self.w = self.w / self.w.sum() 21 | return self.w 22 | 23 | def complete(self, size): 24 | return self.random(size) 25 | 26 | def fit_variational(self, variational_d: np.ndarray): 27 | self.variational_d = variational_d 28 | self.variational_k = variational_d.shape[1] 29 | if self.variational_k == 0: 30 | self.variational_k = self.n 31 | 32 | def variational_mean_log_w_j(self, j): 33 | if self.variational_d is None: 34 | raise NotFittedError 35 | if j >= self.variational_k: 36 | return -np.inf 37 | if self.variational_d.shape[1]: 38 | return self.variational_d.sum(1)[j] / self.variational_d.sum() 39 | return np.log(1 / self.variational_k) 40 | 41 | def variational_mean_log_p_d__w(self, variational_d=None): 42 | if variational_d is None: 43 | if self.variational_d is None: 44 | raise NotFittedError 45 | variational_d = self.variational_d 46 | else: 47 | self.variational_d = variational_d 48 | return np.sum(variational_d.sum(1) * np.log(self.variational_d.sum(1) / 49 | self.variational_d.sum())) 50 | 51 | def variational_mean_log_p_w(self): 52 | if self.variational_d is None: 53 | raise NotFittedError 54 | return 0 55 | 56 | def variational_mean_log_q_w(self): 57 | if self.variational_d is None: 58 | raise NotFittedError 59 | return 0 60 | 61 | def variational_mean_w_j(self, j): 62 | if j >= self.variational_k: 63 | return 0 64 | return self.variational_d.sum(1)[j] / self.variational_d.sum() 65 | 66 | def variational_mode_w_j(self, j): 67 | if j > self.variational_k: 68 | return 0 69 | return self.variational_d.sum(1)[j] / self.variational_d.sum() 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Pycharm 132 | .idea/ 133 | -------------------------------------------------------------------------------- /pyrichlet/tests/test_mixture_models_variational.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyrichlet import mixture_models as mm 3 | 4 | from ._base import BaseTest 5 | 6 | 7 | class TestMixtureModels(BaseTest): 8 | def test_dirichlet_distribution(self): 9 | n = 2 10 | mixture = mm.DirichletDistributionMixture(n=n, rng=self.rng) 11 | mixture.fit_variational(self.y) 12 | fitted_density = mixture.var_eap_density() 13 | mean_squared_error = np.power(fitted_density - self.y_density, 14 | 2).mean() 15 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 16 | 17 | def test_dirichlet_process(self): 18 | mixture = mm.DirichletProcessMixture(rng=self.rng) 19 | mixture.fit_variational(self.y, n_groups=10) 20 | fitted_density = mixture.var_eap_density() 21 | mean_squared_error = np.power(fitted_density - self.y_density, 22 | 2).mean() 23 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 24 | 25 | def test_pitman_yor_process(self): 26 | mixture = mm.PitmanYorMixture(rng=self.rng) 27 | mixture.fit_variational(self.y, n_groups=10) 28 | fitted_density = mixture.var_eap_density() 29 | mean_squared_error = np.power(fitted_density - self.y_density, 30 | 2).mean() 31 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 32 | 33 | def test_geometric_process(self): 34 | mixture = mm.GeometricProcessMixture(rng=self.rng) 35 | mixture.fit_variational(self.y, n_groups=10) 36 | fitted_density = mixture.var_eap_density() 37 | mean_squared_error = np.power(fitted_density - self.y_density, 38 | 2).mean() 39 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 40 | 41 | def test_equal_weighting(self): 42 | n = 2 43 | mixture = mm.EqualWeightedMixture(n=n, rng=self.rng) 44 | mixture.fit_variational(self.y) 45 | fitted_density = mixture.var_eap_density() 46 | mean_squared_error = np.power(fitted_density - self.y_density, 47 | 2).mean() 48 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 49 | 50 | def test_frequency_weighting(self): 51 | n = 2 52 | mixture = mm.FrequencyWeightedMixture(n=n, rng=self.rng) 53 | mixture.fit_variational(self.y) 54 | fitted_density = mixture.var_eap_density() 55 | mean_squared_error = np.power(fitted_density - self.y_density, 56 | 2).mean() 57 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 58 | -------------------------------------------------------------------------------- /pyrichlet/utils/_data/chaetocnema.csv: -------------------------------------------------------------------------------- 1 | "","X10","X12","X14","X18","X40","X48","species","area" 2 | "1",191,131,53,150,15,104," A"," A" 3 | "2",185,134,50,147,13,105," A"," A" 4 | "3",200,137,52,144,14,102," A"," A" 5 | "4",173,127,50,144,16,97," A"," A" 6 | "5",171,118,49,153,13,106," A"," A" 7 | "6",160,118,47,140,15,99," A"," A" 8 | "7",188,134,54,151,14,98," A"," B" 9 | "8",186,129,51,143,14,110," A"," C" 10 | "9",174,131,52,144,14,116," A"," C" 11 | "10",163,115,47,142,15,95," A"," D" 12 | "11",190,143,52,141,13,99," A"," D" 13 | "12",174,131,50,150,15,105," A"," D" 14 | "13",201,130,51,148,13,110," A"," D" 15 | "14",190,133,53,154,15,106," A"," D" 16 | "15",182,130,51,147,14,105," A"," E" 17 | "16",184,131,51,137,14,95," A"," E" 18 | "17",177,127,49,134,15,105," A"," E" 19 | "18",178,126,53,157,14,116," A"," F" 20 | "19",210,140,54,149,13,107," A"," G" 21 | "20",182,121,51,147,13,111," A"," G" 22 | "21",186,136,56,148,14,111," A"," G" 23 | "22",158,141,58,145,8,107," C"," P" 24 | "23",146,119,51,140,11,111," C"," P" 25 | "24",151,130,51,140,11,113," C"," P" 26 | "25",122,113,45,131,10,102," C"," P" 27 | "26",138,121,53,139,11,106," C"," P" 28 | "27",132,115,49,139,10,98," C"," P" 29 | "28",131,127,51,136,12,107," C"," P" 30 | "29",135,123,50,129,11,107," C"," P" 31 | "30",125,119,51,140,10,110," C"," P" 32 | "31",130,120,48,137,9,106," C"," P" 33 | "32",130,131,51,141,11,108," C"," P" 34 | "33",138,127,52,138,9,101," C"," P" 35 | "34",130,116,52,143,9,111," C"," P" 36 | "35",143,123,54,142,11,95," C"," P" 37 | "36",154,135,56,144,10,123," C"," P" 38 | "37",147,132,54,138,10,102," C"," P" 39 | "38",141,131,51,140,10,106," C"," P" 40 | "39",131,116,47,130,9,102," C"," P" 41 | "40",144,121,53,137,11,104," C"," Q" 42 | "41",137,146,53,137,10,113," C"," R" 43 | "42",143,119,53,136,9,105," C"," R" 44 | "43",135,127,52,140,10,108," C"," R" 45 | "44",186,107,49,120,14,84," B"," A" 46 | "45",211,122,49,123,16,95," B"," A" 47 | "46",201,114,47,130,14,74," B"," A" 48 | "47",242,131,54,131,16,90," B"," A" 49 | "48",184,108,43,116,16,75," B"," A" 50 | "49",211,118,51,122,15,90," B"," A" 51 | "50",217,122,49,127,15,73," B"," A" 52 | "51",223,127,51,132,16,84," B"," A" 53 | "52",208,125,50,125,14,88," B"," B" 54 | "53",199,124,46,119,13,78," B"," C" 55 | "54",211,129,49,122,13,83," B"," C" 56 | "55",218,126,49,120,15,85," B"," C" 57 | "56",203,122,49,119,14,73," B"," C" 58 | "57",192,116,49,123,15,90," B"," C" 59 | "58",195,123,47,125,15,77," B"," D" 60 | "59",211,122,48,125,14,73," B"," D" 61 | "60",187,123,47,129,14,75," B"," D" 62 | "61",192,109,46,130,13,90," B"," E" 63 | "62",223,124,53,129,13,82," B"," E" 64 | "63",188,114,48,122,12,74," B"," E" 65 | "64",216,120,50,129,15,86," B"," H" 66 | "65",185,114,46,124,15,92," B"," I" 67 | "66",178,119,47,120,13,78," B"," L" 68 | "67",187,111,49,119,16,66," B"," L" 69 | "68",187,112,49,119,14,55," B"," L" 70 | "69",201,130,54,133,13,84," B"," L" 71 | "70",187,120,47,121,15,86," B"," L" 72 | "71",210,119,50,128,14,68," B"," M" 73 | "72",196,114,51,129,14,86," B"," M" 74 | "73",195,110,49,124,13,89," B"," N" 75 | "74",187,124,49,129,14,88," B"," O" 76 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_beta_binomial.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | import numpy as np 3 | 4 | from ..utils.functions import log_likelihood_beta, log_likelihood_beta_binom, \ 5 | log_likelihood_binom 6 | 7 | 8 | class BetaBinomial(BaseWeight): 9 | def __init__(self, n=0, alpha=1, rng=None): 10 | super().__init__(rng=rng) 11 | self.n = n 12 | self.alpha = alpha 13 | 14 | self.v = np.array([], dtype=np.float64) 15 | self.binomials = np.array([], dtype=int) 16 | 17 | def weighting_log_likelihood(self): 18 | b = self.binomials[0] 19 | res = log_likelihood_beta_binom(b, self.n, 1, self.alpha) 20 | for j in range(1, len(self.v) - 1): 21 | v = self.v[j] 22 | res += log_likelihood_beta(v, 1 + b, 23 | self.alpha + self.n - b) 24 | b = self.binomials[j] 25 | res += log_likelihood_binom(b, self.n, v) 26 | 27 | res += log_likelihood_beta(self.v[-1], 1 + b, 28 | self.alpha + self.n - b) 29 | return res 30 | 31 | def random(self, size=None): 32 | if size is None: 33 | if len(self.d) == 0: 34 | raise ValueError( 35 | "Weight structure not fitted and `n` not passed.") 36 | size = 1 37 | self.v = self.v[:0] 38 | if len(self.d) == 0: 39 | self.complete(size) 40 | else: 41 | self._random_binomials() 42 | a_c = np.bincount(self.d) 43 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 44 | beta_phased = self.binomials[:-1] + self.binomials[1:] 45 | a = 1 + a_c + beta_phased 46 | b = self.alpha + b_c + 2 * self.n - beta_phased 47 | self.v = self._rng.beta(a=a, b=b) 48 | self.w = self.v * np.cumprod(np.concatenate(([1], 49 | 1 - self.v[:-1]))) 50 | if size is not None: 51 | self.complete(size) 52 | return self.w 53 | 54 | def complete(self, size): 55 | super().complete(size) 56 | if len(self.v) == 0: 57 | v0 = self._rng.beta(1, self.alpha) 58 | self.binomials = self._rng.binomial(self.n, v0, size=1) 59 | self.v = self._rng.beta(1 + self.binomials[-1], 60 | self.alpha + self.n - self.binomials[-1], 61 | size=1) 62 | while len(self.v) < size: 63 | self.binomials = np.append(self.binomials, 64 | self._rng.binomial(self.n, self.v[-1])) 65 | self.v = np.append( 66 | self.v, self._rng.beta( 67 | 1 + self.binomials[-1], 68 | self.alpha + self.n - self.binomials[-1]) 69 | ) 70 | self.w = self.v * np.cumprod(np.concatenate(([1], 71 | 1 - self.v[:-1]))) 72 | return self.w 73 | 74 | def _random_binomials(self): 75 | a_c = np.bincount(self.d) 76 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 77 | a_c = np.append(0, a_c) 78 | b_c = np.append(0, b_c) 79 | beta_rv = self._rng.beta(1 + a_c, self.alpha + b_c) 80 | self.binomials = self._rng.binomial(self.n, beta_rv) 81 | return self.binomials 82 | -------------------------------------------------------------------------------- /pyrichlet/tests/test_mixture_models_gibbs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyrichlet import mixture_models as mm 3 | 4 | from ._base import BaseTest 5 | 6 | 7 | class TestMixtureModels(BaseTest): 8 | def test_dirichlet_distribution(self): 9 | n = 2 10 | mixture = mm.DirichletDistributionMixture(n=n, rng=self.rng) 11 | mixture.fit_gibbs(self.y) 12 | fitted_density = mixture.gibbs_map_density() 13 | mean_squared_error = np.power(fitted_density - self.y_density, 14 | 2).mean() 15 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 16 | 17 | def test_dirichlet_process(self): 18 | mixture = mm.DirichletProcessMixture(rng=self.rng) 19 | mixture.fit_gibbs(self.y, init_groups=2) 20 | fitted_density = mixture.gibbs_map_density() 21 | mean_squared_error = np.power(fitted_density - self.y_density, 22 | 2).mean() 23 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 24 | 25 | def test_pitman_yor_process(self): 26 | mixture = mm.PitmanYorMixture(rng=self.rng) 27 | mixture.fit_gibbs(self.y, init_groups=2) 28 | fitted_density = mixture.gibbs_map_density() 29 | mean_squared_error = np.power(fitted_density - self.y_density, 30 | 2).mean() 31 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 32 | 33 | def test_geometric_process(self): 34 | mixture = mm.GeometricProcessMixture(rng=self.rng) 35 | mixture.fit_gibbs(self.y, init_groups=2) 36 | fitted_density = mixture.gibbs_map_density() 37 | mean_squared_error = np.power(fitted_density - self.y_density, 38 | 2).mean() 39 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 40 | 41 | def test_beta_in_dirichlet(self): 42 | mixture = mm.BetaInDirichletMixture(a=0.1, rng=self.rng) 43 | mixture.fit_gibbs(self.y, init_groups=2) 44 | fitted_density = mixture.gibbs_map_density() 45 | mean_squared_error = np.power(fitted_density - self.y_density, 46 | 2).mean() 47 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 48 | 49 | def test_beta_in_beta(self): 50 | mixture = mm.BetaInBetaMixture(rng=self.rng) 51 | mixture.fit_gibbs(self.y, init_groups=2) 52 | fitted_density = mixture.gibbs_map_density() 53 | mean_squared_error = np.power(fitted_density - self.y_density, 54 | 2).mean() 55 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 56 | 57 | def test_beta_bernoulli(self): 58 | mixture = mm.BetaBernoulliMixture(rng=self.rng) 59 | mixture.fit_gibbs(self.y, init_groups=2) 60 | fitted_density = mixture.gibbs_map_density() 61 | mean_squared_error = np.power(fitted_density - self.y_density, 62 | 2).mean() 63 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 64 | 65 | def test_beta_binomial(self): 66 | mixture = mm.BetaBinomialMixture(rng=self.rng) 67 | mixture.fit_gibbs(self.y, init_groups=2) 68 | fitted_density = mixture.gibbs_map_density() 69 | mean_squared_error = np.power(fitted_density - self.y_density, 70 | 2).mean() 71 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 72 | 73 | def test_equal_weighting(self): 74 | n = 2 75 | mixture = mm.EqualWeightedMixture(n=n, rng=self.rng) 76 | mixture.fit_gibbs(self.y) 77 | fitted_density = mixture.gibbs_map_density() 78 | mean_squared_error = np.power(fitted_density - self.y_density, 79 | 2).mean() 80 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 81 | 82 | def test_frequency_weighting(self): 83 | n = 2 84 | mixture = mm.FrequencyWeightedMixture(n=n, rng=self.rng) 85 | mixture.fit_gibbs(self.y) 86 | fitted_density = mixture.gibbs_map_density() 87 | mean_squared_error = np.power(fitted_density - self.y_density, 88 | 2).mean() 89 | self.assertAlmostEqual(mean_squared_error, 0, places=1) 90 | -------------------------------------------------------------------------------- /pyrichlet/tests/test_weight_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyrichlet import weight_models as wm 3 | 4 | from ._base import BaseTest 5 | 6 | 7 | class TestWeightModels(BaseTest): 8 | n_sims = int(1e4) 9 | 10 | def test_dirichlet_distribution(self): 11 | n = 5 12 | weight_structure = wm.DirichletDistribution(n=n, rng=self.rng) 13 | w = np.array( 14 | [weight_structure.random()[:2] for _ in range(self.n_sims)]) 15 | self.assertAlmostEqual(w[:, 0].mean(), 1 / n, places=1) 16 | self.assertAlmostEqual(w[:, 0].var(0), (n - 1) / n ** 2 / (n + 1), 17 | places=1) 18 | self.assertAlmostEqual(np.corrcoef(w.T)[0, 1], -1 / (n - 1), places=1) 19 | 20 | def test_dirichlet_process(self): 21 | n = 2 22 | weight_structure = wm.DirichletProcess(rng=self.rng) 23 | w = np.array([weight_structure.random(n) for _ in range(self.n_sims)]) 24 | self.assertAlmostEqual(w[:, 0].mean(), 1 / 2, places=1) 25 | cov_matrix = np.cov(w.T) 26 | self.assertAlmostEqual(cov_matrix[0, 0], 1 / 12, places=1) 27 | self.assertAlmostEqual(cov_matrix[0, 1], -1 / 24, places=1) 28 | 29 | def test_pitman_yor_process(self): 30 | n = 2 31 | weight_structure = wm.PitmanYorProcess(rng=self.rng) 32 | w = np.array([weight_structure.random(n) for _ in range(self.n_sims)]) 33 | self.assertAlmostEqual(w[:, 0].mean(), 1 / 2, places=1) 34 | cov_matrix = np.cov(w.T) 35 | self.assertAlmostEqual(cov_matrix[0, 0], 1 / 12, places=1) 36 | self.assertAlmostEqual(cov_matrix[0, 1], -1 / 24, places=1) 37 | 38 | def test_geometric_process(self): 39 | n = 2 40 | weight_structure = wm.GeometricProcess(rng=self.rng) 41 | w = np.array([weight_structure.random(n) for _ in range(self.n_sims)]) 42 | self.assertAlmostEqual(w[:, 0].mean(), 1 / 2, places=1) 43 | cov_matrix = np.cov(w.T) 44 | self.assertAlmostEqual(cov_matrix[0, 0], 1 / 12, places=1) 45 | self.assertAlmostEqual(cov_matrix[0, 1], 0, places=1) 46 | 47 | def test_beta_in_dirichlet(self): 48 | n = 2 49 | weight_structure = wm.BetaInDirichlet(rng=self.rng) 50 | w = np.array([weight_structure.random(n) for _ in range(self.n_sims)]) 51 | self.assertAlmostEqual(w[:, 0].mean(), 1 / 2, places=1) 52 | cov_matrix = np.cov(w.T) 53 | self.assertAlmostEqual(cov_matrix[0, 0], 1 / 12, places=1) 54 | self.assertAlmostEqual(cov_matrix[0, 1], 0, places=1) 55 | 56 | def test_beta_in_beta(self): 57 | n = 2 58 | weight_structure = wm.BetaInBeta(rng=self.rng) 59 | w = np.array([weight_structure.random(n) for _ in range(self.n_sims)]) 60 | self.assertAlmostEqual(w[:, 0].mean(), 1 / 2, places=1) 61 | cov_matrix = np.cov(w.T) 62 | self.assertAlmostEqual(cov_matrix[0, 0], 1 / 12, places=1) 63 | self.assertAlmostEqual(cov_matrix[0, 1], - 1 / 24, places=1) 64 | 65 | def test_beta_in_bernoulli(self): 66 | n = 2 67 | weight_structure = wm.BetaBernoulli(rng=self.rng) 68 | w = np.array([weight_structure.random(n) for _ in range(self.n_sims)]) 69 | self.assertAlmostEqual(w[:, 0].mean(), 1 / 2, places=1) 70 | cov_matrix = np.cov(w.T) 71 | self.assertAlmostEqual(cov_matrix[0, 0], 1 / 12, places=1) 72 | self.assertAlmostEqual(cov_matrix[0, 1], - 1 / 24, places=1) 73 | 74 | def test_beta_binomial(self): 75 | n = 2 76 | weight_structure = wm.BetaBinomial(rng=self.rng) 77 | w = np.array([weight_structure.random(n) for _ in range(self.n_sims)]) 78 | self.assertAlmostEqual(w[:, 0].mean(), 1 / 2, places=1) 79 | cov_matrix = np.cov(w.T) 80 | self.assertAlmostEqual(cov_matrix[0, 0], 1 / 12, places=1) 81 | self.assertAlmostEqual(cov_matrix[0, 1], - 1 / 24, places=1) 82 | 83 | def test_equal_weighting(self): 84 | n = 100 85 | weight_structure = wm.EqualWeighting(n=n, rng=self.rng) 86 | w = weight_structure.random() 87 | self.assertAlmostEqual(w.var(), 0) 88 | self.assertAlmostEqual(w[0], 1 / n) 89 | 90 | def test_frequency_weighting(self): 91 | n = 100 92 | weight_structure = wm.EqualWeighting(n=n, rng=self.rng) 93 | w = weight_structure.random() 94 | self.assertAlmostEqual(w.var(), 0) 95 | self.assertAlmostEqual(w[0], 1 / n) 96 | -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | .. note:: 4 | The weighting structure models that this package implements are listed under 5 | :doc:`models`. 6 | 7 | .. hint:: 8 | Code blocks using the ``rng`` parameter use it only for reproducibility 9 | purposes. 10 | 11 | There are two families of classes in `pyrichlet`, weighting models and mixture 12 | models. 13 | A Weighting model object can be imported and initialized as: 14 | 15 | .. doctest:: 16 | 17 | >>> from pyrichlet import weight_models 18 | >>> wm = weight_models.DirichletProcess(rng=0) 19 | >>> wm = weight_models.DirichletProcess(rng=0) 20 | 21 | then `wm` can be used to draw a sample vector of weights with length `10` from 22 | its prior distribution 23 | 24 | .. doctest:: 25 | 26 | >>> wm.random(10) 27 | array([7.02467947e-01, 2.12012016e-01, 8.52339410e-02, 2.75312557e-04, 28 | 8.69186887e-06, 8.68128787e-07, 2.27208506e-07, 6.14190832e-07, 29 | 6.03943713e-08, 2.02785792e-07]) 30 | 31 | we can get a new realization for the vector by running ``wm.random`` again, or 32 | get more weights for the same realization with 33 | 34 | .. doctest:: 35 | 36 | >>> wm.complete(12) 37 | array([7.02467947e-01, 2.12012016e-01, 8.52339410e-02, 2.75312557e-04, 38 | 8.69186887e-06, 8.68128787e-07, 2.27208506e-07, 6.14190832e-07, 39 | 6.03943713e-08, 2.02785792e-07, 7.66088536e-08, 2.75049683e-08]) 40 | 41 | or get `100` independent random assignations using the current truncated 42 | structure 43 | 44 | .. doctest:: 45 | 46 | >>> wm.random_assignment(100) 47 | array([0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 48 | 0, 2, 0, 0, 0, 2, 0, 2, 0, 0, 0, 2, 2, 0, 1, 0, 0, 1, 0, 1, 1, 2, 49 | 0, 1, 2, 2, 0, 1, 2, 2, 0, 2, 1, 1, 0, 0, 1, 2, 0, 0, 0, 2, 0, 1, 50 | 0, 0, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 51 | 1, 0, 0, 0, 1, 0, 2, 0, 0, 0, 1, 0]) 52 | 53 | this methods can be applied to the posterior distribution after fitting a 54 | database of assignations. 55 | Once fitted the minimum vector size can be inferred from the data 56 | 57 | .. doctest:: 58 | 59 | >>> wm.fit([0, 1, 2, 3, 3, 3, 5]) 60 | >>> wm.random() 61 | array([0.22802716, 0.14438267, 0.2624806 , 0.3200459 , 0.0360923 , 62 | 0.00697771]) 63 | >>> wm.random_assignment(20) 64 | array([1, 3, 3, 2, 1, 5, 1, 0, 3, 3, 3, 4, 3, 3, 3, 1, 0, 3, 3, 0]) 65 | 66 | 67 | The fitted data can be replaced by calling again the ``fit`` method, or by 68 | resetting the weighting structure 69 | 70 | .. doctest:: 71 | 72 | >>> wm.reset() 73 | 74 | For each weighting structure there is an associated Gaussian mixture model. 75 | 76 | .. doctest:: 77 | 78 | >>> from pyrichlet import mixture_models 79 | >>> mm = mixture_models.DirichletProcessMixture(rng=0) 80 | 81 | The mixture models can fit data represented in an array or as a dataframe 82 | 83 | .. doctest:: 84 | 85 | >>> mm.fit_gibbs([1, 2, 3, 4], init_groups=2) 86 | 87 | we can get the EAP density at a single point 88 | 89 | .. doctest:: 90 | 91 | >>> mm.gibbs_eap_density(2.5) 92 | array([0.25341657]) 93 | 94 | or at several 95 | 96 | .. doctest:: 97 | 98 | >>> mm.gibbs_eap_density([1.5, 2.5, 3.5]) 99 | array([0.18017642, 0.25341657, 0.19041053]) 100 | 101 | weighting structures can also be fitted using variational inference, to which 102 | we can calculate the EAP density 103 | 104 | .. doctest:: 105 | 106 | >>> mm.fit_variational([1, 2, 3, 4], n_groups=2) 107 | >>> mm.var_eap_density([1.5, 2.5, 3.5]) 108 | array([0.20426694, 0.32282514, 0.20426694]) 109 | 110 | mixture models can also be used for clustering 111 | 112 | .. doctest:: 113 | 114 | >>> mm.var_map_cluster() 115 | array([0, 0, 1, 1]) 116 | >>> mm.gibbs_map_cluster() 117 | array([0, 0, 0, 0]) 118 | >>> mm.gibbs_eap_spectral_consensus_cluster() 119 | array([0, 0, 0, 0], dtype=int32) 120 | 121 | Depending on the database, fitting can take a noticeable time to finish. 122 | To show the progress of the fitting method, the parameter `show_progress` can 123 | be set 124 | 125 | .. code-block:: python 126 | 127 | >>> mm.fit_gibbs([1, 2, 3, 4], init_groups=2, show_progress=True) 128 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_dirichlet_distribution.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | from ..exceptions import NotFittedError 3 | from ..utils.functions import mean_log_beta 4 | 5 | import numpy as np 6 | from scipy.stats import dirichlet 7 | from scipy.special import loggamma 8 | 9 | 10 | class DirichletDistribution(BaseWeight): 11 | def __init__(self, n=1, alpha=1, rng=None): 12 | super().__init__(rng=rng) 13 | assert type(n) == int, "parameter n must be of type int" 14 | self.n = n 15 | if type(alpha) in (list, np.ndarray): 16 | self.n = n 17 | self.alpha = np.array(alpha, dtype=np.float64) 18 | elif type(alpha) in (int, float): 19 | self.alpha = np.array([alpha] * self.n, dtype=np.float64) 20 | 21 | def weighting_log_likelihood(self): 22 | if len(self.w) == 0: 23 | return 0 24 | return dirichlet.logpdf(self.w, self.alpha) 25 | 26 | def random(self, size=None): 27 | if len(self.d) > 0: 28 | if max(self.d) >= len(self.alpha): 29 | raise ValueError( 30 | 'fitted structure is incompatible with this model' 31 | ) 32 | else: 33 | a_c = np.bincount(self.d) 34 | a_c.resize(len(self.alpha), refcheck=False) 35 | self.w = self._rng.dirichlet(self.alpha + a_c) 36 | else: 37 | self.w = self._rng.dirichlet(self.alpha) 38 | return self.w 39 | 40 | def complete(self, size=None): 41 | super().complete(size) 42 | if len(self.w) == 0: 43 | self.random() 44 | return self.w 45 | 46 | def fit_variational(self, variational_d): 47 | assert len(variational_d) == self.n, "variational distribution must" \ 48 | "have the same length as the" \ 49 | "Dirichlet distribution's" \ 50 | "dimension" 51 | self.variational_k = self.n 52 | self.variational_d = variational_d 53 | self.variational_params = self.alpha + np.sum(self.variational_d, 1) 54 | 55 | def variational_mean_log_w_j(self, j): 56 | if self.variational_d is None: 57 | raise NotFittedError 58 | return mean_log_beta(self.variational_params[j], 59 | self.variational_params.sum()) 60 | 61 | def variational_mean_log_p_d__w(self, variational_d=None): 62 | if variational_d is None: 63 | _variational_d = self.variational_d 64 | if _variational_d is None: 65 | raise NotFittedError 66 | else: 67 | _variational_d = variational_d 68 | res = 0 69 | for j, nj in enumerate(np.sum(_variational_d, 1)): 70 | res += nj * self.variational_mean_log_w_j(j) 71 | return res 72 | 73 | def variational_mean_log_p_w(self): 74 | if self.variational_d is None: 75 | raise NotFittedError 76 | log_sum_w_j = 0 77 | for j in range(self.variational_k): 78 | log_sum_w_j += self.variational_mean_log_w_j(j) 79 | log_sum_w_j *= self.alpha.sum() - 1 80 | res = self._log_normalization_constant(self.alpha) 81 | res += log_sum_w_j 82 | return res 83 | 84 | def variational_mean_log_q_w(self): 85 | if self.variational_d is None: 86 | raise NotFittedError 87 | res = 0 88 | for j in range(self.variational_k): 89 | res += ((self.variational_params[j] - 1) * 90 | self.variational_mean_log_w_j(j)) 91 | res += self._log_normalization_constant( 92 | self.variational_params 93 | ) 94 | return res 95 | 96 | def variational_mean_w_j(self, j): 97 | if j > self.variational_k: 98 | return 0 99 | return self.variational_params[j] / self.variational_params.sum() 100 | 101 | def variational_mode_w_j(self, j): 102 | if j > self.variational_k: 103 | return 0 104 | alpha = self.variational_params.sum() 105 | if self.variational_params[j] <= 1: 106 | if alpha - self.variational_params[j] <= 1: 107 | raise ValueError('multimodal distribution') 108 | else: 109 | return 0 110 | elif alpha - self.variational_params[j] <= 1: 111 | return 1 112 | res = ((self.variational_params[j] - 1) / 113 | (alpha - 2)) 114 | return res 115 | 116 | def _log_normalization_constant(self, alpha=None): 117 | if alpha is None: 118 | _alpha = self.alpha 119 | else: 120 | _alpha = alpha 121 | log_sum = loggamma(np.sum(_alpha)) 122 | for a_j in _alpha: 123 | log_sum -= loggamma(a_j) 124 | return log_sum 125 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_geometric_process.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | from ..exceptions import NotFittedError 3 | from ..utils.functions import mean_log_beta, log_likelihood_beta 4 | 5 | import numpy as np 6 | from scipy.special import loggamma 7 | 8 | 9 | class GeometricProcess(BaseWeight): 10 | def __init__(self, a=1, b=1, rng=None): 11 | super().__init__(rng=rng) 12 | self.a = a 13 | self.b = b 14 | self.p = self._rng.beta(a=self.a, b=self.b) 15 | 16 | self.v = np.array([], dtype=np.float64) 17 | 18 | def weighting_log_likelihood(self): 19 | ret = log_likelihood_beta(self.p, self.a, self.b) 20 | return ret 21 | 22 | def random(self, size=None): 23 | if size is None and len(self.d) == 0: 24 | raise ValueError("Weight structure not fitted and `n` not passed.") 25 | if size is None: 26 | size = max(self.d) + 1 27 | self.v = self.v[:0] 28 | self.p = self._rng.beta(self.a + len(self.d), self.b + self.d.sum()) 29 | self.complete(size) 30 | return self.w 31 | 32 | def complete(self, size): 33 | super().complete(size) 34 | if len(self.v) < size: 35 | self.v = np.repeat(self.p, size) 36 | self.w = self.v * np.cumprod(np.concatenate(([1], 37 | 1 - self.v[:-1]))) 38 | return self.w 39 | 40 | def tail(self, x): 41 | if x >= 1 or x < 0: 42 | raise ValueError("Tail parameter not in range [0,1)") 43 | size = int(np.log(1 - x) / np.log(1 - self.p)) 44 | self.complete(size) 45 | return self.w 46 | 47 | def fit_variational(self, variational_d): 48 | self.variational_d = variational_d 49 | self.variational_k = len(self.variational_d) 50 | self.variational_params = np.empty(2, dtype=np.float64) 51 | self.variational_params[0] = self.a + max(len(self.variational_d[0]), 52 | 1) - 1 53 | self.variational_params[1] = ( 54 | self.b + (self.variational_d[1:].T * 55 | range(1, self.variational_k)).sum() 56 | ) 57 | 58 | def variational_mean_log_w_j(self, j): 59 | if self.variational_d is None: 60 | raise NotFittedError 61 | res = mean_log_beta(self.variational_params[0], 62 | self.variational_params[1] 63 | ) 64 | if j > 0: 65 | res += mean_log_beta(self.variational_params[1], 66 | self.variational_params[0] 67 | ) * j 68 | return res 69 | 70 | def variational_mean_log_p_d__w(self, variational_d=None): 71 | if variational_d is None: 72 | _variational_d = self.variational_d 73 | if _variational_d is None: 74 | raise NotFittedError 75 | else: 76 | _variational_d = variational_d 77 | res = mean_log_beta(self.variational_params[0], 78 | self.variational_params[1] 79 | ) * len(_variational_d[0]) 80 | e_log_v_bar = mean_log_beta(self.variational_params[1], 81 | self.variational_params[0] 82 | ) 83 | for j, nj in enumerate(np.sum(_variational_d, 1)): 84 | res += nj * e_log_v_bar 85 | return res 86 | 87 | def variational_mean_log_p_w(self): 88 | if self.variational_d is None: 89 | raise NotFittedError 90 | params = self.variational_params 91 | res = mean_log_beta(params[0], params[1]) * (self.a - 1) 92 | res += mean_log_beta(params[1], params[0]) * (self.b - 1) 93 | res += loggamma(self.a + self.b) 94 | res -= loggamma(self.a) + loggamma(self.b) 95 | return res 96 | 97 | def variational_mean_log_q_w(self): 98 | if self.variational_d is None: 99 | raise NotFittedError 100 | params = self.variational_params 101 | res = mean_log_beta(params[0], params[1]) * (params[0] - 1) 102 | res += mean_log_beta(params[1], params[0]) * (params[1] - 1) 103 | res += loggamma(params[0] + params[1]) 104 | res -= loggamma(params[0]) + loggamma(params[1]) 105 | return res 106 | 107 | def variational_mean_w_j(self, j): 108 | if j >= self.variational_k: 109 | return 1 110 | p = self.variational_params[0] / self.variational_params.sum() 111 | return p * (1 - p) ** j 112 | 113 | def variational_mode_w_j(self, j): 114 | if j > self.variational_k: 115 | return 116 | if self.variational_params[0] <= 1: 117 | if self.variational_params[1] <= 1: 118 | raise ValueError('multimodal distribution') 119 | else: 120 | return 0 121 | elif self.variational_params[1] <= 1: 122 | return 1 * (j == 0) 123 | p = ((self.variational_params[0] - 1) / 124 | (self.variational_params.sum() - 2)) 125 | res = (1 - p) ** j * p 126 | return res 127 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_beta_bernoulli.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | import numpy as np 3 | from scipy.special import beta as betaf 4 | 5 | from ..utils.functions import log_likelihood_beta 6 | 7 | 8 | class BetaBernoulli(BaseWeight): 9 | def __init__(self, p=1, alpha=1, rng=None): 10 | super().__init__(rng=rng) 11 | self.p = p 12 | self.alpha = alpha 13 | self.v = np.array([], dtype=np.float64) 14 | self.bernoullis = np.array([1], dtype=int) 15 | 16 | def weighting_log_likelihood(self): 17 | ret = self._bernoulli_structure_log_likelihood() 18 | ret += self._beta_structure_log_likelihood() 19 | return ret 20 | 21 | def _bernoulli_structure_log_likelihood(self): 22 | ret = 0 23 | if self.p == 0: 24 | return ret 25 | for b in self.bernoullis: 26 | ret += b * np.log(self.p) 27 | return ret 28 | 29 | def _beta_structure_log_likelihood(self): 30 | ret = log_likelihood_beta(self.v[0], 1, self.alpha) 31 | for j in range(1, len(self.v)): 32 | if not self.bernoullis[j]: 33 | ret += log_likelihood_beta(self.v[j], 1, self.alpha) 34 | return ret 35 | 36 | def random(self, size=None): 37 | if size is None: 38 | if len(self.d) == 0: 39 | raise ValueError( 40 | "Weight structure not fitted and `n` not passed.") 41 | size = 1 42 | self.v = self.v[:0] 43 | if len(self.d) == 0: 44 | self._random_bernoullis(size) 45 | mask_change = self.bernoullis 46 | mask_change = np.cumsum(mask_change) 47 | self.v = self._rng.beta(a=1, b=self.alpha, size=mask_change[-1] + 1) 48 | self.v = self.v[mask_change] 49 | self.w = self.v * np.cumprod(np.concatenate(([1], 50 | 1 - self.v[:-1]))) 51 | else: 52 | self._random_bernoullis(self.d.max() + 1) 53 | mask_change = self.bernoullis 54 | mask_change = np.cumsum(mask_change) 55 | a_c = np.bincount(self.d) 56 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 57 | 58 | a_c = np.bincount(mask_change, a_c) 59 | b_c = np.bincount(mask_change, b_c) 60 | 61 | self.v = self._rng.beta(a=1 + a_c, b=self.alpha + b_c) 62 | self.v = self.v[mask_change] 63 | self.w = self.v * np.cumprod(np.concatenate(([1], 64 | 1 - self.v[:-1]))) 65 | self.complete(size) 66 | return self.w 67 | 68 | def complete(self, size): 69 | super().complete(size) 70 | if len(self.v) < size: 71 | if len(self.v) == 0: 72 | self.v = self._rng.beta(a=1, b=self.alpha, size=1) 73 | mask_change = self._rng.binomial(n=1, 74 | p=self.p, 75 | size=size - len(self.v)) 76 | self.bernoullis = np.concatenate((self.bernoullis, mask_change)) 77 | mask_change = np.cumsum(mask_change) 78 | temp_v = np.concatenate(( 79 | [self.v[-1]], 80 | self._rng.beta(a=1, b=self.alpha, size=mask_change[-1]))) 81 | self.v = np.concatenate((self.v, temp_v[mask_change])) 82 | self.w = self.v * np.cumprod(np.concatenate(([1], 83 | 1 - self.v[:-1]))) 84 | return self.w 85 | 86 | def _random_bernoullis(self, size): 87 | bernoullis = self._rng.binomial(n=1, p=self.p, size=size) 88 | self.bernoullis[0] = 1 89 | if len(self.d) > 0: 90 | size_fit = self.d.max() 91 | a_c = np.bincount(self.d) 92 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 93 | bernoullis[0] = 0 94 | for j in range(1, size_fit): 95 | a_j_prime, b_j_prime, g_plus = 0, 0, 1 96 | k = j + 1 97 | for b in bernoullis[j + 1:]: 98 | if b == 1: 99 | break 100 | a_j_prime += a_c[k] 101 | b_j_prime += b_c[k] 102 | k += 1 103 | g_plus = betaf(a_j_prime, b_j_prime) * self.alpha 104 | a_j_prime += a_c[j] 105 | b_j_prime += b_c[j] 106 | k = j - 1 107 | for b in bernoullis[:j][::-1]: 108 | if b == 1: 109 | break 110 | a_j_prime += a_c[k] 111 | b_j_prime += b_c[k] 112 | k -= 1 113 | g_minus = betaf(a_j_prime, b_j_prime) * self.alpha 114 | p = self.p 115 | p_times_plus = p * g_plus if p > 0 else 0 116 | not_p_times_minus = (1 - p) * g_minus if p < 1 else 0 117 | p = p_times_plus / (p_times_plus + not_p_times_minus) 118 | bernoullis[j] = self._rng.binomial(n=1, p=p) 119 | self.bernoullis = bernoullis 120 | return self.bernoullis 121 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_utils.py: -------------------------------------------------------------------------------- 1 | from scipy.stats import invwishart, multivariate_normal, norm 2 | from scipy.special import digamma, loggamma 3 | from sklearn.cluster import KMeans 4 | import numpy as np 5 | 6 | 7 | def random_normal_invw(mu, lam, psi, nu, rng=None): 8 | if rng is None: 9 | rng = np.random.default_rng() 10 | ret_sigma = invwishart.rvs(nu, psi, 11 | random_state=rng) 12 | ret_mu = multivariate_normal.rvs(mu, ret_sigma / lam, 13 | random_state=rng) 14 | return ret_mu, ret_sigma 15 | 16 | 17 | def log_likelihood_normal_invw(mu, sigma, mu0, lam0, psi0, nu0): 18 | log_sigma = invwishart.logpdf(sigma, nu0, psi0) 19 | log_mu = multivariate_normal.logpdf(mu, mu0, sigma / lam0) 20 | return log_sigma + log_mu 21 | 22 | 23 | def posterior_norm_invw_params(y, mu, lam, psi, nu): 24 | n = len(y) 25 | ret_mu = (lam * mu + n * y.mean(axis=0)) / (lam + n) 26 | ret_lam = lam + n 27 | ret_psi = psi + n * np.cov(y.T, bias=True) + ( 28 | (lam * n) / (lam + n) * np.outer(y.mean(axis=0) - mu, 29 | y.mean(axis=0) - mu)) 30 | ret_nu = nu + n 31 | return {"mu": ret_mu, "lambda": ret_lam, "psi": ret_psi, "nu": ret_nu} 32 | 33 | 34 | def gumbel_max_sampling(logp, size=None, *, rng=None): 35 | if size is None: 36 | ret = np.argmax(logp - 37 | np.log(-np.log(rng.uniform(size=logp.shape))), axis=0) 38 | else: 39 | ret = [] 40 | for i in range(size): 41 | ret.append(np.argmax(logp - 42 | np.log(-np.log(rng.uniform(size=len(logp)))))) 43 | ret = np.array(ret) 44 | return ret 45 | 46 | 47 | def rejection_sample(f, max_y, a=0, b=1, size=None, *, rng=None): 48 | if size is None: 49 | x = rng.uniform(a, b) 50 | y = rng.uniform(0, max_y) 51 | while y > f(x): 52 | x = rng.uniform(a, b) 53 | y = rng.uniform(0, max_y) 54 | return x 55 | else: 56 | x = rng.uniform(a, b, size) 57 | y = rng.uniform(0, max_y, size) 58 | while np.any(y > f(x)): 59 | x[y > f(x)] = rng.uniform(a, b, np.sum(y > f(x))) 60 | y[y > f(x)] = rng.uniform(0, max_y, np.sum(y > f(x))) 61 | return x 62 | 63 | 64 | def mixture_density(x, w, theta, dim=None, component=None): 65 | k = len(w) 66 | 67 | ret = [] 68 | if component is None: 69 | iterator = range(k) 70 | else: 71 | iterator = [component] 72 | for j in iterator: 73 | if dim is None: 74 | ret.append(multivariate_normal.pdf(x, 75 | theta[j][0], 76 | theta[j][1], 77 | 1)*w[j]) 78 | elif type(dim) in [list, np.ndarray]: 79 | ret.append(multivariate_normal.pdf(x, 80 | theta[j][0][dim], 81 | theta[j][1][:, dim][dim, :], 82 | 1)*w[j]) 83 | elif type(dim) is int: 84 | ret.append(norm.pdf(x[:, 0], 85 | theta[j][0][dim], 86 | np.sqrt(theta[j][1][dim, dim]))*w[j]) 87 | ret = np.array(ret).T 88 | ret = np.atleast_2d(ret).sum(1) 89 | return ret 90 | 91 | 92 | def cluster(x, w, theta): 93 | k = len(w) 94 | assign_prob = [] 95 | for j in range(k): 96 | assign_prob.append(multivariate_normal.pdf(x, 97 | theta[j][0], 98 | theta[j][1], 99 | 1)) 100 | assign_prob = np.array(assign_prob).T 101 | assign_prob = assign_prob * w 102 | assign_prob /= assign_prob.sum(1)[:, None] 103 | grp = np.argmax(assign_prob, axis=1) 104 | uncertainty = 1 - assign_prob[range(len(x)), grp] 105 | u_grp, ret = np.unique(grp, return_inverse=True) 106 | return ret, uncertainty 107 | 108 | 109 | def log_wishart_normalization_term(precision, scale): 110 | dim = precision.shape[0] 111 | res = np.log(np.linalg.norm(precision)) * (-scale / 2) 112 | inverse_term = np.log(2) * (scale * dim / 2) 113 | inverse_term += np.log(np.pi) * (dim * (dim - 1) / 4) 114 | for i in range(dim): 115 | inverse_term += loggamma((scale - i) / 2) 116 | res -= inverse_term 117 | return res 118 | 119 | 120 | def e_log_norm_wishart(precision, scale): 121 | dim = precision.shape[0] 122 | res = np.log(np.linalg.norm(precision)) 123 | res += dim * np.log(2) 124 | for i in range(dim): 125 | res += digamma((scale - i) / 2) 126 | return res 127 | 128 | 129 | def entropy_wishart(precision, scale): 130 | dim = precision.shape[0] 131 | res = -log_wishart_normalization_term(precision, scale) 132 | res -= (scale - dim - 1) / 2 * e_log_norm_wishart(precision, scale) 133 | res += scale * dim / 2 134 | return res 135 | 136 | 137 | def kmeans_cluster_size_biased(y, k, rng): 138 | # TODO wait for sklearn to implement Generator as random_state 139 | # input https://github.com/scikit-learn/scikit-learn/issues/16988 140 | km = KMeans( 141 | n_clusters=k, 142 | n_init=10, 143 | random_state=np.random.RandomState(rng.bit_generator) 144 | ) 145 | d = km.fit_predict(y) 146 | _, d, d_count = np.unique(d, return_inverse=True, return_counts=True) 147 | # Rename assignation vector by size bias, the group with more counts gets 148 | # the label 0, and so on 149 | d = np.argsort(np.argsort(-d_count))[d] 150 | return d 151 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_beta_in_dirichlet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import loggamma 3 | 4 | from ._base import BaseWeight 5 | from ..mixture_models._utils import gumbel_max_sampling 6 | from ..utils.functions import log_likelihood_beta, dirichlet_log_eppf 7 | 8 | 9 | class BetaInDirichlet(BaseWeight): 10 | def __init__(self, alpha=1, a=0, rng=None): 11 | super().__init__(rng=rng) 12 | self.a = a 13 | self.alpha = alpha 14 | self.v = np.array([], dtype=np.float64) 15 | self._v_base = np.array([], dtype=np.float64) 16 | self._d_base = np.array([], dtype=np.int64) 17 | self._count_base = [] 18 | 19 | def weighting_log_likelihood(self): 20 | v_unique, v_counts = np.unique(self.v, return_counts=True) 21 | ret = 0 22 | for vj in v_unique: 23 | ret += log_likelihood_beta(vj, 1, self.alpha) 24 | ret += dirichlet_log_eppf(self.a, v_counts) 25 | return ret 26 | 27 | def random(self, size=None, u=None): 28 | if size is None and len(self.d) == 0: 29 | raise ValueError("Weight structure not fitted and `n` not passed.") 30 | if size is not None: 31 | if type(size) is not int: 32 | raise TypeError("size parameter must be integer or None") 33 | self.v = self.v[:0] 34 | self._v_base = self._v_base[:0] 35 | if len(self.d) == 0: 36 | self.complete(size) 37 | return self.w 38 | n = max(self.d) + 1 39 | if len(self._d_base) < n: 40 | self._d_base = np.concatenate( 41 | [self._d_base, [0] * (n - len(self._d_base))]) 42 | elif len(self._d_base) > n: 43 | self._d_base = self._d_base[:n] 44 | k = max(self._d_base) + 1 45 | v_base = np.empty(k, dtype=np.float64) 46 | a_c = np.bincount(self.d) 47 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 48 | # Update inner weights given the inner assignations 49 | for jj in range(k): 50 | a_c_base = np.sum(a_c[self._d_base == jj]) 51 | b_c_base = np.sum(b_c[self._d_base == jj]) 52 | v_base[jj] = self._rng.beta(a=1 + a_c_base, 53 | b=self.alpha + b_c_base) 54 | self._v_base = v_base 55 | # Update the inner assignations given inner weights and other 56 | # assignations. It is, update inner d_jj given d_{-jj}, inner v. 57 | for j in range(n): 58 | if self.a == 0: 59 | if j < len(self._d_base): 60 | self._d_base[j] = 0 61 | else: 62 | self._d_base = np.append(self._d_base, 0) 63 | continue 64 | d_base_reduced = np.delete(self._d_base, j) 65 | k = np.max(d_base_reduced) 66 | log_prob = [] 67 | new_base = False 68 | for dd in range(k + 2): 69 | base_count = sum(d_base_reduced == dd) 70 | if base_count == 0 and not new_base: 71 | new_base = True 72 | temp_log_prob = ( 73 | np.log(self.a) + loggamma(1 + a_c[j]) 74 | + loggamma(self.alpha + b_c[j]) 75 | - loggamma(1 + self.alpha + a_c[j] + b_c[j])) 76 | log_prob.append(temp_log_prob) 77 | continue 78 | if base_count == 0: 79 | if dd < k + 1: 80 | log_prob.append(-np.inf) 81 | continue 82 | temp_log_prob = ( 83 | np.log(base_count) 84 | + a_c[j] * np.log(self._v_base[dd]) 85 | + b_c[j] * np.log(1 - self._v_base[dd])) 86 | log_prob.append(temp_log_prob) 87 | log_prob = np.array(log_prob) 88 | if j < len(self._d_base): 89 | self._d_base[j] = gumbel_max_sampling(log_prob, rng=self._rng) 90 | else: 91 | self._d_base = np.append(self._d_base, 92 | gumbel_max_sampling(log_prob, 93 | rng=self._rng)) 94 | if max(self._d_base) <= len(self._v_base): 95 | self._v_base = np.concatenate(( 96 | self._v_base, [self._rng.beta(a=1, b=self.alpha)])) 97 | self._count_base = [ 98 | np.sum(self._d_base == j) for j in range(len(self._v_base))] 99 | self.v = self._v_base[self._d_base] 100 | self.w = self.v * np.cumprod(np.concatenate(([1], 101 | 1 - self.v[:-1]))) 102 | return self.w 103 | 104 | def complete(self, size): 105 | if len(self._v_base) == 0: 106 | self._v_base = self._rng.beta(1, self.alpha, size=1) 107 | self._count_base = [1] 108 | while len(self.v) < size: 109 | p = np.array(self._count_base + [self.a], dtype=np.float64) 110 | p /= p.sum() 111 | jj = self._rng.choice(range(len(self._v_base) + 1), p=p) 112 | if jj < len(self._v_base): 113 | self.v = np.append(self.v, self._v_base[jj]) 114 | self._count_base[jj] += 1 115 | else: 116 | new_v_base = self._rng.beta(1, self.alpha) 117 | self._v_base = np.append(self._v_base, new_v_base) 118 | self._count_base += [1] 119 | self.v = np.append(self.v, new_v_base) 120 | self._d_base = np.append(self._d_base, jj) 121 | self.w = self.v * np.cumprod(np.concatenate(([1], 122 | 1 - self.v[:-1]))) 123 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_dirichlet_process.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | from ..exceptions import NotFittedError 3 | from ..utils.functions import mean_log_beta, log_likelihood_beta 4 | 5 | import numpy as np 6 | from scipy.special import loggamma 7 | 8 | 9 | class DirichletProcess(BaseWeight): 10 | def __init__(self, alpha=1, rng=None): 11 | super().__init__(rng=rng) 12 | self.alpha = alpha 13 | self.v = np.array([], dtype=np.float64) 14 | 15 | def weighting_log_likelihood(self): 16 | v = self.w[0] 17 | ret = log_likelihood_beta(v, 1, self.alpha) 18 | prod_v = 1 - v 19 | for wj in self.w[1:]: 20 | v = wj / prod_v 21 | ret += log_likelihood_beta(v, 1, self.alpha) 22 | prod_v *= (1 - v) 23 | return ret 24 | 25 | def random(self, size=None): 26 | if size is None and len(self.d) == 0: 27 | raise ValueError("Weight structure not fitted and `n` not passed.") 28 | if size is not None: 29 | if type(size) is not int: 30 | raise TypeError("size parameter must be integer or None") 31 | self.v = self.v[:0] 32 | if len(self.d) == 0: 33 | self.complete(size) 34 | else: 35 | a_c = np.bincount(self.d) 36 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 37 | 38 | if size is not None and size < len(a_c): 39 | a_c = a_c[:size] 40 | b_c = b_c[:size] 41 | 42 | self.v = self._rng.beta(a=1 + a_c, b=self.alpha + b_c) 43 | self.w = self.v * np.cumprod(np.concatenate(([1], 44 | 1 - self.v[:-1]))) 45 | if size is not None: 46 | self.complete(size) 47 | return self.w 48 | 49 | def complete(self, size): 50 | super().complete(size) 51 | if len(self.v) < size: 52 | self.v = np.concatenate( 53 | (self.v, 54 | self._rng.beta(a=1, b=self.alpha, size=size - len(self.v)))) 55 | self.w = self.v * np.cumprod(np.concatenate(([1], 56 | 1 - self.v[:-1]))) 57 | return self.w 58 | 59 | def fit_variational(self, variational_d): 60 | self.variational_d = variational_d 61 | self.variational_k = len(self.variational_d) 62 | self.variational_params = np.empty((self.variational_k, 2), 63 | dtype=np.float64) 64 | a_c = np.sum(self.variational_d, 1) 65 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 66 | self.variational_params[:, 0] = 1 + a_c 67 | self.variational_params[:, 1] = self.alpha + b_c 68 | 69 | def variational_mean_log_w_j(self, j): 70 | if self.variational_d is None: 71 | raise NotFittedError 72 | if j >= self.variational_k: 73 | return -np.inf 74 | res = 0 75 | for jj in range(j): 76 | res += mean_log_beta(self.variational_params[jj][1], 77 | self.variational_params[jj][0]) 78 | if j < self.variational_k - 1: 79 | res += mean_log_beta(self.variational_params[j, 0], 80 | self.variational_params[j, 1]) 81 | return res 82 | 83 | def variational_mean_log_p_d__w(self, variational_d=None): 84 | if variational_d is None: 85 | _variational_d = self.variational_d 86 | if _variational_d is None: 87 | raise NotFittedError 88 | else: 89 | _variational_d = variational_d 90 | res = 0 91 | for j, nj in enumerate(np.sum(_variational_d, 1)): 92 | res += nj * self.variational_mean_log_w_j(j) 93 | return res 94 | 95 | def variational_mean_log_p_w(self): 96 | if self.variational_d is None: 97 | raise NotFittedError 98 | res = 0 99 | for params in self.variational_params[:-1]: 100 | res += mean_log_beta(params[1], params[0]) 101 | res *= self.alpha - 1 102 | res += self.variational_k * np.log(self.alpha) 103 | return res 104 | 105 | def variational_mean_log_q_w(self): 106 | if self.variational_d is None: 107 | raise NotFittedError 108 | res = 0 109 | for params in self.variational_params[:-1]: 110 | res += (params[0] - 1) * mean_log_beta(params[0], params[1]) 111 | res += (params[1] - 1) * mean_log_beta(params[1], params[0]) 112 | res += loggamma(params[0] + params[1]) 113 | res -= loggamma(params[0]) + loggamma(params[1]) 114 | return res 115 | 116 | def variational_mean_w_j(self, j): 117 | if j >= self.variational_k: 118 | return 0 119 | res = 1 120 | for jj in range(j): 121 | res *= (self.variational_params[jj][1] / 122 | self.variational_params[jj].sum()) 123 | if j < self.variational_k - 1: 124 | res *= self.variational_params[j, 0] / self.variational_params[ 125 | j].sum() 126 | return res 127 | 128 | def variational_mode_w_j(self, j): 129 | if j >= self.variational_k: 130 | return 0 131 | res = 1 132 | for jj in range(j): 133 | if self.variational_params[jj, 1] <= 1: 134 | if self.variational_params[jj, 0] <= 1: 135 | raise ValueError('multimodal distribution') 136 | else: 137 | return 0 138 | elif self.variational_params[jj, 0] <= 1: 139 | continue 140 | res *= ((self.variational_params[jj, 1] - 1) / 141 | (self.variational_params[jj].sum() - 2)) 142 | if j == self.variational_k - 1: 143 | return res 144 | if self.variational_params[j, 0] <= 1: 145 | if self.variational_params[j, 1] <= 1: 146 | raise ValueError('multimodal distribution') 147 | else: 148 | return 0 149 | elif self.variational_params[j, 1] <= 1: 150 | return res 151 | res *= ((self.variational_params[j, 0] - 1) / 152 | (self.variational_params[j].sum() - 2)) 153 | return res 154 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_pitman_yor_process.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseWeight 2 | from ..exceptions import NotFittedError 3 | from ..utils.functions import mean_log_beta, log_likelihood_beta 4 | 5 | import numpy as np 6 | from scipy.special import loggamma 7 | 8 | 9 | class PitmanYorProcess(BaseWeight): 10 | def __init__(self, pyd=0, alpha=1, truncation_length=-1, rng=None): 11 | super().__init__(rng=rng) 12 | assert -pyd < alpha, "alpha param must be greater than -pyd" 13 | self.pyd = pyd 14 | self.alpha = alpha 15 | self.v = np.array([], dtype=np.float64) 16 | self.truncation_length = truncation_length 17 | 18 | def weighting_log_likelihood(self): 19 | v = self.w[0] 20 | ret = log_likelihood_beta(v, 1 - self.pyd, self.alpha) 21 | prod_v = 1 - v 22 | for j in range(1, len(self.w)): 23 | v = self.w[j] / prod_v 24 | ret += log_likelihood_beta(v, 1 - self.pyd, 25 | self.alpha + j * self.pyd) 26 | prod_v *= (1 - v) 27 | return ret 28 | 29 | def random(self, size=None): 30 | if size is None and len(self.d) == 0: 31 | raise ValueError("Weight structure not fitted and `n` not passed.") 32 | if len(self.d) == 0: 33 | pitman_yor_bias = np.arange(size) 34 | self.v = self._rng.beta(a=1 - self.pyd, 35 | b=self.alpha + pitman_yor_bias * self.pyd, 36 | size=size) 37 | self.w = self.v * np.cumprod(np.concatenate(([1], 38 | 1 - self.v[:-1]))) 39 | else: 40 | a_c = np.bincount(self.d) 41 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 42 | 43 | if size is not None and size < len(a_c): 44 | a_c = a_c[:size] 45 | b_c = b_c[:size] 46 | 47 | pitman_yor_bias = np.arange(len(a_c)) 48 | self.v = self._rng.beta( 49 | a=1 - self.pyd + a_c, 50 | b=self.alpha + pitman_yor_bias * self.pyd + b_c 51 | ) 52 | self.w = self.v * np.cumprod(np.concatenate(([1], 53 | 1 - self.v[:-1]))) 54 | if size is not None: 55 | self.complete(size) 56 | return self.w 57 | 58 | def complete(self, size): 59 | super().complete(size) 60 | if self.get_size() < size: 61 | pitman_yor_bias = np.arange(self.get_size(), size) 62 | self.v = np.concatenate( 63 | ( 64 | self.v, 65 | self._rng.beta(a=1 - self.pyd, 66 | b=self.alpha + pitman_yor_bias * self.pyd) 67 | ) 68 | ) 69 | self.w = self.v * np.cumprod(np.concatenate(([1], 70 | 1 - self.v[:-1]))) 71 | return self.w 72 | 73 | def fit_variational(self, variational_d): 74 | self.variational_d = variational_d 75 | self.variational_k = len(self.variational_d) 76 | self.variational_params = np.empty((self.variational_k, 2), 77 | dtype=np.float64) 78 | a_c = np.sum(self.variational_d, 1) 79 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 80 | self.variational_params[:, 0] = 1 - self.pyd + a_c 81 | self.variational_params[:, 1] = self.alpha + ( 82 | 1 + np.arange(self.variational_params.shape[0]) 83 | ) * self.pyd + b_c 84 | 85 | def variational_mean_log_w_j(self, j): 86 | if self.variational_d is None: 87 | raise NotFittedError 88 | res = 0 89 | for jj in range(j): 90 | res += mean_log_beta(self.variational_params[jj][1], 91 | self.variational_params[jj][0]) 92 | res += mean_log_beta(self.variational_params[j, 0], 93 | self.variational_params[j, 1] 94 | ) 95 | return res 96 | 97 | def variational_mean_log_p_d__w(self, variational_d=None): 98 | if variational_d is None: 99 | _variational_d = self.variational_d 100 | if _variational_d is None: 101 | raise NotFittedError 102 | else: 103 | _variational_d = variational_d 104 | res = 0 105 | for j, nj in enumerate(np.sum(_variational_d, 1)): 106 | res += nj * self.variational_mean_log_w_j(j) 107 | return res 108 | 109 | def variational_mean_log_p_w(self): 110 | if self.variational_d is None: 111 | raise NotFittedError 112 | res = 0 113 | for j, params in enumerate(self.variational_params): 114 | res += mean_log_beta(params[0], params[1]) * -self.pyd 115 | res += mean_log_beta(params[1], params[0]) * ( 116 | self.alpha + (j + 1) * self.pyd - 1 117 | ) 118 | res += loggamma(self.alpha + j * self.pyd + 1) 119 | res -= loggamma(self.alpha + (j + 1) * self.pyd + 1) 120 | res -= loggamma(1 - self.pyd) 121 | return res 122 | 123 | def variational_mean_log_q_w(self): 124 | if self.variational_d is None: 125 | raise NotFittedError 126 | res = 0 127 | for params in self.variational_params: 128 | res += (params[0] - 1) * mean_log_beta(params[0], params[1]) 129 | res += (params[1] - 1) * mean_log_beta(params[1], params[0]) 130 | res += loggamma(params[0] + params[1]) 131 | res -= loggamma(params[0]) + loggamma(params[1]) 132 | return res 133 | 134 | def variational_mean_w_j(self, j): 135 | if j > self.variational_k: 136 | return 0 137 | res = 1 138 | for jj in range(j): 139 | res *= (self.variational_params[jj][1] / 140 | self.variational_params[jj].sum()) 141 | res *= self.variational_params[j, 0] / self.variational_params[j].sum() 142 | return res 143 | 144 | def variational_mode_w_j(self, j): 145 | if j > self.variational_k: 146 | return 0 147 | res = 1 148 | for jj in range(j): 149 | if self.variational_params[jj, 1] <= 1: 150 | if self.variational_params[jj, 0] <= 1: 151 | raise ValueError('multimodal distribution') 152 | else: 153 | return 0 154 | elif self.variational_params[jj, 0] <= 1: 155 | continue 156 | res *= ((self.variational_params[jj, 1] - 1) / 157 | (self.variational_params[jj].sum() - 2)) 158 | 159 | if self.variational_params[j, 0] <= 1: 160 | if self.variational_params[j, 1] <= 1: 161 | raise ValueError('multimodal distribution') 162 | else: 163 | return 0 164 | elif self.variational_params[j, 1] <= 1: 165 | return res 166 | res *= ((self.variational_params[j, 0] - 1) / 167 | (self.variational_params[j].sum() - 2)) 168 | return res 169 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_beta_in_beta.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import minimize, brentq 3 | from scipy.integrate import quad 4 | from scipy.stats import beta 5 | 6 | from ._base import BaseWeight 7 | from ..utils.functions import log_likelihood_beta 8 | 9 | 10 | class BetaInBeta(BaseWeight): 11 | def __init__(self, x=0, alpha=1, a=1, b=1, p=0, 12 | p_method="geometric", 13 | p_optim_max_steps=10, rng=None): 14 | super().__init__(rng=rng) 15 | self.x = x 16 | self.a = a 17 | self.b = b 18 | self.alpha = alpha 19 | self.p = p 20 | 21 | self.p_method = p_method 22 | self.v = np.array([], dtype=np.float64) 23 | self.p_optim_max_steps = p_optim_max_steps 24 | self._validate_params() 25 | 26 | def weighting_log_likelihood(self): 27 | v = self.w[0] 28 | beta_a = 1 + self.x / (1 - self.x) * self.p 29 | beta_b = self.alpha + self.x / (1 - self.x) * (1 - self.p) 30 | ret = log_likelihood_beta(v, beta_a, beta_b) 31 | prod_v = 1 - v 32 | for wj in self.w[1:]: 33 | v = wj / prod_v 34 | ret += log_likelihood_beta(v, beta_a, beta_b) 35 | prod_v *= (1 - v) 36 | ret += self._internal_beta_log_likelihood() 37 | return ret 38 | 39 | def _internal_beta_log_likelihood(self): 40 | return log_likelihood_beta(self.p, self.a, self.b) 41 | 42 | def random(self, size=None): 43 | if size is None and len(self.d) == 0: 44 | raise ValueError("Weight structure not fitted and `n` not passed.") 45 | self.v = self.v[:0] 46 | if len(self.d) == 0: 47 | self.complete(size) 48 | else: 49 | if self.x > 0: 50 | self.random_p() 51 | a_c = np.bincount(self.d) 52 | b_c = np.concatenate((np.cumsum(a_c[::-1])[-2::-1], [0])) 53 | 54 | if size is None: 55 | size = len(a_c) 56 | elif size < len(a_c): 57 | a_c = a_c[:size] 58 | b_c = b_c[:size] 59 | else: 60 | pass 61 | 62 | if self.x < 1: 63 | self.v = self._rng.beta( 64 | a=1 + self.x / (1 - self.x) * self.p + a_c, 65 | b=self.alpha + self.x / (1 - self.x) * (1 - self.p) + b_c 66 | ) 67 | else: 68 | self.v = np.repeat(self.p, size) 69 | self.w = self.v * np.cumprod(np.concatenate(([1], 70 | 1 - self.v[:-1]))) 71 | if size is not None: 72 | self.complete(size) 73 | return self.w 74 | 75 | def complete(self, size): 76 | super().complete(size) 77 | if len(self.v == 0): 78 | self.random_p() 79 | if len(self.v) < size: 80 | if self.x < 1: 81 | concat_value = self._rng.beta( 82 | a=1 + self.x / (1 - self.x) * self.p, 83 | b=self.alpha + self.x / (1 - self.x) * (1 - self.p), 84 | size=size - len(self.v) 85 | ) 86 | else: 87 | concat_value = np.repeat(self.p, size - len(self.v)) 88 | self.v = np.append(self.v, concat_value) 89 | self.w = self.v * np.cumprod(np.concatenate(([1], 90 | 1 - self.v[:-1]))) 91 | return self.w 92 | 93 | def random_p(self): 94 | if self.x == 1: 95 | self.p = self._rng.beta(a=self.a + len(self.d), 96 | b=self.b + self.d.sum()) 97 | return self.p 98 | if self.x == 0: 99 | return self.p 100 | if len(self.d) == 0: 101 | self.p = self._rng.beta(a=self.a, b=self.b) 102 | return self.p 103 | # It's not Dirichlet or Gemetric and we need to fit it 104 | if self.p_method == "static": 105 | return self.p 106 | elif self.p_method == "independent": 107 | self.p = self._rng.beta(a=self.a, b=self.b) 108 | return self.p 109 | elif self.p_method == "geometric": 110 | self.p = self._rng.beta(a=self.a + len(self.d), 111 | b=self.b + self.d.sum()) 112 | return self.p 113 | elif self.p_method == "max-likelihood": 114 | max_param = minimize( 115 | lambda p: -self._structure_log_likelihood(p=p), 116 | np.array([self.p]), 117 | bounds=[(0, 1)], 118 | options={'maxiter': self.p_optim_max_steps}) 119 | if max_param.success: 120 | self.p = max_param.x[0] 121 | return self.p 122 | elif self.p_method == "inverse-sampling": 123 | unif = self._rng.uniform() 124 | 125 | def f(p): 126 | return np.exp(self._structure_log_likelihood(p=p)) 127 | 128 | integral_normalization = quad(f, 0, 1)[0] 129 | 130 | def f_integral(p): 131 | return quad(f, 0, p)[0] / integral_normalization - unif 132 | 133 | try: 134 | self.p = brentq(f_integral, a=0, b=1) 135 | except ValueError: 136 | pass 137 | return self.p 138 | else: 139 | raise ValueError(f"unknown p-method") 140 | 141 | def _validate_params(self): 142 | accepted_methods = ["static", "independent", "geometric", 143 | "rejection-sampling", "max-likelihood", 144 | "inverse-sampling"] 145 | if self.p_method not in accepted_methods: 146 | raise ValueError(f"p_method must be one of {accepted_methods}") 147 | 148 | def _structure_log_likelihood(self, v=None, p=None, x=None, alpha=None): 149 | if v is None: 150 | v = self.v 151 | if p is None: 152 | p = self.p 153 | if x is None: 154 | x = self.x 155 | if alpha is None: 156 | alpha = self.alpha 157 | log_likelihood = self._weight_log_likelihood(v=v, p=p, x=x, 158 | alpha=alpha) 159 | log_likelihood += self._p_log_likelihood(p=p) 160 | return log_likelihood 161 | 162 | def _weight_log_likelihood(self, v=None, p=None, x=None, alpha=None, 163 | a=None, b=None): 164 | if v is None: 165 | v = self.v 166 | if p is None: 167 | p = self.p 168 | if x is None: 169 | x = self.x 170 | if alpha is None: 171 | alpha = self.alpha 172 | if a is None: 173 | alpha = self.alpha 174 | if b is None: 175 | alpha = self.alpha 176 | if x == 1: 177 | if len(v) == 0: 178 | return 0 179 | if np.all(v == v[0]): 180 | return 0 181 | else: 182 | return -np.inf 183 | return np.sum( 184 | beta.logpdf(v, 185 | a=1 + x / (1 - x) * p, 186 | b=alpha + x / (1 - x) * (1 - p))) 187 | 188 | def _p_log_likelihood(self, p=None, a=None, b=None): 189 | if p is None: 190 | p = self.p 191 | if a is None: 192 | a = self.a 193 | if b is None: 194 | b = self.b 195 | return beta.logpdf(p, a=a, b=b) 196 | -------------------------------------------------------------------------------- /pyrichlet/weight_models/_base.py: -------------------------------------------------------------------------------- 1 | """Base class for weighting structure models.""" 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | 5 | from ..utils.validators import rng_parser 6 | 7 | 8 | class BaseWeight(ABC): 9 | """Base class for weighting structure models. 10 | 11 | This abstract class specifies an interface for all weighting structure 12 | classes and provides basic common methods for weighting models. 13 | """ 14 | 15 | def __init__(self, rng=None): 16 | self._rng = rng_parser(rng) 17 | 18 | self.w = np.array([], dtype=np.float64) 19 | self.d = np.array([], dtype=np.int64) 20 | self.variational_params = None 21 | self.variational_d = None 22 | self.variational_k = None 23 | 24 | @abstractmethod 25 | def random(self, size=None): 26 | """Do a random draw of the truncated weighting structure up to `n` obs. 27 | 28 | This method does a random draw from the posterior weighting 29 | distribution (or from the prior distribution if nothing has been 30 | fitted) and updates the internal truncated weighting structure 31 | `self.w`. 32 | 33 | Parameters 34 | ---------- 35 | size : int 36 | The desired size of the returned vector of weights 37 | 38 | Returns 39 | ------- 40 | np.array 41 | Array of the weighted structure. 42 | """ 43 | pass 44 | 45 | @abstractmethod 46 | def complete(self, size): 47 | """Return an array of weights with at least `n` elements 48 | 49 | This method appends weights to the truncated weighting structure 50 | `self.w` until reaching a length of `size` and then returns `self.w`. 51 | Note: This method sets a constraint on the minimum number of elements 52 | in the truncated weighting structure. No truncation is induced in case 53 | `size` is less than `len(self.w)` and the full length of `self.w` is 54 | returned. 55 | 56 | Parameters 57 | ---------- 58 | size : int 59 | The desired size of the returned vector of weights 60 | 61 | Returns 62 | ------- 63 | np.array 64 | Array of the weighted structure. 65 | """ 66 | if size is not None and type(size) not in (int, np.int64): 67 | raise TypeError("size parameter must be integer or None") 68 | 69 | def weighting_log_likelihood(self): 70 | """Return the given structure log-likelihood 71 | 72 | This method returns log f(w) for the underlying weighting model. 73 | 74 | Returns 75 | ------- 76 | float 77 | The log-likelihood value 78 | """ 79 | pass 80 | 81 | def fit(self, d): 82 | """Fit the weighting structure to a vector of assignments 83 | 84 | This method fits the parameters of the weighting model given the 85 | internal truncated weighting structure `self.w`. Any call to the 86 | methods `random`, `tail` or `complete` after calling this method 87 | results in a random draw from the posterior distribution. 88 | 89 | Parameters 90 | ---------- 91 | d : array[int], np.array 92 | An array of integers representing the assigned group 93 | """ 94 | self.d = np.array(d) 95 | 96 | def tail(self, x): 97 | """Return an array of weights such that the sum is greater than `x` 98 | 99 | This method appends weights to the truncated weighting structure 100 | `self.w` until the sum of its elements is greater than the input `x` 101 | and then returns `self.w`. 102 | 103 | Parameters 104 | ---------- 105 | x : float 106 | A float in the range $[0,1)$ for which the sum of weights must be 107 | greater. 108 | 109 | Returns 110 | ------- 111 | np.array 112 | Array of the completed weighted structure 113 | """ 114 | if x >= 1 or x < 0: 115 | raise ValueError("Tail parameter not in range [0,1)") 116 | if len(self.w) == 0: 117 | self.random(1) 118 | while self.w.sum() < x: 119 | self.complete(len(self.w) + 1) 120 | return self.w 121 | 122 | def assignation_log_likelihood(self, d=None): 123 | """Returns the log-likelihood of an assignment `d` given the weights""" 124 | if d is None: 125 | d = self.d 126 | self.complete(max(d) + 1) 127 | with np.errstate(divide='ignore'): 128 | ret = np.sum(np.log(self.w[d])) 129 | return ret 130 | 131 | def reset(self): 132 | """Resets the conditional vector `d` to None""" 133 | self.d = np.array([], dtype=np.int64) 134 | 135 | def get_weights(self): 136 | """Returns the last weighting structure drawn""" 137 | return self.w 138 | 139 | def get_normalized_weights(self): 140 | """Returns the last weighting stricture normalized""" 141 | return self.w / np.sum(self.w) 142 | 143 | def get_normalized_cumulative_weights(self): 144 | """Returns the normalized cumulative weights""" 145 | return np.cumsum(self.get_normalized_weights()) 146 | 147 | def get_size(self): 148 | """Returns the size of the truncated weighting structure""" 149 | return len(self.w) 150 | 151 | def random_assignment(self, size=None): 152 | """Returns a sample draw of the categorical assignment from the current 153 | state normalized weighting structure""" 154 | u = self._rng.uniform(size=size) 155 | inverse_sampling = np.greater.outer( 156 | u, self.get_normalized_cumulative_weights() 157 | ) 158 | return np.sum(inverse_sampling, axis=1) 159 | 160 | def fit_variational(self, variational_d): 161 | """Fits the variational distribution q 162 | 163 | This method fits the variational distribution q that minimizes the 164 | Kullback-Leiber divergence from q(w) to p(w|d) ($D_{KL}(q||p)$) where 165 | d has a discrete finite random distribution given by 166 | q(d_i = j) = variational_d[j, i] and q is truncated up to 167 | `k=len(variational_d)` so that q(w_k=1) = 1. 168 | """ 169 | raise NotImplementedError 170 | 171 | def variational_mean_log_w_j(self, j): 172 | """Returns the mean of the logarithm of w_j 173 | 174 | This method returns the expected value of the logarithm of w_j under 175 | the variational distribution q. 176 | """ 177 | raise NotImplementedError 178 | 179 | def variational_mean_log_p_d__w(self, variational_d=None): 180 | """Returns the mean of log p(d|w) 181 | 182 | This method returns the expected value of the logarithm of the 183 | probability of assignation d given w under the variational 184 | distribution q. 185 | """ 186 | raise NotImplementedError 187 | 188 | def variational_mean_log_p_w(self): 189 | """Returns the mean of log p(w) 190 | 191 | This method returns the expected value of the logarithm of the 192 | probability of assignation d given w under the variational 193 | distribution q. 194 | """ 195 | raise NotImplementedError 196 | 197 | def variational_mean_log_q_w(self): 198 | """Returns the mean of log q(w) 199 | 200 | This method returns the expected value of the logarithm of the 201 | probability of assignation d given w under the variational 202 | distribution q. 203 | """ 204 | raise NotImplementedError 205 | 206 | def variational_mean_w_j(self, j): 207 | """Returns the mean of w_j 208 | 209 | This method returns the expected value of the j-th weighting factor 210 | under the variational distribution q. 211 | """ 212 | raise NotImplementedError 213 | 214 | def variational_mode_w_j(self, j): 215 | """Returns the mean of w_j 216 | 217 | This method returns the expected value of the j-th weighting factor 218 | under the variational distribution q. 219 | """ 220 | raise NotImplementedError 221 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | ------------------------------------------------------------------------------- 204 | 205 | Copyright 2020-2021 Fidel Selva 206 | 207 | Licensed under the Apache License, Version 2.0 (the "License"); 208 | you may not use this file except in compliance with the License. 209 | You may obtain a copy of the License at 210 | 211 | http://www.apache.org/licenses/LICENSE-2.0 212 | 213 | Unless required by applicable law or agreed to in writing, software 214 | distributed under the License is distributed on an "AS IS" BASIS, 215 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 216 | See the License for the specific language governing permissions and 217 | limitations under the License. -------------------------------------------------------------------------------- /pyrichlet/utils/_data/penguins.csv: -------------------------------------------------------------------------------- 1 | species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex 2 | Adelie,Torgersen,39.1,18.7,181,3750,MALE 3 | Adelie,Torgersen,39.5,17.4,186,3800,FEMALE 4 | Adelie,Torgersen,40.3,18,195,3250,FEMALE 5 | Adelie,Torgersen,,,,, 6 | Adelie,Torgersen,36.7,19.3,193,3450,FEMALE 7 | Adelie,Torgersen,39.3,20.6,190,3650,MALE 8 | Adelie,Torgersen,38.9,17.8,181,3625,FEMALE 9 | Adelie,Torgersen,39.2,19.6,195,4675,MALE 10 | Adelie,Torgersen,34.1,18.1,193,3475, 11 | Adelie,Torgersen,42,20.2,190,4250, 12 | Adelie,Torgersen,37.8,17.1,186,3300, 13 | Adelie,Torgersen,37.8,17.3,180,3700, 14 | Adelie,Torgersen,41.1,17.6,182,3200,FEMALE 15 | Adelie,Torgersen,38.6,21.2,191,3800,MALE 16 | Adelie,Torgersen,34.6,21.1,198,4400,MALE 17 | Adelie,Torgersen,36.6,17.8,185,3700,FEMALE 18 | Adelie,Torgersen,38.7,19,195,3450,FEMALE 19 | Adelie,Torgersen,42.5,20.7,197,4500,MALE 20 | Adelie,Torgersen,34.4,18.4,184,3325,FEMALE 21 | Adelie,Torgersen,46,21.5,194,4200,MALE 22 | Adelie,Biscoe,37.8,18.3,174,3400,FEMALE 23 | Adelie,Biscoe,37.7,18.7,180,3600,MALE 24 | Adelie,Biscoe,35.9,19.2,189,3800,FEMALE 25 | Adelie,Biscoe,38.2,18.1,185,3950,MALE 26 | Adelie,Biscoe,38.8,17.2,180,3800,MALE 27 | Adelie,Biscoe,35.3,18.9,187,3800,FEMALE 28 | Adelie,Biscoe,40.6,18.6,183,3550,MALE 29 | Adelie,Biscoe,40.5,17.9,187,3200,FEMALE 30 | Adelie,Biscoe,37.9,18.6,172,3150,FEMALE 31 | Adelie,Biscoe,40.5,18.9,180,3950,MALE 32 | Adelie,Dream,39.5,16.7,178,3250,FEMALE 33 | Adelie,Dream,37.2,18.1,178,3900,MALE 34 | Adelie,Dream,39.5,17.8,188,3300,FEMALE 35 | Adelie,Dream,40.9,18.9,184,3900,MALE 36 | Adelie,Dream,36.4,17,195,3325,FEMALE 37 | Adelie,Dream,39.2,21.1,196,4150,MALE 38 | Adelie,Dream,38.8,20,190,3950,MALE 39 | Adelie,Dream,42.2,18.5,180,3550,FEMALE 40 | Adelie,Dream,37.6,19.3,181,3300,FEMALE 41 | Adelie,Dream,39.8,19.1,184,4650,MALE 42 | Adelie,Dream,36.5,18,182,3150,FEMALE 43 | Adelie,Dream,40.8,18.4,195,3900,MALE 44 | Adelie,Dream,36,18.5,186,3100,FEMALE 45 | Adelie,Dream,44.1,19.7,196,4400,MALE 46 | Adelie,Dream,37,16.9,185,3000,FEMALE 47 | Adelie,Dream,39.6,18.8,190,4600,MALE 48 | Adelie,Dream,41.1,19,182,3425,MALE 49 | Adelie,Dream,37.5,18.9,179,2975, 50 | Adelie,Dream,36,17.9,190,3450,FEMALE 51 | Adelie,Dream,42.3,21.2,191,4150,MALE 52 | Adelie,Biscoe,39.6,17.7,186,3500,FEMALE 53 | Adelie,Biscoe,40.1,18.9,188,4300,MALE 54 | Adelie,Biscoe,35,17.9,190,3450,FEMALE 55 | Adelie,Biscoe,42,19.5,200,4050,MALE 56 | Adelie,Biscoe,34.5,18.1,187,2900,FEMALE 57 | Adelie,Biscoe,41.4,18.6,191,3700,MALE 58 | Adelie,Biscoe,39,17.5,186,3550,FEMALE 59 | Adelie,Biscoe,40.6,18.8,193,3800,MALE 60 | Adelie,Biscoe,36.5,16.6,181,2850,FEMALE 61 | Adelie,Biscoe,37.6,19.1,194,3750,MALE 62 | Adelie,Biscoe,35.7,16.9,185,3150,FEMALE 63 | Adelie,Biscoe,41.3,21.1,195,4400,MALE 64 | Adelie,Biscoe,37.6,17,185,3600,FEMALE 65 | Adelie,Biscoe,41.1,18.2,192,4050,MALE 66 | Adelie,Biscoe,36.4,17.1,184,2850,FEMALE 67 | Adelie,Biscoe,41.6,18,192,3950,MALE 68 | Adelie,Biscoe,35.5,16.2,195,3350,FEMALE 69 | Adelie,Biscoe,41.1,19.1,188,4100,MALE 70 | Adelie,Torgersen,35.9,16.6,190,3050,FEMALE 71 | Adelie,Torgersen,41.8,19.4,198,4450,MALE 72 | Adelie,Torgersen,33.5,19,190,3600,FEMALE 73 | Adelie,Torgersen,39.7,18.4,190,3900,MALE 74 | Adelie,Torgersen,39.6,17.2,196,3550,FEMALE 75 | Adelie,Torgersen,45.8,18.9,197,4150,MALE 76 | Adelie,Torgersen,35.5,17.5,190,3700,FEMALE 77 | Adelie,Torgersen,42.8,18.5,195,4250,MALE 78 | Adelie,Torgersen,40.9,16.8,191,3700,FEMALE 79 | Adelie,Torgersen,37.2,19.4,184,3900,MALE 80 | Adelie,Torgersen,36.2,16.1,187,3550,FEMALE 81 | Adelie,Torgersen,42.1,19.1,195,4000,MALE 82 | Adelie,Torgersen,34.6,17.2,189,3200,FEMALE 83 | Adelie,Torgersen,42.9,17.6,196,4700,MALE 84 | Adelie,Torgersen,36.7,18.8,187,3800,FEMALE 85 | Adelie,Torgersen,35.1,19.4,193,4200,MALE 86 | Adelie,Dream,37.3,17.8,191,3350,FEMALE 87 | Adelie,Dream,41.3,20.3,194,3550,MALE 88 | Adelie,Dream,36.3,19.5,190,3800,MALE 89 | Adelie,Dream,36.9,18.6,189,3500,FEMALE 90 | Adelie,Dream,38.3,19.2,189,3950,MALE 91 | Adelie,Dream,38.9,18.8,190,3600,FEMALE 92 | Adelie,Dream,35.7,18,202,3550,FEMALE 93 | Adelie,Dream,41.1,18.1,205,4300,MALE 94 | Adelie,Dream,34,17.1,185,3400,FEMALE 95 | Adelie,Dream,39.6,18.1,186,4450,MALE 96 | Adelie,Dream,36.2,17.3,187,3300,FEMALE 97 | Adelie,Dream,40.8,18.9,208,4300,MALE 98 | Adelie,Dream,38.1,18.6,190,3700,FEMALE 99 | Adelie,Dream,40.3,18.5,196,4350,MALE 100 | Adelie,Dream,33.1,16.1,178,2900,FEMALE 101 | Adelie,Dream,43.2,18.5,192,4100,MALE 102 | Adelie,Biscoe,35,17.9,192,3725,FEMALE 103 | Adelie,Biscoe,41,20,203,4725,MALE 104 | Adelie,Biscoe,37.7,16,183,3075,FEMALE 105 | Adelie,Biscoe,37.8,20,190,4250,MALE 106 | Adelie,Biscoe,37.9,18.6,193,2925,FEMALE 107 | Adelie,Biscoe,39.7,18.9,184,3550,MALE 108 | Adelie,Biscoe,38.6,17.2,199,3750,FEMALE 109 | Adelie,Biscoe,38.2,20,190,3900,MALE 110 | Adelie,Biscoe,38.1,17,181,3175,FEMALE 111 | Adelie,Biscoe,43.2,19,197,4775,MALE 112 | Adelie,Biscoe,38.1,16.5,198,3825,FEMALE 113 | Adelie,Biscoe,45.6,20.3,191,4600,MALE 114 | Adelie,Biscoe,39.7,17.7,193,3200,FEMALE 115 | Adelie,Biscoe,42.2,19.5,197,4275,MALE 116 | Adelie,Biscoe,39.6,20.7,191,3900,FEMALE 117 | Adelie,Biscoe,42.7,18.3,196,4075,MALE 118 | Adelie,Torgersen,38.6,17,188,2900,FEMALE 119 | Adelie,Torgersen,37.3,20.5,199,3775,MALE 120 | Adelie,Torgersen,35.7,17,189,3350,FEMALE 121 | Adelie,Torgersen,41.1,18.6,189,3325,MALE 122 | Adelie,Torgersen,36.2,17.2,187,3150,FEMALE 123 | Adelie,Torgersen,37.7,19.8,198,3500,MALE 124 | Adelie,Torgersen,40.2,17,176,3450,FEMALE 125 | Adelie,Torgersen,41.4,18.5,202,3875,MALE 126 | Adelie,Torgersen,35.2,15.9,186,3050,FEMALE 127 | Adelie,Torgersen,40.6,19,199,4000,MALE 128 | Adelie,Torgersen,38.8,17.6,191,3275,FEMALE 129 | Adelie,Torgersen,41.5,18.3,195,4300,MALE 130 | Adelie,Torgersen,39,17.1,191,3050,FEMALE 131 | Adelie,Torgersen,44.1,18,210,4000,MALE 132 | Adelie,Torgersen,38.5,17.9,190,3325,FEMALE 133 | Adelie,Torgersen,43.1,19.2,197,3500,MALE 134 | Adelie,Dream,36.8,18.5,193,3500,FEMALE 135 | Adelie,Dream,37.5,18.5,199,4475,MALE 136 | Adelie,Dream,38.1,17.6,187,3425,FEMALE 137 | Adelie,Dream,41.1,17.5,190,3900,MALE 138 | Adelie,Dream,35.6,17.5,191,3175,FEMALE 139 | Adelie,Dream,40.2,20.1,200,3975,MALE 140 | Adelie,Dream,37,16.5,185,3400,FEMALE 141 | Adelie,Dream,39.7,17.9,193,4250,MALE 142 | Adelie,Dream,40.2,17.1,193,3400,FEMALE 143 | Adelie,Dream,40.6,17.2,187,3475,MALE 144 | Adelie,Dream,32.1,15.5,188,3050,FEMALE 145 | Adelie,Dream,40.7,17,190,3725,MALE 146 | Adelie,Dream,37.3,16.8,192,3000,FEMALE 147 | Adelie,Dream,39,18.7,185,3650,MALE 148 | Adelie,Dream,39.2,18.6,190,4250,MALE 149 | Adelie,Dream,36.6,18.4,184,3475,FEMALE 150 | Adelie,Dream,36,17.8,195,3450,FEMALE 151 | Adelie,Dream,37.8,18.1,193,3750,MALE 152 | Adelie,Dream,36,17.1,187,3700,FEMALE 153 | Adelie,Dream,41.5,18.5,201,4000,MALE 154 | Chinstrap,Dream,46.5,17.9,192,3500,FEMALE 155 | Chinstrap,Dream,50,19.5,196,3900,MALE 156 | Chinstrap,Dream,51.3,19.2,193,3650,MALE 157 | Chinstrap,Dream,45.4,18.7,188,3525,FEMALE 158 | Chinstrap,Dream,52.7,19.8,197,3725,MALE 159 | Chinstrap,Dream,45.2,17.8,198,3950,FEMALE 160 | Chinstrap,Dream,46.1,18.2,178,3250,FEMALE 161 | Chinstrap,Dream,51.3,18.2,197,3750,MALE 162 | Chinstrap,Dream,46,18.9,195,4150,FEMALE 163 | Chinstrap,Dream,51.3,19.9,198,3700,MALE 164 | Chinstrap,Dream,46.6,17.8,193,3800,FEMALE 165 | Chinstrap,Dream,51.7,20.3,194,3775,MALE 166 | Chinstrap,Dream,47,17.3,185,3700,FEMALE 167 | Chinstrap,Dream,52,18.1,201,4050,MALE 168 | Chinstrap,Dream,45.9,17.1,190,3575,FEMALE 169 | Chinstrap,Dream,50.5,19.6,201,4050,MALE 170 | Chinstrap,Dream,50.3,20,197,3300,MALE 171 | Chinstrap,Dream,58,17.8,181,3700,FEMALE 172 | Chinstrap,Dream,46.4,18.6,190,3450,FEMALE 173 | Chinstrap,Dream,49.2,18.2,195,4400,MALE 174 | Chinstrap,Dream,42.4,17.3,181,3600,FEMALE 175 | Chinstrap,Dream,48.5,17.5,191,3400,MALE 176 | Chinstrap,Dream,43.2,16.6,187,2900,FEMALE 177 | Chinstrap,Dream,50.6,19.4,193,3800,MALE 178 | Chinstrap,Dream,46.7,17.9,195,3300,FEMALE 179 | Chinstrap,Dream,52,19,197,4150,MALE 180 | Chinstrap,Dream,50.5,18.4,200,3400,FEMALE 181 | Chinstrap,Dream,49.5,19,200,3800,MALE 182 | Chinstrap,Dream,46.4,17.8,191,3700,FEMALE 183 | Chinstrap,Dream,52.8,20,205,4550,MALE 184 | Chinstrap,Dream,40.9,16.6,187,3200,FEMALE 185 | Chinstrap,Dream,54.2,20.8,201,4300,MALE 186 | Chinstrap,Dream,42.5,16.7,187,3350,FEMALE 187 | Chinstrap,Dream,51,18.8,203,4100,MALE 188 | Chinstrap,Dream,49.7,18.6,195,3600,MALE 189 | Chinstrap,Dream,47.5,16.8,199,3900,FEMALE 190 | Chinstrap,Dream,47.6,18.3,195,3850,FEMALE 191 | Chinstrap,Dream,52,20.7,210,4800,MALE 192 | Chinstrap,Dream,46.9,16.6,192,2700,FEMALE 193 | Chinstrap,Dream,53.5,19.9,205,4500,MALE 194 | Chinstrap,Dream,49,19.5,210,3950,MALE 195 | Chinstrap,Dream,46.2,17.5,187,3650,FEMALE 196 | Chinstrap,Dream,50.9,19.1,196,3550,MALE 197 | Chinstrap,Dream,45.5,17,196,3500,FEMALE 198 | Chinstrap,Dream,50.9,17.9,196,3675,FEMALE 199 | Chinstrap,Dream,50.8,18.5,201,4450,MALE 200 | Chinstrap,Dream,50.1,17.9,190,3400,FEMALE 201 | Chinstrap,Dream,49,19.6,212,4300,MALE 202 | Chinstrap,Dream,51.5,18.7,187,3250,MALE 203 | Chinstrap,Dream,49.8,17.3,198,3675,FEMALE 204 | Chinstrap,Dream,48.1,16.4,199,3325,FEMALE 205 | Chinstrap,Dream,51.4,19,201,3950,MALE 206 | Chinstrap,Dream,45.7,17.3,193,3600,FEMALE 207 | Chinstrap,Dream,50.7,19.7,203,4050,MALE 208 | Chinstrap,Dream,42.5,17.3,187,3350,FEMALE 209 | Chinstrap,Dream,52.2,18.8,197,3450,MALE 210 | Chinstrap,Dream,45.2,16.6,191,3250,FEMALE 211 | Chinstrap,Dream,49.3,19.9,203,4050,MALE 212 | Chinstrap,Dream,50.2,18.8,202,3800,MALE 213 | Chinstrap,Dream,45.6,19.4,194,3525,FEMALE 214 | Chinstrap,Dream,51.9,19.5,206,3950,MALE 215 | Chinstrap,Dream,46.8,16.5,189,3650,FEMALE 216 | Chinstrap,Dream,45.7,17,195,3650,FEMALE 217 | Chinstrap,Dream,55.8,19.8,207,4000,MALE 218 | Chinstrap,Dream,43.5,18.1,202,3400,FEMALE 219 | Chinstrap,Dream,49.6,18.2,193,3775,MALE 220 | Chinstrap,Dream,50.8,19,210,4100,MALE 221 | Chinstrap,Dream,50.2,18.7,198,3775,FEMALE 222 | Gentoo,Biscoe,46.1,13.2,211,4500,FEMALE 223 | Gentoo,Biscoe,50,16.3,230,5700,MALE 224 | Gentoo,Biscoe,48.7,14.1,210,4450,FEMALE 225 | Gentoo,Biscoe,50,15.2,218,5700,MALE 226 | Gentoo,Biscoe,47.6,14.5,215,5400,MALE 227 | Gentoo,Biscoe,46.5,13.5,210,4550,FEMALE 228 | Gentoo,Biscoe,45.4,14.6,211,4800,FEMALE 229 | Gentoo,Biscoe,46.7,15.3,219,5200,MALE 230 | Gentoo,Biscoe,43.3,13.4,209,4400,FEMALE 231 | Gentoo,Biscoe,46.8,15.4,215,5150,MALE 232 | Gentoo,Biscoe,40.9,13.7,214,4650,FEMALE 233 | Gentoo,Biscoe,49,16.1,216,5550,MALE 234 | Gentoo,Biscoe,45.5,13.7,214,4650,FEMALE 235 | Gentoo,Biscoe,48.4,14.6,213,5850,MALE 236 | Gentoo,Biscoe,45.8,14.6,210,4200,FEMALE 237 | Gentoo,Biscoe,49.3,15.7,217,5850,MALE 238 | Gentoo,Biscoe,42,13.5,210,4150,FEMALE 239 | Gentoo,Biscoe,49.2,15.2,221,6300,MALE 240 | Gentoo,Biscoe,46.2,14.5,209,4800,FEMALE 241 | Gentoo,Biscoe,48.7,15.1,222,5350,MALE 242 | Gentoo,Biscoe,50.2,14.3,218,5700,MALE 243 | Gentoo,Biscoe,45.1,14.5,215,5000,FEMALE 244 | Gentoo,Biscoe,46.5,14.5,213,4400,FEMALE 245 | Gentoo,Biscoe,46.3,15.8,215,5050,MALE 246 | Gentoo,Biscoe,42.9,13.1,215,5000,FEMALE 247 | Gentoo,Biscoe,46.1,15.1,215,5100,MALE 248 | Gentoo,Biscoe,44.5,14.3,216,4100, 249 | Gentoo,Biscoe,47.8,15,215,5650,MALE 250 | Gentoo,Biscoe,48.2,14.3,210,4600,FEMALE 251 | Gentoo,Biscoe,50,15.3,220,5550,MALE 252 | Gentoo,Biscoe,47.3,15.3,222,5250,MALE 253 | Gentoo,Biscoe,42.8,14.2,209,4700,FEMALE 254 | Gentoo,Biscoe,45.1,14.5,207,5050,FEMALE 255 | Gentoo,Biscoe,59.6,17,230,6050,MALE 256 | Gentoo,Biscoe,49.1,14.8,220,5150,FEMALE 257 | Gentoo,Biscoe,48.4,16.3,220,5400,MALE 258 | Gentoo,Biscoe,42.6,13.7,213,4950,FEMALE 259 | Gentoo,Biscoe,44.4,17.3,219,5250,MALE 260 | Gentoo,Biscoe,44,13.6,208,4350,FEMALE 261 | Gentoo,Biscoe,48.7,15.7,208,5350,MALE 262 | Gentoo,Biscoe,42.7,13.7,208,3950,FEMALE 263 | Gentoo,Biscoe,49.6,16,225,5700,MALE 264 | Gentoo,Biscoe,45.3,13.7,210,4300,FEMALE 265 | Gentoo,Biscoe,49.6,15,216,4750,MALE 266 | Gentoo,Biscoe,50.5,15.9,222,5550,MALE 267 | Gentoo,Biscoe,43.6,13.9,217,4900,FEMALE 268 | Gentoo,Biscoe,45.5,13.9,210,4200,FEMALE 269 | Gentoo,Biscoe,50.5,15.9,225,5400,MALE 270 | Gentoo,Biscoe,44.9,13.3,213,5100,FEMALE 271 | Gentoo,Biscoe,45.2,15.8,215,5300,MALE 272 | Gentoo,Biscoe,46.6,14.2,210,4850,FEMALE 273 | Gentoo,Biscoe,48.5,14.1,220,5300,MALE 274 | Gentoo,Biscoe,45.1,14.4,210,4400,FEMALE 275 | Gentoo,Biscoe,50.1,15,225,5000,MALE 276 | Gentoo,Biscoe,46.5,14.4,217,4900,FEMALE 277 | Gentoo,Biscoe,45,15.4,220,5050,MALE 278 | Gentoo,Biscoe,43.8,13.9,208,4300,FEMALE 279 | Gentoo,Biscoe,45.5,15,220,5000,MALE 280 | Gentoo,Biscoe,43.2,14.5,208,4450,FEMALE 281 | Gentoo,Biscoe,50.4,15.3,224,5550,MALE 282 | Gentoo,Biscoe,45.3,13.8,208,4200,FEMALE 283 | Gentoo,Biscoe,46.2,14.9,221,5300,MALE 284 | Gentoo,Biscoe,45.7,13.9,214,4400,FEMALE 285 | Gentoo,Biscoe,54.3,15.7,231,5650,MALE 286 | Gentoo,Biscoe,45.8,14.2,219,4700,FEMALE 287 | Gentoo,Biscoe,49.8,16.8,230,5700,MALE 288 | Gentoo,Biscoe,46.2,14.4,214,4650, 289 | Gentoo,Biscoe,49.5,16.2,229,5800,MALE 290 | Gentoo,Biscoe,43.5,14.2,220,4700,FEMALE 291 | Gentoo,Biscoe,50.7,15,223,5550,MALE 292 | Gentoo,Biscoe,47.7,15,216,4750,FEMALE 293 | Gentoo,Biscoe,46.4,15.6,221,5000,MALE 294 | Gentoo,Biscoe,48.2,15.6,221,5100,MALE 295 | Gentoo,Biscoe,46.5,14.8,217,5200,FEMALE 296 | Gentoo,Biscoe,46.4,15,216,4700,FEMALE 297 | Gentoo,Biscoe,48.6,16,230,5800,MALE 298 | Gentoo,Biscoe,47.5,14.2,209,4600,FEMALE 299 | Gentoo,Biscoe,51.1,16.3,220,6000,MALE 300 | Gentoo,Biscoe,45.2,13.8,215,4750,FEMALE 301 | Gentoo,Biscoe,45.2,16.4,223,5950,MALE 302 | Gentoo,Biscoe,49.1,14.5,212,4625,FEMALE 303 | Gentoo,Biscoe,52.5,15.6,221,5450,MALE 304 | Gentoo,Biscoe,47.4,14.6,212,4725,FEMALE 305 | Gentoo,Biscoe,50,15.9,224,5350,MALE 306 | Gentoo,Biscoe,44.9,13.8,212,4750,FEMALE 307 | Gentoo,Biscoe,50.8,17.3,228,5600,MALE 308 | Gentoo,Biscoe,43.4,14.4,218,4600,FEMALE 309 | Gentoo,Biscoe,51.3,14.2,218,5300,MALE 310 | Gentoo,Biscoe,47.5,14,212,4875,FEMALE 311 | Gentoo,Biscoe,52.1,17,230,5550,MALE 312 | Gentoo,Biscoe,47.5,15,218,4950,FEMALE 313 | Gentoo,Biscoe,52.2,17.1,228,5400,MALE 314 | Gentoo,Biscoe,45.5,14.5,212,4750,FEMALE 315 | Gentoo,Biscoe,49.5,16.1,224,5650,MALE 316 | Gentoo,Biscoe,44.5,14.7,214,4850,FEMALE 317 | Gentoo,Biscoe,50.8,15.7,226,5200,MALE 318 | Gentoo,Biscoe,49.4,15.8,216,4925,MALE 319 | Gentoo,Biscoe,46.9,14.6,222,4875,FEMALE 320 | Gentoo,Biscoe,48.4,14.4,203,4625,FEMALE 321 | Gentoo,Biscoe,51.1,16.5,225,5250,MALE 322 | Gentoo,Biscoe,48.5,15,219,4850,FEMALE 323 | Gentoo,Biscoe,55.9,17,228,5600,MALE 324 | Gentoo,Biscoe,47.2,15.5,215,4975,FEMALE 325 | Gentoo,Biscoe,49.1,15,228,5500,MALE 326 | Gentoo,Biscoe,47.3,13.8,216,4725, 327 | Gentoo,Biscoe,46.8,16.1,215,5500,MALE 328 | Gentoo,Biscoe,41.7,14.7,210,4700,FEMALE 329 | Gentoo,Biscoe,53.4,15.8,219,5500,MALE 330 | Gentoo,Biscoe,43.3,14,208,4575,FEMALE 331 | Gentoo,Biscoe,48.1,15.1,209,5500,MALE 332 | Gentoo,Biscoe,50.5,15.2,216,5000,FEMALE 333 | Gentoo,Biscoe,49.8,15.9,229,5950,MALE 334 | Gentoo,Biscoe,43.5,15.2,213,4650,FEMALE 335 | Gentoo,Biscoe,51.5,16.3,230,5500,MALE 336 | Gentoo,Biscoe,46.2,14.1,217,4375,FEMALE 337 | Gentoo,Biscoe,55.1,16,230,5850,MALE 338 | Gentoo,Biscoe,44.5,15.7,217,4875, 339 | Gentoo,Biscoe,48.8,16.2,222,6000,MALE 340 | Gentoo,Biscoe,47.2,13.7,214,4925,FEMALE 341 | Gentoo,Biscoe,,,,, 342 | Gentoo,Biscoe,46.8,14.3,215,4850,FEMALE 343 | Gentoo,Biscoe,50.4,15.7,222,5750,MALE 344 | Gentoo,Biscoe,45.2,14.8,212,5200,FEMALE 345 | Gentoo,Biscoe,49.9,16.1,213,5400,MALE 346 | -------------------------------------------------------------------------------- /pyrichlet/_version.py: -------------------------------------------------------------------------------- 1 | 2 | # This file helps to compute a version number in source trees obtained from 3 | # git-archive tarball (such as those provided by githubs download-from-tag 4 | # feature). Distribution tarballs (built by setup.py sdist) and build 5 | # directories (produced by setup.py build) will contain a much shorter file 6 | # that just contains the computed version number. 7 | 8 | # This file is released into the public domain. Generated by 9 | # versioneer-0.21 (https://github.com/python-versioneer/python-versioneer) 10 | 11 | """Git implementation of _version.py.""" 12 | 13 | import errno 14 | import os 15 | import re 16 | import subprocess 17 | import sys 18 | from typing import Callable, Dict 19 | 20 | 21 | def get_keywords(): 22 | """Get the keywords needed to look up the version information.""" 23 | # these strings will be replaced by git during git-archive. 24 | # setup.py/versioneer.py will grep for the variable names, so they must 25 | # each be defined on a line of their own. _version.py will just call 26 | # get_keywords(). 27 | git_refnames = " (HEAD -> main, tag: 0.0.9)" 28 | git_full = "1ea69fd5a26fc353a206cef75488fbcce87cce51" 29 | git_date = "2024-03-01 10:27:25 -0600" 30 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} 31 | return keywords 32 | 33 | 34 | class VersioneerConfig: 35 | """Container for Versioneer configuration parameters.""" 36 | 37 | 38 | def get_config(): 39 | """Create, populate and return the VersioneerConfig() object.""" 40 | # these strings are filled in when 'setup.py versioneer' creates 41 | # _version.py 42 | cfg = VersioneerConfig() 43 | cfg.VCS = "git" 44 | cfg.style = "pep440" 45 | cfg.tag_prefix = "" 46 | cfg.parentdir_prefix = "pyrichlet-" 47 | cfg.versionfile_source = "pyrichlet/_version.py" 48 | cfg.verbose = False 49 | return cfg 50 | 51 | 52 | class NotThisMethod(Exception): 53 | """Exception raised if a method is not valid for the current scenario.""" 54 | 55 | 56 | LONG_VERSION_PY: Dict[str, str] = {} 57 | HANDLERS: Dict[str, Dict[str, Callable]] = {} 58 | 59 | 60 | def register_vcs_handler(vcs, method): # decorator 61 | """Create decorator to mark a method as the handler of a VCS.""" 62 | def decorate(f): 63 | """Store f in HANDLERS[vcs][method].""" 64 | if vcs not in HANDLERS: 65 | HANDLERS[vcs] = {} 66 | HANDLERS[vcs][method] = f 67 | return f 68 | return decorate 69 | 70 | 71 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, 72 | env=None): 73 | """Call the given command(s).""" 74 | assert isinstance(commands, list) 75 | process = None 76 | for command in commands: 77 | try: 78 | dispcmd = str([command] + args) 79 | # remember shell=False, so use git.cmd on windows, not just git 80 | process = subprocess.Popen([command] + args, cwd=cwd, env=env, 81 | stdout=subprocess.PIPE, 82 | stderr=(subprocess.PIPE if hide_stderr 83 | else None)) 84 | break 85 | except OSError: 86 | e = sys.exc_info()[1] 87 | if e.errno == errno.ENOENT: 88 | continue 89 | if verbose: 90 | print("unable to run %s" % dispcmd) 91 | print(e) 92 | return None, None 93 | else: 94 | if verbose: 95 | print("unable to find command, tried %s" % (commands,)) 96 | return None, None 97 | stdout = process.communicate()[0].strip().decode() 98 | if process.returncode != 0: 99 | if verbose: 100 | print("unable to run %s (error)" % dispcmd) 101 | print("stdout was %s" % stdout) 102 | return None, process.returncode 103 | return stdout, process.returncode 104 | 105 | 106 | def versions_from_parentdir(parentdir_prefix, root, verbose): 107 | """Try to determine the version from the parent directory name. 108 | 109 | Source tarballs conventionally unpack into a directory that includes both 110 | the project name and a version string. We will also support searching up 111 | two directory levels for an appropriately named parent directory 112 | """ 113 | rootdirs = [] 114 | 115 | for _ in range(3): 116 | dirname = os.path.basename(root) 117 | if dirname.startswith(parentdir_prefix): 118 | return {"version": dirname[len(parentdir_prefix):], 119 | "full-revisionid": None, 120 | "dirty": False, "error": None, "date": None} 121 | rootdirs.append(root) 122 | root = os.path.dirname(root) # up a level 123 | 124 | if verbose: 125 | print("Tried directories %s but none started with prefix %s" % 126 | (str(rootdirs), parentdir_prefix)) 127 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 128 | 129 | 130 | @register_vcs_handler("git", "get_keywords") 131 | def git_get_keywords(versionfile_abs): 132 | """Extract version information from the given file.""" 133 | # the code embedded in _version.py can just fetch the value of these 134 | # keywords. When used from setup.py, we don't want to import _version.py, 135 | # so we do it with a regexp instead. This function is not used from 136 | # _version.py. 137 | keywords = {} 138 | try: 139 | with open(versionfile_abs, "r") as fobj: 140 | for line in fobj: 141 | if line.strip().startswith("git_refnames ="): 142 | mo = re.search(r'=\s*"(.*)"', line) 143 | if mo: 144 | keywords["refnames"] = mo.group(1) 145 | if line.strip().startswith("git_full ="): 146 | mo = re.search(r'=\s*"(.*)"', line) 147 | if mo: 148 | keywords["full"] = mo.group(1) 149 | if line.strip().startswith("git_date ="): 150 | mo = re.search(r'=\s*"(.*)"', line) 151 | if mo: 152 | keywords["date"] = mo.group(1) 153 | except OSError: 154 | pass 155 | return keywords 156 | 157 | 158 | @register_vcs_handler("git", "keywords") 159 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 160 | """Get version information from git keywords.""" 161 | if "refnames" not in keywords: 162 | raise NotThisMethod("Short version file found") 163 | date = keywords.get("date") 164 | if date is not None: 165 | # Use only the last line. Previous lines may contain GPG signature 166 | # information. 167 | date = date.splitlines()[-1] 168 | 169 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant 170 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 171 | # -like" string, which we must then edit to make compliant), because 172 | # it's been around since git-1.5.3, and it's too difficult to 173 | # discover which version we're using, or to work around using an 174 | # older one. 175 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 176 | refnames = keywords["refnames"].strip() 177 | if refnames.startswith("$Format"): 178 | if verbose: 179 | print("keywords are unexpanded, not using") 180 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 181 | refs = {r.strip() for r in refnames.strip("()").split(",")} 182 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 183 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 184 | TAG = "tag: " 185 | tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} 186 | if not tags: 187 | # Either we're using git < 1.8.3, or there really are no tags. We use 188 | # a heuristic: assume all version tags have a digit. The old git %d 189 | # expansion behaves like git log --decorate=short and strips out the 190 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 191 | # between branches and tags. By ignoring refnames without digits, we 192 | # filter out many common branch names like "release" and 193 | # "stabilization", as well as "HEAD" and "master". 194 | tags = {r for r in refs if re.search(r'\d', r)} 195 | if verbose: 196 | print("discarding '%s', no digits" % ",".join(refs - tags)) 197 | if verbose: 198 | print("likely tags: %s" % ",".join(sorted(tags))) 199 | for ref in sorted(tags): 200 | # sorting will prefer e.g. "2.0" over "2.0rc1" 201 | if ref.startswith(tag_prefix): 202 | r = ref[len(tag_prefix):] 203 | # Filter out refs that exactly match prefix or that don't start 204 | # with a number once the prefix is stripped (mostly a concern 205 | # when prefix is '') 206 | if not re.match(r'\d', r): 207 | continue 208 | if verbose: 209 | print("picking %s" % r) 210 | return {"version": r, 211 | "full-revisionid": keywords["full"].strip(), 212 | "dirty": False, "error": None, 213 | "date": date} 214 | # no suitable tags, so version is "0+unknown", but full hex is still there 215 | if verbose: 216 | print("no suitable tags, using unknown + full revision id") 217 | return {"version": "0+unknown", 218 | "full-revisionid": keywords["full"].strip(), 219 | "dirty": False, "error": "no suitable tags", "date": None} 220 | 221 | 222 | @register_vcs_handler("git", "pieces_from_vcs") 223 | def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): 224 | """Get version from 'git describe' in the root of the source tree. 225 | 226 | This only gets called if the git-archive 'subst' keywords were *not* 227 | expanded, and _version.py hasn't already been rewritten with a short 228 | version string, meaning we're inside a checked out source tree. 229 | """ 230 | GITS = ["git"] 231 | TAG_PREFIX_REGEX = "*" 232 | if sys.platform == "win32": 233 | GITS = ["git.cmd", "git.exe"] 234 | TAG_PREFIX_REGEX = r"\*" 235 | 236 | _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, 237 | hide_stderr=True) 238 | if rc != 0: 239 | if verbose: 240 | print("Directory %s not under git control" % root) 241 | raise NotThisMethod("'git rev-parse --git-dir' returned error") 242 | 243 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] 244 | # if there isn't one, this yields HEX[-dirty] (no NUM) 245 | describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", 246 | "--always", "--long", 247 | "--match", 248 | "%s%s" % (tag_prefix, TAG_PREFIX_REGEX)], 249 | cwd=root) 250 | # --long was added in git-1.5.5 251 | if describe_out is None: 252 | raise NotThisMethod("'git describe' failed") 253 | describe_out = describe_out.strip() 254 | full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) 255 | if full_out is None: 256 | raise NotThisMethod("'git rev-parse' failed") 257 | full_out = full_out.strip() 258 | 259 | pieces = {} 260 | pieces["long"] = full_out 261 | pieces["short"] = full_out[:7] # maybe improved later 262 | pieces["error"] = None 263 | 264 | branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], 265 | cwd=root) 266 | # --abbrev-ref was added in git-1.6.3 267 | if rc != 0 or branch_name is None: 268 | raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") 269 | branch_name = branch_name.strip() 270 | 271 | if branch_name == "HEAD": 272 | # If we aren't exactly on a branch, pick a branch which represents 273 | # the current commit. If all else fails, we are on a branchless 274 | # commit. 275 | branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) 276 | # --contains was added in git-1.5.4 277 | if rc != 0 or branches is None: 278 | raise NotThisMethod("'git branch --contains' returned error") 279 | branches = branches.split("\n") 280 | 281 | # Remove the first line if we're running detached 282 | if "(" in branches[0]: 283 | branches.pop(0) 284 | 285 | # Strip off the leading "* " from the list of branches. 286 | branches = [branch[2:] for branch in branches] 287 | if "master" in branches: 288 | branch_name = "master" 289 | elif not branches: 290 | branch_name = None 291 | else: 292 | # Pick the first branch that is returned. Good or bad. 293 | branch_name = branches[0] 294 | 295 | pieces["branch"] = branch_name 296 | 297 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 298 | # TAG might have hyphens. 299 | git_describe = describe_out 300 | 301 | # look for -dirty suffix 302 | dirty = git_describe.endswith("-dirty") 303 | pieces["dirty"] = dirty 304 | if dirty: 305 | git_describe = git_describe[:git_describe.rindex("-dirty")] 306 | 307 | # now we have TAG-NUM-gHEX or HEX 308 | 309 | if "-" in git_describe: 310 | # TAG-NUM-gHEX 311 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) 312 | if not mo: 313 | # unparsable. Maybe git-describe is misbehaving? 314 | pieces["error"] = ("unable to parse git-describe output: '%s'" 315 | % describe_out) 316 | return pieces 317 | 318 | # tag 319 | full_tag = mo.group(1) 320 | if not full_tag.startswith(tag_prefix): 321 | if verbose: 322 | fmt = "tag '%s' doesn't start with prefix '%s'" 323 | print(fmt % (full_tag, tag_prefix)) 324 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" 325 | % (full_tag, tag_prefix)) 326 | return pieces 327 | pieces["closest-tag"] = full_tag[len(tag_prefix):] 328 | 329 | # distance: number of commits since tag 330 | pieces["distance"] = int(mo.group(2)) 331 | 332 | # commit: short hex revision ID 333 | pieces["short"] = mo.group(3) 334 | 335 | else: 336 | # HEX: no tags 337 | pieces["closest-tag"] = None 338 | count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) 339 | pieces["distance"] = int(count_out) # total number of commits 340 | 341 | # commit date: see ISO-8601 comment in git_versions_from_keywords() 342 | date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() 343 | # Use only the last line. Previous lines may contain GPG signature 344 | # information. 345 | date = date.splitlines()[-1] 346 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 347 | 348 | return pieces 349 | 350 | 351 | def plus_or_dot(pieces): 352 | """Return a + if we don't already have one, else return a .""" 353 | if "+" in pieces.get("closest-tag", ""): 354 | return "." 355 | return "+" 356 | 357 | 358 | def render_pep440(pieces): 359 | """Build up version string, with post-release "local version identifier". 360 | 361 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 362 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 363 | 364 | Exceptions: 365 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 366 | """ 367 | if pieces["closest-tag"]: 368 | rendered = pieces["closest-tag"] 369 | if pieces["distance"] or pieces["dirty"]: 370 | rendered += plus_or_dot(pieces) 371 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 372 | if pieces["dirty"]: 373 | rendered += ".dirty" 374 | else: 375 | # exception #1 376 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], 377 | pieces["short"]) 378 | if pieces["dirty"]: 379 | rendered += ".dirty" 380 | return rendered 381 | 382 | 383 | def render_pep440_branch(pieces): 384 | """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . 385 | 386 | The ".dev0" means not master branch. Note that .dev0 sorts backwards 387 | (a feature branch will appear "older" than the master branch). 388 | 389 | Exceptions: 390 | 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] 391 | """ 392 | if pieces["closest-tag"]: 393 | rendered = pieces["closest-tag"] 394 | if pieces["distance"] or pieces["dirty"]: 395 | if pieces["branch"] != "master": 396 | rendered += ".dev0" 397 | rendered += plus_or_dot(pieces) 398 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 399 | if pieces["dirty"]: 400 | rendered += ".dirty" 401 | else: 402 | # exception #1 403 | rendered = "0" 404 | if pieces["branch"] != "master": 405 | rendered += ".dev0" 406 | rendered += "+untagged.%d.g%s" % (pieces["distance"], 407 | pieces["short"]) 408 | if pieces["dirty"]: 409 | rendered += ".dirty" 410 | return rendered 411 | 412 | 413 | def pep440_split_post(ver): 414 | """Split pep440 version string at the post-release segment. 415 | 416 | Returns the release segments before the post-release and the 417 | post-release version number (or -1 if no post-release segment is present). 418 | """ 419 | vc = str.split(ver, ".post") 420 | return vc[0], int(vc[1] or 0) if len(vc) == 2 else None 421 | 422 | 423 | def render_pep440_pre(pieces): 424 | """TAG[.postN.devDISTANCE] -- No -dirty. 425 | 426 | Exceptions: 427 | 1: no tags. 0.post0.devDISTANCE 428 | """ 429 | if pieces["closest-tag"]: 430 | if pieces["distance"]: 431 | # update the post release segment 432 | tag_version, post_version = pep440_split_post(pieces["closest-tag"]) 433 | rendered = tag_version 434 | if post_version is not None: 435 | rendered += ".post%d.dev%d" % (post_version+1, pieces["distance"]) 436 | else: 437 | rendered += ".post0.dev%d" % (pieces["distance"]) 438 | else: 439 | # no commits, use the tag as the version 440 | rendered = pieces["closest-tag"] 441 | else: 442 | # exception #1 443 | rendered = "0.post0.dev%d" % pieces["distance"] 444 | return rendered 445 | 446 | 447 | def render_pep440_post(pieces): 448 | """TAG[.postDISTANCE[.dev0]+gHEX] . 449 | 450 | The ".dev0" means dirty. Note that .dev0 sorts backwards 451 | (a dirty tree will appear "older" than the corresponding clean one), 452 | but you shouldn't be releasing software with -dirty anyways. 453 | 454 | Exceptions: 455 | 1: no tags. 0.postDISTANCE[.dev0] 456 | """ 457 | if pieces["closest-tag"]: 458 | rendered = pieces["closest-tag"] 459 | if pieces["distance"] or pieces["dirty"]: 460 | rendered += ".post%d" % pieces["distance"] 461 | if pieces["dirty"]: 462 | rendered += ".dev0" 463 | rendered += plus_or_dot(pieces) 464 | rendered += "g%s" % pieces["short"] 465 | else: 466 | # exception #1 467 | rendered = "0.post%d" % pieces["distance"] 468 | if pieces["dirty"]: 469 | rendered += ".dev0" 470 | rendered += "+g%s" % pieces["short"] 471 | return rendered 472 | 473 | 474 | def render_pep440_post_branch(pieces): 475 | """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . 476 | 477 | The ".dev0" means not master branch. 478 | 479 | Exceptions: 480 | 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] 481 | """ 482 | if pieces["closest-tag"]: 483 | rendered = pieces["closest-tag"] 484 | if pieces["distance"] or pieces["dirty"]: 485 | rendered += ".post%d" % pieces["distance"] 486 | if pieces["branch"] != "master": 487 | rendered += ".dev0" 488 | rendered += plus_or_dot(pieces) 489 | rendered += "g%s" % pieces["short"] 490 | if pieces["dirty"]: 491 | rendered += ".dirty" 492 | else: 493 | # exception #1 494 | rendered = "0.post%d" % pieces["distance"] 495 | if pieces["branch"] != "master": 496 | rendered += ".dev0" 497 | rendered += "+g%s" % pieces["short"] 498 | if pieces["dirty"]: 499 | rendered += ".dirty" 500 | return rendered 501 | 502 | 503 | def render_pep440_old(pieces): 504 | """TAG[.postDISTANCE[.dev0]] . 505 | 506 | The ".dev0" means dirty. 507 | 508 | Exceptions: 509 | 1: no tags. 0.postDISTANCE[.dev0] 510 | """ 511 | if pieces["closest-tag"]: 512 | rendered = pieces["closest-tag"] 513 | if pieces["distance"] or pieces["dirty"]: 514 | rendered += ".post%d" % pieces["distance"] 515 | if pieces["dirty"]: 516 | rendered += ".dev0" 517 | else: 518 | # exception #1 519 | rendered = "0.post%d" % pieces["distance"] 520 | if pieces["dirty"]: 521 | rendered += ".dev0" 522 | return rendered 523 | 524 | 525 | def render_git_describe(pieces): 526 | """TAG[-DISTANCE-gHEX][-dirty]. 527 | 528 | Like 'git describe --tags --dirty --always'. 529 | 530 | Exceptions: 531 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 532 | """ 533 | if pieces["closest-tag"]: 534 | rendered = pieces["closest-tag"] 535 | if pieces["distance"]: 536 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 537 | else: 538 | # exception #1 539 | rendered = pieces["short"] 540 | if pieces["dirty"]: 541 | rendered += "-dirty" 542 | return rendered 543 | 544 | 545 | def render_git_describe_long(pieces): 546 | """TAG-DISTANCE-gHEX[-dirty]. 547 | 548 | Like 'git describe --tags --dirty --always -long'. 549 | The distance/hash is unconditional. 550 | 551 | Exceptions: 552 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 553 | """ 554 | if pieces["closest-tag"]: 555 | rendered = pieces["closest-tag"] 556 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 557 | else: 558 | # exception #1 559 | rendered = pieces["short"] 560 | if pieces["dirty"]: 561 | rendered += "-dirty" 562 | return rendered 563 | 564 | 565 | def render(pieces, style): 566 | """Render the given version pieces into the requested style.""" 567 | if pieces["error"]: 568 | return {"version": "unknown", 569 | "full-revisionid": pieces.get("long"), 570 | "dirty": None, 571 | "error": pieces["error"], 572 | "date": None} 573 | 574 | if not style or style == "default": 575 | style = "pep440" # the default 576 | 577 | if style == "pep440": 578 | rendered = render_pep440(pieces) 579 | elif style == "pep440-branch": 580 | rendered = render_pep440_branch(pieces) 581 | elif style == "pep440-pre": 582 | rendered = render_pep440_pre(pieces) 583 | elif style == "pep440-post": 584 | rendered = render_pep440_post(pieces) 585 | elif style == "pep440-post-branch": 586 | rendered = render_pep440_post_branch(pieces) 587 | elif style == "pep440-old": 588 | rendered = render_pep440_old(pieces) 589 | elif style == "git-describe": 590 | rendered = render_git_describe(pieces) 591 | elif style == "git-describe-long": 592 | rendered = render_git_describe_long(pieces) 593 | else: 594 | raise ValueError("unknown style '%s'" % style) 595 | 596 | return {"version": rendered, "full-revisionid": pieces["long"], 597 | "dirty": pieces["dirty"], "error": None, 598 | "date": pieces.get("date")} 599 | 600 | 601 | def get_versions(): 602 | """Get version information or return default if unable to do so.""" 603 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 604 | # __file__, we can work backwards from there to the root. Some 605 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 606 | # case we can only use expanded keywords. 607 | 608 | cfg = get_config() 609 | verbose = cfg.verbose 610 | 611 | try: 612 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, 613 | verbose) 614 | except NotThisMethod: 615 | pass 616 | 617 | try: 618 | root = os.path.realpath(__file__) 619 | # versionfile_source is the relative path from the top of the source 620 | # tree (where the .git directory might live) to this file. Invert 621 | # this to find the root from __file__. 622 | for _ in cfg.versionfile_source.split('/'): 623 | root = os.path.dirname(root) 624 | except NameError: 625 | return {"version": "0+unknown", "full-revisionid": None, 626 | "dirty": None, 627 | "error": "unable to find root of source tree", 628 | "date": None} 629 | 630 | try: 631 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 632 | return render(pieces, cfg.style) 633 | except NotThisMethod: 634 | pass 635 | 636 | try: 637 | if cfg.parentdir_prefix: 638 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 639 | except NotThisMethod: 640 | pass 641 | 642 | return {"version": "0+unknown", "full-revisionid": None, 643 | "dirty": None, 644 | "error": "unable to compute version", "date": None} 645 | -------------------------------------------------------------------------------- /pyrichlet/mixture_models/_base.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import SpectralClustering 2 | from scipy.stats import multivariate_normal 3 | from collections import defaultdict 4 | import pandas as pd 5 | import numpy as np 6 | 7 | from abc import ABCMeta 8 | 9 | from . import _utils 10 | from ..exceptions import NotFittedError 11 | from ..weight_models import BaseWeight 12 | from ..utils.functions import density_students_t, density_normal 13 | from ..utils.validators import rng_parser 14 | 15 | 16 | class BaseGaussianMixture(metaclass=ABCMeta): 17 | """ 18 | Base class for Gaussian Mixture Models 19 | 20 | Warning: This class should not be used directly. Use derived classes 21 | instead. 22 | 23 | Parameters 24 | ---------- 25 | weight_model : BaseWeight, default=None 26 | The weighting model for the mixing components 27 | mu_prior : {float, array, np.array}, default=None 28 | The prior centering parameter of the prior normal - inverse Wishart 29 | distribution. If None, the mean of the observations to fit will be used 30 | lambda_prior : float, default=1 31 | The precision parameter of the prior normal - inverse Wishart 32 | distribution. 33 | psi_prior : {array, np.array, np.matrix}, default=None 34 | The inverse scale matrix of the prior normal - inverse Wishart 35 | distribution. If None, the sample variance-covariance matrix will be 36 | used. 37 | nu_prior : float, default=None 38 | The degrees of freedom of the prior normal - inverse Wishart 39 | distribution. If None, the dimension of the scale matrix will be used. 40 | total_iter : int, default=1000 41 | The total number of steps in the Gibbs sampler algorithm or the max 42 | number of steps for the variational algorithm. 43 | burn_in : int, default=100 44 | The number of steps in the Gibbs sampler to discard in expected a 45 | posteriori (EAP) estimations. 46 | subsample_steps : int, default=1 47 | The number of steps to draw before saving the realizations. The steps 48 | between savings will be discarded. 49 | show_progress : bool, default=False 50 | Whether to display the progress with tqdm. 51 | rng: {np.random.Generator, int}, default=None 52 | The PRNG to use for sampling. 53 | """ 54 | 55 | def __init__(self, weight_model: BaseWeight, mu_prior=None, 56 | lambda_prior=1, psi_prior=None, nu_prior=None, 57 | total_iter=1000, burn_in=100, subsample_steps=1, 58 | show_progress=False, rng=None): 59 | 60 | assert total_iter > burn_in, ( 61 | "total_iter must be greater than burn_in period") 62 | self.rng = rng_parser(rng) 63 | self.burn_in = int(burn_in) 64 | self.total_iter = int(total_iter) 65 | self.subsample_steps = int(subsample_steps) 66 | 67 | self.mu_prior = mu_prior 68 | self.lambda_prior = lambda_prior 69 | self.psi_prior = psi_prior 70 | self.nu_prior = nu_prior 71 | 72 | self.weight_model = weight_model 73 | self.y = np.array([]) 74 | self._column_names = None 75 | 76 | # Variables used in Gibbs sampler 77 | self.d = np.array([]) 78 | self.theta = {} 79 | self.u = np.array([]) 80 | self.affinity_matrix = np.array([]) 81 | self.map_sim_params = None 82 | self.map_log_likelihood = -np.inf 83 | self.total_saved_steps = 0 84 | self.sim_params = [] 85 | self.n_groups = [] 86 | self.n_atoms = [] 87 | self.n_log_likelihood = [] 88 | 89 | # Extra variables used in variational methods 90 | self.var_k = None 91 | self.var_d = None 92 | self.var_theta = None 93 | 94 | # Fitting flags 95 | self.gibbs_fitted = False 96 | self.var_fitted = False 97 | self.var_converged = False 98 | 99 | self.show_progress = show_progress 100 | 101 | def fit_gibbs(self, y, init_groups=None, warm_start=False, 102 | show_progress=None, init_method="kmeans"): 103 | """ 104 | Fit posterior distribution using Gibbs sampling. 105 | 106 | This method does `self.total_iter` steps of the Gibbs sampler and 107 | stores the arising variables for a later computation of the expected a 108 | posteriori of the probability distribution density or of the clusters. 109 | 110 | Parameters 111 | ---------- 112 | y : {array-like} of shape (n_samples, n_features) 113 | The input sample. 114 | 115 | init_groups: int, default=None 116 | Maximum number of groups to assign in the initialization. If None, 117 | the initial number of groups is drawn from the weighting structure 118 | model's attribute `n`. 119 | This parameter is only used in k-means initialization. 120 | 121 | warm_start : bool, default=False 122 | Whether to continue the sampling process from a past run or start 123 | over. If False, the sampling will start from the prior and saved 124 | states will be deleted. 125 | 126 | show_progress: bool, default=None 127 | If show_progress is True, a progress bar from the tqdm library is 128 | displayed. 129 | 130 | init_method: str, default="random" 131 | "random": does a random initialization based on the prior models 132 | "kmeans": does a kmeans initialization 133 | "variational": fits the variational distribution an uses the MAP 134 | parameters as initialization 135 | """ 136 | self._initialize_common_params(y) 137 | if not warm_start: 138 | self._initialize_gibbs_params(init_groups=init_groups, 139 | method=init_method) 140 | self._update_map_params() 141 | if show_progress is not None: 142 | self.show_progress = show_progress 143 | 144 | # Iterate the Gibbs steps with or without tqdm 145 | burn_in_iterator = range(self.burn_in) 146 | range_iterator = range(self.total_iter - self.burn_in) 147 | if self.show_progress: 148 | from tqdm import tqdm 149 | burn_in_iterator = tqdm(burn_in_iterator) 150 | range_iterator = tqdm(range_iterator) 151 | print("Starting burn-in.") 152 | for _ in burn_in_iterator: 153 | self._gibbs_step() 154 | self._update_map_params() 155 | if self.show_progress: 156 | print("Finished burn-in.") 157 | print("Starting training.") 158 | for i in range_iterator: 159 | self._gibbs_step() 160 | self._update_map_params() 161 | if i % self.subsample_steps == 0: 162 | self._save_params() 163 | if self.show_progress: 164 | print("Finished training.") 165 | self.gibbs_fitted = True 166 | 167 | def fit_variational(self, y, n_groups=None, warm_start=False, 168 | show_progress=None, tol=1e-8, init_method='kmeans'): 169 | """ 170 | Fit posterior variational distribution using mean field theory. 171 | 172 | This method does up to `self.total_iter` steps of the gradient descent 173 | algorithm to fit the variational distributions of weights, atoms and 174 | assignations of the mixture. 175 | 176 | Parameters 177 | ---------- 178 | y : {array-like} of shape (n_samples, n_features) 179 | The input sample. 180 | 181 | n_groups : int, default=None 182 | The number of groups of the truncated variational distribution. 183 | If None, the number of groups will be deduced from the weighting 184 | structure if possible. 185 | 186 | warm_start : bool, default=False 187 | Whether to continue the sampling process from a past run or start 188 | over. If False, the sampling will start from the prior parameters 189 | and any previous calculations will be discarded. 190 | 191 | show_progress : bool, default=None 192 | If show_progress is True, a progress bar from the tqdm library is 193 | displayed. 194 | 195 | tol: float, default=1e-8 196 | The tolerance of change in the evidence lower bound (ELBO) between 197 | iterations. The process finishes when the change is less than 198 | `tol`. 199 | 200 | init_method : str, default="kmeans" 201 | "kmeans": initialize variational parameters using k-means algorithm 202 | "random": initialize variational parameters using a random 203 | assignment 204 | """ 205 | if show_progress is not None: 206 | self.show_progress = show_progress 207 | self._initialize_common_params(y) 208 | if not warm_start: 209 | if hasattr(self, 'n') and n_groups is None: 210 | var_k = self.n 211 | elif n_groups is None: 212 | raise AttributeError("n_groups must be a positive integer") 213 | else: 214 | var_k = n_groups 215 | self._initialize_variational_params(var_k=var_k, 216 | init_method=init_method) 217 | elbo = -np.inf 218 | elbo_diff = np.inf 219 | iterations = 0 220 | t = None 221 | if self.show_progress: 222 | from tqdm import tqdm 223 | t = tqdm() 224 | while elbo_diff > tol and iterations < self.total_iter: 225 | self._maximize_variational() 226 | prev_elbo = elbo 227 | elbo = self._calc_elbo() 228 | elbo_diff = abs(prev_elbo - elbo) 229 | iterations += 1 230 | if t is not None: 231 | t.update() 232 | self.var_fitted = True 233 | if iterations < self.total_iter: 234 | self.var_converged = True 235 | 236 | def gibbs_eap_density(self, y=None, dim=None, component=None, 237 | periods=None): 238 | """ 239 | Returns the (Gibbs fitted) expected a posteriori density at y 240 | 241 | This method must be called after fitting a dataset with `fit_gibbs`. 242 | It returns the density at `y` as defined by the average of the mixture 243 | at every saved Gibbs step. 244 | 245 | Parameters 246 | ---------- 247 | y : {array-like} of shape (n_samples, n_features), default=None 248 | The data points over which to evaluate the EAP density. If `None` 249 | the data used at fitting is used. 250 | dim: int, {array-like} default=None 251 | The desired dimension index for which to marginalize the density, 252 | if None, all dimensions are used. 253 | component: int default=None 254 | Only returns the scaled density for a particular component. 255 | periods : int, default=None 256 | The number of saved periods to use counting backwards from the 257 | last Gibbs step. If `None`, all saved periods are used. 258 | """ 259 | if not self.gibbs_fitted: 260 | raise NotFittedError("Object must be fitted with fit_gibbs method") 261 | _y = self._cast_observations(y) 262 | y_sim = [] 263 | if periods is None: 264 | sim_params = self.sim_params 265 | else: 266 | i_start = min(periods, len(self.sim_params)) 267 | sim_params = self.sim_params[-i_start:] 268 | for param in sim_params: 269 | y_sim.append(_utils.mixture_density(_y, 270 | param["w"], 271 | param["theta"], 272 | dim=dim, 273 | component=component)) 274 | return np.array(y_sim).mean(axis=0) 275 | 276 | def gibbs_map_density(self, y=None, dim=None, component=None): 277 | """ 278 | Returns the (Gibbs fitted) maximum a posteriori density at y 279 | 280 | This method must be called after fitting a dataset with `fit_gibbs`. 281 | It returns the density at `y` as defined by the random mixture within 282 | the Gibbs steps having the highest likelihood. 283 | 284 | Parameters 285 | ---------- 286 | y : {array-like} of shape (n_samples, n_features), default=None 287 | The data points over which to evaluate the MAP density. If `None` 288 | the data used at fitting is used. 289 | dim: int, {array-like} default=None 290 | The desired dimension index for which to marginalize the density, 291 | if None, all dimensions are used. 292 | component: int default=None 293 | Only returns the scaled density for a particular component. 294 | """ 295 | if not self.gibbs_fitted: 296 | raise NotFittedError("Object must be fitted with fit_gibbs method") 297 | _y = self._cast_observations(y) 298 | return _utils.mixture_density(_y, 299 | self.map_sim_params["w"], 300 | self.map_sim_params["theta"], 301 | dim=dim, 302 | component=component) 303 | 304 | def gibbs_eap_affinity_matrix(self, y=None): 305 | """ 306 | Returns the (Gibbs fitted) affinity matrix for the observations y 307 | 308 | This method must be called after fitting a dataset with `fit_gibbs`. 309 | It returns an affinity matrix for `y`. The entry (i,j) of the returned 310 | matrix denotes the proportion of draws where the observation i shared 311 | the same group as the observation j. 312 | 313 | Parameters 314 | ---------- 315 | y : {array-like} of shape (n_samples, n_features), default=None 316 | The data points for which to get an affinity matrix. If `None` 317 | the data used at fitting is used. 318 | """ 319 | if not self.gibbs_fitted: 320 | raise NotFittedError( 321 | "Object must be fitted with the fit_gibbs method") 322 | if y is None: 323 | return self.affinity_matrix / self.total_saved_steps 324 | _y = self._cast_observations(y) 325 | affinity_matrix = np.zeros((len(_y), len(_y))) 326 | for params in self.sim_params: 327 | grouping = _utils.cluster(_y, params["w"], params["theta"])[0] 328 | affinity_matrix += np.equal(grouping, grouping[:, None]) 329 | affinity_matrix /= len(self.sim_params) 330 | return affinity_matrix 331 | 332 | def gibbs_eap_spectral_consensus_cluster(self, y=None, n_clusters=1): 333 | """ 334 | Returns the (Gibbs fitted) expected a posteriori cluster for y 335 | 336 | This method must be called after fitting a dataset with 337 | `fit_gibbs`. 338 | It returns the EAP consensus clustering for the observations `y`. 339 | It uses the spectral clustering algorithm over the EAP affinity matrix 340 | as consensus algorithm. 341 | 342 | Parameters 343 | ---------- 344 | y : {array-like} of shape (n_samples, n_features), default=None 345 | The data points to cluster. If `None` 346 | the data used at fitting is used. 347 | 348 | n_clusters: int, default=1 349 | The number of clusters to output. 350 | """ 351 | if not self.gibbs_fitted: 352 | raise NotFittedError( 353 | "Object must be fitted with the fit_gibbs method") 354 | sc = SpectralClustering(n_clusters=n_clusters, affinity='precomputed') 355 | return sc.fit_predict(self.gibbs_eap_affinity_matrix(y)) 356 | 357 | def gibbs_map_cluster(self, y=None, full=False): 358 | """ 359 | Returns the (Gibbs fitted) maximum a posteriori cluster for y 360 | 361 | This method is called after fitting a dataset with `fit_gibbs`. 362 | It returns the clustering for `y` using the mixture within the Gibbs 363 | steps with the greatest likelihood. 364 | 365 | Parameters 366 | ---------- 367 | y : {array-like} of shape (n_samples, n_features), default=None 368 | The data points to cluster. If `None` 369 | the data used at fitting is used. 370 | full: bool, default=False 371 | if full is false, only a vector with the clustering output is 372 | returned. If true, a tuple with the clusters and assignation 373 | uncertainties is returned. 374 | """ 375 | if not self.gibbs_fitted: 376 | raise NotFittedError( 377 | "Object must be fitted with fit_gibbs method") 378 | _y = self._cast_observations(y) 379 | ret = _utils.cluster(_y, 380 | self.map_sim_params["w"], 381 | self.map_sim_params["theta"]) 382 | if not full: 383 | ret = ret[0] 384 | return ret 385 | 386 | def gibbs_map_pairplot(self): 387 | import matplotlib.pyplot as plt 388 | import matplotlib.ticker as ticker 389 | from matplotlib.ticker import FormatStrFormatter 390 | from scipy.stats import chi2 391 | from scipy.linalg import ldl 392 | 393 | names = self._column_names 394 | n_feats = self.y.shape[1] 395 | grp = self.gibbs_map_cluster() 396 | 397 | fig, axes = plt.subplots(nrows=n_feats, ncols=n_feats, sharex='col', 398 | figsize=(n_feats * 2, n_feats * 2)) 399 | fig.set_dpi(150) 400 | alpha = 0.05 401 | alpha_radius = np.sqrt(chi2.ppf(1 - alpha, 2)) 402 | color = plt.get_cmap('tab10') 403 | if max(grp) > 10: 404 | color = color(np.linspace(0, 1, len(np.unique(grp)) + 1)) 405 | else: 406 | color = color(np.linspace(0, 1, 10)) 407 | 408 | # We set labels for each subplot on the left and lower borders 409 | # and set a shared y-axis for all subplots except those in the diagonal 410 | for it in range(n_feats): 411 | ax = axes[-1, it] 412 | ax.set_xlabel(names[it]) 413 | ax = axes[it, 0] 414 | ax.set_ylabel(names[it]) 415 | for it2 in range(n_feats - 1): 416 | if it == it2: 417 | continue 418 | it3 = it2 + 1 419 | if it == it3: 420 | it3 += 1 421 | if it3 == n_feats: 422 | continue 423 | ax = axes[it, it2] 424 | ax.sharey(axes[it, it3]) 425 | 426 | # Iterates over all subplots and does a density plot over the diagonal 427 | # and scatters off the diagonal 428 | for it in range(n_feats): 429 | for it2 in range(n_feats): 430 | ax = axes[it2, it] 431 | ax.xaxis.set_major_locator(ticker.LinearLocator(numticks=4)) 432 | ax.yaxis.set_major_locator(ticker.LinearLocator(numticks=4)) 433 | ax.tick_params(axis='y', labelrotation=90) 434 | ax.tick_params(labelsize=8) 435 | ticks = ax.xaxis.get_major_ticks() 436 | ticks[0].label1.set_visible(False) 437 | ticks[-1].label1.set_visible(False) 438 | ticks = ax.yaxis.get_major_ticks() 439 | ticks[0].label1.set_visible(False) 440 | ticks[-1].label1.set_visible(False) 441 | ax.grid(color='lightgray', linestyle='--', alpha=0.5) 442 | if it == it2: 443 | y_min = self.y[:, it].min() 444 | y_max = self.y[:, it].max() 445 | y_ptp = y_max - y_min 446 | y_range = np.linspace(y_min - y_ptp / 4, 447 | y_max + y_ptp / 4, 448 | 100) 449 | for j in np.unique(grp): 450 | dens = self.gibbs_map_density(y_range, dim=it, 451 | component=j) 452 | # ax.plot(y_range, dens, c=color[j]) 453 | ax.fill_between(y_range, dens, interpolate=True, 454 | color=color[j], alpha=0.5) 455 | for j in range(len(self.map_sim_params['w'])): 456 | if j in grp: 457 | continue 458 | dens = self.gibbs_map_density(y_range, dim=it, 459 | component=j) 460 | if max(dens) > 1e-4: 461 | ax.fill_between(y_range, dens, interpolate=True, 462 | color='black', alpha=0.25) 463 | dens = self.gibbs_map_density(y_range, dim=it) 464 | ax.plot(y_range, dens, color='black') 465 | if max(dens) < 0.001 or min(dens) > 99999: 466 | ax.yaxis.set_major_formatter(FormatStrFormatter('%.1E')) 467 | else: 468 | ax.yaxis.set_major_formatter(FormatStrFormatter('%.2g')) 469 | else: 470 | ax.scatter(self.y[:, it], self.y[:, it2], s=10, 471 | c=color[grp], alpha=0.5) 472 | for j in np.unique(grp): 473 | mu = self.map_sim_params['theta'][j][0] 474 | mu = mu[[it, it2]] 475 | sigma = self.map_sim_params['theta'][j][1] 476 | sigma = sigma[[it, it2], :][:, [it, it2]] 477 | ldl_sigma = ldl(sigma) 478 | sig_sq = ldl_sigma[0] @ np.sqrt( 479 | ldl_sigma[1]) @ ldl_sigma[0].T 480 | circle_points = np.array( 481 | [np.cos(np.linspace(0, 2 * np.pi, 100)), 482 | np.sin(np.linspace(0, 2 * np.pi, 483 | 100))]).T * alpha_radius 484 | circle_points = np.matmul(circle_points, sig_sq) + mu 485 | ax.plot(circle_points[:, 0], circle_points[:, 1], 486 | linestyle='--', c=color[j]) 487 | # break 488 | 489 | def var_eap_density(self, y=None, dim=None, component=None): 490 | """ 491 | Returns the expected a posteriori density at y using variational 492 | inference 493 | 494 | This method is called after fitting a dataset with 495 | `fit_variational`. 496 | It returns the density at `y` as described by the fitted variational 497 | distributions using the expected density at each point. 498 | 499 | Parameters 500 | ---------- 501 | y : {array-like} of shape (n_samples, n_features), default=None 502 | The points at which to draw the variational EAP density. If `None` 503 | the data used at fitting is used. 504 | dim: int, {array-like} default=None 505 | The desired dimension index for which to marginalize the density, 506 | if None, all dimensions are used. 507 | component: int default=None 508 | Only returns the scaled density for a particular component. 509 | """ 510 | if not self.var_fitted: 511 | raise NotFittedError("Object must be fitted with fit_variational" 512 | " method") 513 | _y = self._cast_observations(y) 514 | if dim is None: 515 | dim = np.arange(_y.shape[1]) 516 | if isinstance(dim, int): 517 | len_dim = 1 518 | else: 519 | len_dim = len(dim) 520 | f_x = np.zeros(len(_y)) 521 | if component is None: 522 | iterator = enumerate(self.var_theta) 523 | else: 524 | iterator = [(component, self.var_theta[component])] 525 | for j, vt_j in iterator: 526 | v_mu_j, v_lambda_j, v_precision_j, v_scale_j = vt_j 527 | v_mu_j = v_mu_j[dim] 528 | v_precision_j = v_precision_j[:, dim][dim, :] 529 | f_x += density_students_t( 530 | _y, v_mu_j, 531 | v_precision_j * (v_scale_j + 1 - 532 | len_dim) * v_lambda_j / (1 + v_lambda_j), 533 | v_scale_j + 1 - len_dim 534 | ) * self.weight_model.variational_mean_w_j(j) 535 | return f_x 536 | 537 | def var_map_density(self, y=None, dim=None, component=None): 538 | """ 539 | Returns the maximum a posteriori density at y using variational 540 | inference 541 | 542 | This method is called after fitting a dataset with `fit_variational`. 543 | It returns the density at `y` as described by the fitted variational 544 | distributions using the maximum likelihood density at each point. 545 | 546 | Parameters 547 | ---------- 548 | y : {array-like} of shape (n_samples, n_features), default=None 549 | The points at which to draw the variational MAP density. If `None` 550 | the data used at fitting is used. 551 | """ 552 | if not self.var_fitted: 553 | raise NotFittedError("Object must be fitted with fit_variational" 554 | " method") 555 | _y = self._cast_observations(y) 556 | if isinstance(dim, int): 557 | len_dim = 1 558 | else: 559 | len_dim = len(dim) 560 | f_x = np.zeros(len(_y)) 561 | if component is None: 562 | iterator = enumerate(self.var_theta) 563 | else: 564 | iterator = [(component, self.var_theta[component])] 565 | for j, vt_j in iterator: 566 | v_mu_j, v_lambda_j, v_precision_j, v_scale_j = vt_j 567 | v_mu_j = v_mu_j[dim] 568 | v_precision_j = v_precision_j[:, dim][dim, :] 569 | map_mu = v_mu_j 570 | map_precision = (v_scale_j - len_dim) * v_precision_j 571 | f_x += density_normal(_y, map_mu, map_precision 572 | ) * self.weight_model.variational_mode_w_j(j) 573 | return f_x 574 | 575 | def var_eap_affinity_matrix(self, y=None): 576 | """ 577 | Returns the (Variational fitted) affinity matrix for the observations y 578 | 579 | This init_method must be called after fitting a dataset with 580 | `fit_variational`. 581 | It returns an affinity matrix for `y`. The entry (it,it2) of the 582 | returned matrix denotes the variational probability of draws in the 583 | assignation of y[it] and y[it2]. 584 | 585 | Parameters 586 | ---------- 587 | y : {array-like} of shape (n_samples, n_features), default=None 588 | The data points for which to get an affinity matrix. If `None` 589 | the data used at fitting is used. 590 | """ 591 | if not self.var_fitted: 592 | raise NotFittedError( 593 | "Object must be fitted with the fit_variational method") 594 | if y is None: 595 | var_d = self.var_d 596 | else: 597 | _y = self._cast_observations(y) 598 | dim = y.shape[1] 599 | var_d = np.zeros((self.var_k, y.shape[0]), dtype=np.float64) 600 | for j, vt_j in enumerate(self.var_theta): 601 | v_mu_j, v_lambda_j, v_precision_j, v_scale_j = vt_j 602 | log_d_ji = self.weight_model.variational_mean_log_w_j(j) 603 | log_d_ji += _utils.e_log_norm_wishart(v_precision_j, 604 | v_scale_j) / 2 605 | log_d_ji -= dim * np.log(2 * np.pi) / 2 606 | log_d_ji -= dim / v_lambda_j / 2 607 | log_d_ji -= (v_scale_j * 608 | ((y - v_mu_j).T * ( 609 | v_precision_j @ (y - v_mu_j).T)).sum( 610 | 0) 611 | ) / 2 612 | var_d[j, :] = log_d_ji 613 | var_d -= var_d.max(axis=0, initial=-np.inf) 614 | var_d = np.exp(var_d) 615 | var_d += np.finfo(np.float64).eps 616 | var_d /= var_d.sum(axis=0) 617 | affinity_matrix = var_d.T @ var_d 618 | return affinity_matrix 619 | 620 | def var_eap_spectral_consensus_cluster(self, y=None, n_clusters=1): 621 | """ 622 | Returns the (Variational fitted) expected a posteriori cluster for y 623 | 624 | This init_method must be called after fitting a dataset with 625 | `fit_variational`. 626 | It returns the EAP consensus clustering for the observations `y`. 627 | It uses the spectral clustering algorithm over the EAP affinity matrix 628 | as the consensus algorithm. 629 | 630 | Parameters 631 | ---------- 632 | y : {array-like} of shape (n_samples, n_features), default=None 633 | The data points to cluster. If `None` 634 | the data used at fitting is used. 635 | n_clusters: int, default=1 636 | The number of clusters to output. 637 | """ 638 | if not self.var_fitted: 639 | raise NotFittedError( 640 | "Object must be fitted with fit_variational init_method") 641 | sc = SpectralClustering(n_clusters=n_clusters, affinity='precomputed') 642 | return sc.fit_predict(self.var_eap_affinity_matrix(y)) 643 | 644 | def var_map_cluster(self, y=None, full=False): 645 | """ 646 | Returns the maximum a posteriori clustering for y using variational 647 | inference 648 | 649 | This method is called after fitting a dataset with `fit_variational`. 650 | It returns a clustering for `y` using the fitted variational 651 | distributions and the assignations with greater likelihood. 652 | 653 | Parameters 654 | ---------- 655 | y : {array-like} of shape (n_samples, n_features), default=None 656 | The points to cluster using the MAP assignations. If `None` 657 | the data used at fitting is used. 658 | full: bool, default=False 659 | If False (default), only the maximum a posteriori clustering is 660 | returned. If True, the variational assignation probabilty is also 661 | returned. 662 | """ 663 | if not self.var_fitted: 664 | raise NotFittedError("Object must be fitted with fit_variational" 665 | " method") 666 | if y is None: 667 | if not full: 668 | return self.var_d.argmax(0) 669 | else: 670 | d = self.var_d.argmax(0) 671 | return d, 1 - self.var_d[d, range(len(d))] 672 | _y = self._cast_observations(y) 673 | dim = _y.shape[1] 674 | var_d = np.zeros((self.var_k, _y.shape[0]), dtype=np.float64) 675 | for j, vt_j in enumerate(self.var_theta): 676 | v_mu_j, v_lambda_j, v_precision_j, v_scale_j = vt_j 677 | log_d_ji = self.weight_model.variational_mean_log_w_j(j) 678 | log_d_ji += _utils.e_log_norm_wishart(v_precision_j, v_scale_j) / 2 679 | log_d_ji -= dim / (2 * v_lambda_j) 680 | log_d_ji -= (v_scale_j / 2 * 681 | ((_y - v_mu_j).T * ( 682 | v_precision_j @ (_y - v_mu_j).T)).sum(0) 683 | ) 684 | var_d[j, :] = log_d_ji 685 | var_d -= var_d.mean(0) 686 | var_d = np.exp(var_d) 687 | var_d += np.finfo(np.float64).eps 688 | var_d /= var_d.sum(0) 689 | if not full: 690 | return var_d.argmax(0) 691 | else: 692 | d = var_d.argmax(0) 693 | return d, var_d[d] 694 | 695 | def get_n_groups(self): 696 | return self.n_groups 697 | 698 | def get_n_theta(self): 699 | return self.n_atoms 700 | 701 | def get_sim_params(self): 702 | return self.sim_params 703 | 704 | def _initialize_common_params(self, y): 705 | """ 706 | Initialize the prior variables if not given 707 | """ 708 | if isinstance(y, pd.DataFrame): 709 | self.y = y.to_numpy() 710 | self._column_names = y.columns 711 | elif isinstance(y, pd.Series): 712 | self.y = y.to_numpy() 713 | if y.name is None: 714 | self._column_names = np.array([0]) 715 | else: 716 | self._column_names = np.array([y.name]) 717 | elif isinstance(y, list): 718 | self.y = np.array(y) 719 | if self.y.ndim == 1: 720 | self.y = self.y.reshape(-1, 1) 721 | self._column_names = np.array([0]) 722 | else: 723 | self._column_names = np.arange(self.y.shape[1]) 724 | elif isinstance(y, np.ndarray): 725 | self.y = np.copy(y) 726 | if self.y.ndim == 1: 727 | self.y = self.y.reshape(-1, 1) 728 | self._column_names = np.array([0]) 729 | else: 730 | self._column_names = np.arange(self.y.shape[1]) 731 | else: 732 | raise TypeError('Invalid type for variable y') 733 | 734 | if self.mu_prior is None: 735 | self.mu_prior = self.y.mean(axis=0) 736 | else: 737 | self.mu_prior = np.array(self.mu_prior) 738 | if self.psi_prior is None: 739 | self.psi_prior = np.atleast_2d(np.cov(self.y.T)) 740 | else: 741 | self.psi_prior = np.atleast_2d(self.psi_prior) 742 | if self.nu_prior is None: 743 | self.nu_prior = self.y.shape[1] 744 | 745 | def _initialize_gibbs_params(self, init_groups=None, method="kmeans"): 746 | """ 747 | Initialize the Gibbs sampler latent variables 748 | 749 | This method randomly initializes the number of groups, mean and 750 | variance variables and the assignation vector. 751 | 752 | Parameters 753 | ---------- 754 | init_groups: int, default=None 755 | Maximum number of groups to assign in the initialization. If None, 756 | the initial number of groups is drawn from the weighting structure 757 | model's attribute `n`. 758 | This parameter is only used in k-means initialization. 759 | method: str 760 | "kmeans": does a kmeans initialization 761 | "random": does a random initialization based on the prior models 762 | "variational": fits the variational distribution an uses the MAP 763 | parameters as initialization 764 | """ 765 | 766 | def atom_generator(): 767 | mu, sigma = _utils.random_normal_invw( 768 | mu=self.mu_prior, 769 | lam=self.lambda_prior, 770 | psi=self.psi_prior, 771 | nu=self.nu_prior, 772 | rng=self.rng 773 | ) 774 | return np.atleast_1d(mu), np.atleast_2d(sigma) 775 | 776 | self.sim_params = [] 777 | self.n_groups = [] 778 | self.n_atoms = [] 779 | self.total_saved_steps = 0 780 | 781 | self.theta = defaultdict(atom_generator) 782 | self.affinity_matrix = np.zeros((len(self.y), len(self.y))) 783 | self.u = self.rng.uniform(0 + np.finfo(np.float64).eps, 1, 784 | len(self.y)) 785 | if init_groups is None: 786 | self.weight_model.tail(1 - min(self.u)) 787 | else: 788 | self.weight_model.complete(init_groups) 789 | 790 | if method == "kmeans": 791 | if hasattr(self, 'n'): 792 | n = self.n 793 | else: 794 | if init_groups is not None: 795 | n = init_groups 796 | else: 797 | n = 2 798 | self.d = _utils.kmeans_cluster_size_biased(self.y, n, self.rng) 799 | elif method == "random": 800 | self.d = self.weight_model.random_assignment(len(self.y)) 801 | elif method == "variational": 802 | self.fit_variational(self.y, 803 | n_groups=self.weight_model.get_size()) 804 | self.d = self.var_map_cluster() 805 | else: 806 | raise AttributeError("init_method param must be one of 'kmeans', " 807 | "'random', 'variational'") 808 | 809 | def _initialize_variational_params(self, var_k, init_method="kmeans"): 810 | """ 811 | Initialize the variational parameters for the variational distributions 812 | 813 | This method randomly initializes the parameters for the assignation 814 | vector distribution, assigns the variational Normal-Wishart parameters 815 | and fits the weight_model. 816 | 817 | Parameters 818 | ---------- 819 | var_k: int 820 | Maximum number of groups to assign in the initialization. If None, 821 | the number of groups drawn from the weight model is not caped. 822 | 823 | init_method: str 824 | "kmeans": initialize variational parameters using k-means algorithm 825 | "random": initialize variational parameters using a random 826 | assignment 827 | """ 828 | self.var_k = var_k 829 | self.var_theta = [] 830 | 831 | if init_method == "kmeans": 832 | d = _utils.kmeans_cluster_size_biased(self.y, self.var_k, 833 | self.rng) 834 | var_d = np.zeros((self.var_k, self.y.shape[0]), 835 | dtype=np.float64) 836 | var_d[d, range(len(d))] = 1 837 | self.var_d = var_d 838 | self.weight_model.fit_variational(np.empty(shape=(self.var_k, 0))) 839 | self._update_var_theta() 840 | self._update_var_d() 841 | elif init_method == "random": 842 | for _ in range(self.var_k): 843 | mu_j, temp_psi = _utils.random_normal_invw( 844 | mu=self.mu_prior, 845 | lam=self.lambda_prior, 846 | psi=self.psi_prior, 847 | nu=self.nu_prior, 848 | rng=self.rng 849 | ) 850 | mu_j = np.atleast_1d(mu_j) 851 | temp_psi = np.atleast_2d(temp_psi) 852 | self.var_theta.append([mu_j, self.lambda_prior, 853 | np.linalg.inv(temp_psi), self.nu_prior]) 854 | self.var_d = np.tile(1 / self.var_k, (self.var_k, self.y.shape[0])) 855 | else: 856 | raise AttributeError("init_method param must be one of 'kmeans', " 857 | "'random'") 858 | self.weight_model.fit_variational(self.var_d) 859 | 860 | def _get_run_params(self): 861 | return {"w": self.weight_model.get_weights(), 862 | "theta": dict(self.theta), 863 | "u": self.u, 864 | "d": self.d} 865 | 866 | def _save_params(self): 867 | self.sim_params.append(self._get_run_params()) 868 | self.n_groups.append(len(np.unique(self.d))) 869 | self.n_atoms.append(len(self.theta)) 870 | self.affinity_matrix += np.equal(self.d, self.d[:, None]) 871 | self.total_saved_steps += 1 872 | 873 | def _update_map_params(self): 874 | """Calc the likelihood and parameters of the run. Update MAP if the 875 | likelihood is greater 876 | """ 877 | run_log_likelihood = self._run_log_likelihood() 878 | if self.map_log_likelihood < run_log_likelihood: 879 | self.map_log_likelihood = run_log_likelihood 880 | self.map_sim_params = self._get_run_params() 881 | elif self.map_log_likelihood == -np.inf: 882 | # Save the params to get something to compare 883 | self.map_sim_params = self._get_run_params() 884 | 885 | def _update_weights(self): 886 | self.weight_model.fit(self.d) 887 | w = self.weight_model.random() 888 | self.u = self.rng.uniform(0 + np.finfo(np.float64).eps, 889 | w[self.d] + np.finfo(np.float64).eps) 890 | self.weight_model.tail(1 - min(self.u)) 891 | 892 | def _update_atoms(self): 893 | for j in np.unique(self.d): 894 | mask_j = self.d == j 895 | posterior_params = _utils.posterior_norm_invw_params( 896 | self.y[mask_j], 897 | mu=self.mu_prior, 898 | lam=self.lambda_prior, 899 | psi=self.psi_prior, 900 | nu=self.nu_prior) 901 | temp_mu, temp_sigma = _utils.random_normal_invw( 902 | mu=posterior_params["mu"], 903 | lam=posterior_params["lambda"], 904 | psi=posterior_params["psi"], 905 | nu=posterior_params["nu"], 906 | rng=self.rng) 907 | temp_mu = np.atleast_1d(temp_mu) 908 | temp_sigma = np.atleast_2d(temp_sigma) 909 | self.theta[j] = (temp_mu, temp_sigma) 910 | 911 | def _update_d(self): 912 | log_prob = self._d_log_likelihood_vector() 913 | self.d = _utils.gumbel_max_sampling(log_prob, rng=self.rng) 914 | 915 | def _gibbs_step(self): 916 | self._update_atoms() 917 | self._update_weights() 918 | self._update_d() 919 | 920 | def _d_log_likelihood_vector(self): 921 | with np.errstate(divide='ignore'): 922 | log_probability = np.array( 923 | [multivariate_normal.logpdf(self.y, 924 | self.theta[j][0], 925 | self.theta[j][1], 926 | 1) 927 | for j in range(self.weight_model.get_size())] 928 | ) 929 | log_probability += np.log(np.greater.outer( 930 | self.weight_model.get_weights(), 931 | self.u)) 932 | return log_probability 933 | 934 | def _run_log_likelihood(self): 935 | ret = 0 936 | ret += self._y_log_likelihood() 937 | ret += self._d_log_likelihood() 938 | return ret 939 | 940 | def _y_log_likelihood(self): 941 | """returns the loglikelihood of f(y|d, w, theta)""" 942 | ret = 0 943 | with np.errstate(divide='ignore'): 944 | for j in np.unique(self.d): 945 | ret += np.sum(multivariate_normal.logpdf(self.y[self.d == j], 946 | self.theta[j][0], 947 | self.theta[j][1], 948 | 1)) 949 | return ret 950 | 951 | def _d_log_likelihood(self): 952 | """returns the loglikelihood of f(d|w)""" 953 | return self.weight_model.assignation_log_likelihood(self.d) 954 | 955 | def _w_log_likelihood(self): 956 | """returns the loglikelihood of f(w)""" 957 | return self.weight_model.weighting_log_likelihood() 958 | 959 | def _theta_log_likelihood(self): 960 | """returns the loglikelihood of f(theta)""" 961 | res = 0 962 | for j in np.unique(self.d): 963 | mu, sigma = self.theta[j] 964 | res += _utils.log_likelihood_normal_invw( 965 | mu=mu, 966 | sigma=sigma, 967 | mu0=self.mu_prior, 968 | lam0=self.lambda_prior, 969 | psi0=self.psi_prior, 970 | nu0=self.nu_prior 971 | ) 972 | return res 973 | 974 | def _maximize_variational(self): 975 | self._update_var_d() 976 | self._update_var_w() 977 | self._update_var_theta() 978 | 979 | def _calc_elbo(self): 980 | ret = 0 981 | ret += self._e_q_log_p_x() 982 | ret += self._e_q_log_p_d__w() 983 | ret += self._e_log_p_w() 984 | ret += self._e_log_p_theta() 985 | ret -= self._e_log_q_d() 986 | ret -= self._e_loq_q_w() 987 | ret -= self._e_log_q_theta() 988 | return ret 989 | 990 | def _update_var_w(self): 991 | self.weight_model.fit_variational(self.var_d) 992 | 993 | def _update_var_theta(self): 994 | var_theta = [] 995 | for vd_j in self.var_d: 996 | n_j = vd_j.sum() 997 | x_bar_j = (vd_j / n_j) @ self.y 998 | ns_j = np.einsum('i,ij,ik->jk', 999 | vd_j, (self.y - x_bar_j), (self.y - x_bar_j) 1000 | ) / n_j 1001 | v_lambda_j = self.lambda_prior + n_j 1002 | v_scale_j = self.nu_prior + n_j 1003 | v_mu_j = (self.lambda_prior * self.mu_prior + 1004 | n_j * x_bar_j) / v_lambda_j 1005 | v_precision_j = (self.psi_prior + ns_j + 1006 | self.lambda_prior * n_j / (self.lambda_prior + 1007 | n_j) * 1008 | (x_bar_j - self.mu_prior) @ (x_bar_j - 1009 | self.mu_prior)) 1010 | v_precision_j = np.linalg.inv(v_precision_j) 1011 | var_theta.append([v_mu_j, v_lambda_j, v_precision_j, v_scale_j]) 1012 | self.var_theta = var_theta 1013 | 1014 | def _update_var_d(self): 1015 | dim = self.y.shape[1] 1016 | var_d = np.zeros((self.var_k, self.y.shape[0]), dtype=np.float64) 1017 | for j, vt_j in enumerate(self.var_theta): 1018 | v_mu_j, v_lambda_j, v_precision_j, v_scale_j = vt_j 1019 | log_d_ji = self.weight_model.variational_mean_log_w_j(j) 1020 | log_d_ji += _utils.e_log_norm_wishart(v_precision_j, v_scale_j) / 2 1021 | log_d_ji -= dim * np.log(2 * np.pi) / 2 1022 | log_d_ji -= dim / v_lambda_j / 2 1023 | log_d_ji -= (v_scale_j * 1024 | ((self.y - v_mu_j).T * ( 1025 | v_precision_j @ (self.y - v_mu_j).T)).sum(0) 1026 | ) / 2 1027 | var_d[j, :] = log_d_ji 1028 | var_d -= var_d.max(axis=0, initial=-np.inf) 1029 | var_d = np.exp(var_d) 1030 | var_d += np.finfo(np.float64).eps 1031 | var_d /= var_d.sum(axis=0) 1032 | self.var_d = var_d 1033 | 1034 | def _e_q_log_p_x(self): 1035 | dim = self.y.shape[1] 1036 | res = 0 1037 | for vd_j, vt_j in zip(self.var_d, self.var_theta): 1038 | n_j = vd_j.sum() 1039 | x_bar_j = (vd_j / n_j) @ self.y 1040 | s_j = (vd_j / n_j * (self.y - x_bar_j).T) @ (self.y - x_bar_j) 1041 | v_mu_j, v_lambda_j, v_precision_j, v_scale_j = vt_j 1042 | res += _utils.e_log_norm_wishart(v_precision_j, v_scale_j) 1043 | res -= dim / v_lambda_j 1044 | res -= v_scale_j * np.einsum('ij,ji->', s_j, v_precision_j) 1045 | res -= (v_scale_j * 1046 | (x_bar_j - v_mu_j) @ v_precision_j @ (x_bar_j - v_mu_j) 1047 | ) 1048 | res -= dim * np.log(2 * np.pi) * self.var_k 1049 | res /= 2 1050 | return res 1051 | 1052 | def _e_q_log_p_d__w(self): 1053 | return self.weight_model.variational_mean_log_p_d__w( 1054 | variational_d=self.var_d 1055 | ) 1056 | 1057 | def _e_log_p_w(self): 1058 | return self.weight_model.variational_mean_log_p_w() 1059 | 1060 | def _e_log_p_theta(self): 1061 | dim = self.y.shape[1] 1062 | res = 0 1063 | for vt_j in self.var_theta: 1064 | v_mu_j, v_lambda_j, v_precision_j, v_scale_j = vt_j 1065 | res += (self.nu_prior - 1066 | dim) * _utils.e_log_norm_wishart(v_precision_j, v_scale_j) 1067 | res -= dim * self.lambda_prior / v_lambda_j 1068 | res -= v_scale_j * np.einsum('ij,ji->', 1069 | self.psi_prior, v_precision_j) 1070 | res -= (self.lambda_prior * v_scale_j * 1071 | (v_mu_j - self.mu_prior) @ v_precision_j @ (v_mu_j - 1072 | self.mu_prior) 1073 | ) 1074 | res += dim * (np.log(self.lambda_prior / (2 * np.pi))) * self.var_k 1075 | res /= 2 1076 | res += self.var_k * _utils.log_wishart_normalization_term( 1077 | np.linalg.inv(self.psi_prior), self.nu_prior 1078 | ) 1079 | return res 1080 | 1081 | def _e_log_q_d(self): 1082 | return np.sum(self.var_d * np.log(self.var_d)) 1083 | 1084 | def _e_loq_q_w(self): 1085 | return self.weight_model.variational_mean_log_p_w() 1086 | 1087 | def _e_log_q_theta(self): 1088 | dim = self.y.shape[1] 1089 | res = 0 1090 | for vt_j in self.var_theta: 1091 | _, v_lambda_j, v_precision_j, v_scale_j = vt_j 1092 | res += _utils.e_log_norm_wishart(v_precision_j, v_scale_j) 1093 | res += dim * (np.log(v_lambda_j / (2 * np.pi))) 1094 | res -= _utils.entropy_wishart(v_precision_j, v_scale_j) 1095 | res += dim * self.var_k 1096 | res /= 2 1097 | return res 1098 | 1099 | def _cast_observations(self, y): 1100 | if y is None: 1101 | _y = self.y 1102 | else: 1103 | if isinstance(y, pd.DataFrame): 1104 | _y = y.to_numpy() 1105 | elif isinstance(y, list): 1106 | _y = np.array(y) 1107 | if _y.ndim == 1: 1108 | _y = _y.reshape(-1, 1) 1109 | elif isinstance(y, np.ndarray): 1110 | if y.ndim == 1: 1111 | _y = y.copy() 1112 | _y = y.reshape(-1, 1) 1113 | else: 1114 | _y = y 1115 | elif isinstance(y, (int, float)): 1116 | _y = np.array([[y]]) 1117 | else: 1118 | raise TypeError("Invalid type for variable y") 1119 | return _y 1120 | --------------------------------------------------------------------------------