├── .editorconfig ├── .github ├── actions │ └── setup-python-env │ │ └── action.yml └── workflows │ ├── build-and-inspect-package.yml │ ├── main.yml │ ├── on-release-main.yml │ └── validate-codecov-config.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── codecov.yaml ├── docs ├── _static │ ├── gmmx-logo.png │ ├── time-vs-n-components-fit.png │ ├── time-vs-n-components-predict.png │ ├── time-vs-n-features-fit.png │ ├── time-vs-n-features-predict.png │ ├── time-vs-n-samples-fit.png │ └── time-vs-n-samples-predict.png ├── development.md ├── index.md └── modules.md ├── examples ├── benchmarks │ └── run-benchmark.py ├── moon-density-estimation.py └── simple-1d-fit.py ├── gmmx ├── __init__.py ├── fit.py ├── gmm.py └── utils.py ├── mkdocs.yml ├── pyproject.toml ├── tests └── test_gmm.py └── tox.ini /.editorconfig: -------------------------------------------------------------------------------- 1 | max_line_length = 88 2 | 3 | [*.json] 4 | indent_style = space 5 | indent_size = 4 6 | -------------------------------------------------------------------------------- /.github/actions/setup-python-env/action.yml: -------------------------------------------------------------------------------- 1 | name: "Setup Python Environment" 2 | description: "Set up Python environment for the given Python version" 3 | 4 | inputs: 5 | python-version: 6 | description: "Python version to use" 7 | required: true 8 | default: "3.12" 9 | uv-version: 10 | description: "uv version to use" 11 | required: true 12 | default: "0.4.6" 13 | 14 | runs: 15 | using: "composite" 16 | steps: 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ inputs.python-version }} 20 | 21 | - name: Install uv 22 | uses: astral-sh/setup-uv@v2 23 | with: 24 | version: ${{ inputs.uv-version }} 25 | enable-cache: "true" 26 | cache-suffix: ${{ matrix.python-version }} 27 | 28 | - name: Install Python dependencies 29 | run: uv sync 30 | shell: bash 31 | -------------------------------------------------------------------------------- /.github/workflows/build-and-inspect-package.yml: -------------------------------------------------------------------------------- 1 | name: build-and-inspect-package 2 | 3 | on: 4 | pull_request: 5 | paths: [pyproject.toml] 6 | push: 7 | branches: [main] 8 | 9 | jobs: 10 | build-and-inspect-package: 11 | name: Build & inspect package. 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: hynek/build-and-inspect-python-package@v2 16 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Main 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, synchronize, reopened, ready_for_review] 9 | 10 | jobs: 11 | quality: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Check out 15 | uses: actions/checkout@v4 16 | 17 | - uses: actions/cache@v4 18 | with: 19 | path: ~/.cache/pre-commit 20 | key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} 21 | 22 | - name: Set up the environment 23 | uses: ./.github/actions/setup-python-env 24 | 25 | - name: Run checks 26 | run: make check 27 | 28 | tests-and-type-check: 29 | runs-on: ubuntu-latest 30 | strategy: 31 | matrix: 32 | python-version: ["3.9", "3.10", "3.11", "3.12"] 33 | fail-fast: false 34 | defaults: 35 | run: 36 | shell: bash 37 | steps: 38 | - name: Check out 39 | uses: actions/checkout@v4 40 | 41 | - name: Set up the environment 42 | uses: ./.github/actions/setup-python-env 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | 46 | - name: Run tests 47 | run: uv run python -m pytest tests --cov --cov-config=pyproject.toml --cov-report=xml 48 | 49 | - name: Check typing 50 | run: uv run mypy 51 | 52 | - name: Upload coverage reports to Codecov with GitHub Action on Python 3.11 53 | uses: codecov/codecov-action@v5 54 | if: ${{ matrix.python-version == '3.11' }} 55 | with: 56 | token: ${{ secrets.CODECOV_TOKEN }} 57 | 58 | check-docs: 59 | runs-on: ubuntu-latest 60 | steps: 61 | - name: Check out 62 | uses: actions/checkout@v4 63 | 64 | - name: Set up the environment 65 | uses: ./.github/actions/setup-python-env 66 | 67 | - name: Check if documentation can be built 68 | run: uv run mkdocs build -s 69 | -------------------------------------------------------------------------------- /.github/workflows/on-release-main.yml: -------------------------------------------------------------------------------- 1 | name: release-main 2 | 3 | on: 4 | release: 5 | types: [published] 6 | branches: [main] 7 | 8 | jobs: 9 | set-version: 10 | runs-on: ubuntu-24.04 11 | steps: 12 | - uses: actions/checkout@v4 13 | 14 | - name: Export tag 15 | id: vars 16 | run: echo tag=${GITHUB_REF#refs/*/} >> $GITHUB_OUTPUT 17 | if: ${{ github.event_name == 'release' }} 18 | 19 | - name: Update project version 20 | run: | 21 | sed -i "s/^version = \".*\"/version = \"$RELEASE_VERSION\"/" pyproject.toml 22 | env: 23 | RELEASE_VERSION: ${{ steps.vars.outputs.tag }} 24 | if: ${{ github.event_name == 'release' }} 25 | 26 | - name: Upload updated pyproject.toml 27 | uses: actions/upload-artifact@v4 28 | with: 29 | name: pyproject-toml 30 | path: pyproject.toml 31 | publish: 32 | runs-on: ubuntu-latest 33 | needs: [set-version] 34 | steps: 35 | - name: Check out 36 | uses: actions/checkout@v4 37 | 38 | - name: Set up the environment 39 | uses: ./.github/actions/setup-python-env 40 | 41 | - name: Download updated pyproject.toml 42 | uses: actions/download-artifact@v4 43 | with: 44 | name: pyproject-toml 45 | - name: Build package 46 | run: uvx --from build pyproject-build --installer uv 47 | 48 | - name: Publish package 49 | run: uvx twine upload dist/* 50 | env: 51 | TWINE_USERNAME: __token__ 52 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 53 | 54 | deploy-docs: 55 | needs: publish 56 | runs-on: ubuntu-latest 57 | steps: 58 | - name: Check out 59 | uses: actions/checkout@v4 60 | 61 | - name: Set up the environment 62 | uses: ./.github/actions/setup-python-env 63 | 64 | - name: Deploy documentation 65 | run: uv run mkdocs gh-deploy --force 66 | env: 67 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 68 | -------------------------------------------------------------------------------- /.github/workflows/validate-codecov-config.yml: -------------------------------------------------------------------------------- 1 | name: validate-codecov-config 2 | 3 | on: 4 | pull_request: 5 | paths: [codecov.yaml] 6 | push: 7 | branches: [main] 8 | 9 | jobs: 10 | validate-codecov-config: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Validate codecov configuration 15 | run: curl -sSL --fail-with-body --data-binary @codecov.yaml https://codecov.io/validate 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | docs/source 2 | 3 | # From https://raw.githubusercontent.com/github/gitignore/main/Python.gitignore 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 90 | __pypackages__/ 91 | 92 | # Celery stuff 93 | celerybeat-schedule 94 | celerybeat.pid 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | # pytype static type analyzer 127 | .pytype/ 128 | 129 | # Cython debug symbols 130 | cython_debug/ 131 | 132 | # Vscode config files 133 | .vscode/ 134 | 135 | # PyCharm 136 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 137 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 138 | # and can be added to the global gitignore or merged into this file. For a more nuclear 139 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 140 | #.idea/ 141 | 142 | examples/benchmarks/results 143 | 144 | uv.lock 145 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: "v4.4.0" 4 | hooks: 5 | - id: check-case-conflict 6 | - id: check-merge-conflict 7 | - id: check-toml 8 | - id: check-yaml 9 | - id: end-of-file-fixer 10 | - id: trailing-whitespace 11 | 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | rev: "v0.6.3" 14 | hooks: 15 | - id: ruff 16 | args: [--exit-non-zero-on-fix] 17 | - id: ruff-format 18 | 19 | - repo: https://github.com/pre-commit/mirrors-prettier 20 | rev: "v3.0.3" 21 | hooks: 22 | - id: prettier 23 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Axel" 5 | given-names: "Donath" 6 | orcid: "https://orcid.org/0000-0003-4568-7005" 7 | title: "GMMX: Gaussian Mixture Models in Jax" 8 | version: v0.1 9 | doi: 10.5281/zenodo.14515326 10 | date-released: 2024-12-18 11 | url: "https://github.com/adonath/gmmx" 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `gmmx` 2 | 3 | Contributions are welcome, and they are greatly appreciated! 4 | Every little bit helps, and credit will always be given. 5 | 6 | You can contribute in many ways: 7 | 8 | # Types of Contributions 9 | 10 | ## Report Bugs 11 | 12 | Report bugs at https://github.com/adonath/gmmx/issues 13 | 14 | If you are reporting a bug, please include: 15 | 16 | - Your operating system name and version. 17 | - Any details about your local setup that might be helpful in troubleshooting. 18 | - Detailed steps to reproduce the bug. 19 | 20 | ## Fix Bugs 21 | 22 | Look through the GitHub issues for bugs. 23 | Anything tagged with "bug" and "help wanted" is open to whoever wants to implement a fix for it. 24 | 25 | ## Implement Features 26 | 27 | Look through the GitHub issues for features. 28 | Anything tagged with "enhancement" and "help wanted" is open to whoever wants to implement it. 29 | 30 | ## Write Documentation 31 | 32 | Cookiecutter PyPackage could always use more documentation, whether as part of the official docs, in docstrings, or even on the web in blog posts, articles, and such. 33 | 34 | ## Submit Feedback 35 | 36 | The best way to send feedback is to file an issue at https://github.com/adonath/gmmx/issues. 37 | 38 | If you are proposing a new feature: 39 | 40 | - Explain in detail how it would work. 41 | - Keep the scope as narrow as possible, to make it easier to implement. 42 | - Remember that this is a volunteer-driven project, and that contributions 43 | are welcome :) 44 | 45 | # Get Started! 46 | 47 | Ready to contribute? Here's how to set up `gmmx` for local development. 48 | Please note this documentation assumes you already have `uv` and `Git` installed and ready to go. 49 | 50 | 1. Fork the `gmmx` repo on GitHub. 51 | 52 | 2. Clone your fork locally: 53 | 54 | ```bash 55 | cd 56 | git clone git@github.com:YOUR_NAME/gmmx.git 57 | ``` 58 | 59 | 3. Now we need to install the environment. Navigate into the directory 60 | 61 | ```bash 62 | cd gmmx 63 | ``` 64 | 65 | Then, install and activate the environment with: 66 | 67 | ```bash 68 | uv sync 69 | ``` 70 | 71 | 4. Install pre-commit to run linters/formatters at commit time: 72 | 73 | ```bash 74 | uv run pre-commit install 75 | ``` 76 | 77 | 5. Create a branch for local development: 78 | 79 | ```bash 80 | git checkout -b name-of-your-bugfix-or-feature 81 | ``` 82 | 83 | Now you can make your changes locally. 84 | 85 | 6. Don't forget to add test cases for your added functionality to the `tests` directory. 86 | 87 | 7. When you're done making changes, check that your changes pass the formatting tests. 88 | 89 | ```bash 90 | make check 91 | ``` 92 | 93 | Now, validate that all unit tests are passing: 94 | 95 | ```bash 96 | make test 97 | ``` 98 | 99 | 9. Before raising a pull request you should also run tox. 100 | This will run the tests across different versions of Python: 101 | 102 | ```bash 103 | tox 104 | ``` 105 | 106 | This requires you to have multiple versions of python installed. 107 | This step is also triggered in the CI/CD pipeline, so you could also choose to skip this step locally. 108 | 109 | 10. Commit your changes and push your branch to GitHub: 110 | 111 | ```bash 112 | git add . 113 | git commit -m "Your detailed description of your changes." 114 | git push origin name-of-your-bugfix-or-feature 115 | ``` 116 | 117 | 11. Submit a pull request through the GitHub website. 118 | 119 | # Pull Request Guidelines 120 | 121 | Before you submit a pull request, check that it meets these guidelines: 122 | 123 | 1. The pull request should include tests. 124 | 125 | 2. If the pull request adds functionality, the docs should be updated. 126 | Put your new functionality into a function with a docstring, and add the feature to the list in `README.md`. 127 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Axel Donath 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install 2 | install: ## Install the virtual environment and install the pre-commit hooks 3 | @echo "🚀 Creating virtual environment using uv" 4 | @uv sync 5 | @uv run pre-commit install 6 | 7 | .PHONY: check 8 | check: ## Run code quality tools. 9 | @echo "🚀 Checking lock file consistency with 'pyproject.toml'" 10 | @uv lock --locked 11 | @echo "🚀 Linting code: Running pre-commit" 12 | @uv run pre-commit run -a 13 | @echo "🚀 Static type checking: Running mypy" 14 | @uv run mypy 15 | 16 | .PHONY: test 17 | test: ## Test the code with pytest 18 | @echo "🚀 Testing code: Running pytest" 19 | @uv run python -m pytest --cov --cov-config=pyproject.toml --cov-report=xml 20 | 21 | .PHONY: build 22 | build: clean-build ## Build wheel file 23 | @echo "🚀 Creating wheel file" 24 | @uvx --from build pyproject-build --installer uv 25 | 26 | .PHONY: clean-build 27 | clean-build: ## Clean build artifacts 28 | @echo "🚀 Removing build artifacts" 29 | @uv run python -c "import shutil; import os; shutil.rmtree('dist') if os.path.exists('dist') else None" 30 | 31 | .PHONY: publish 32 | publish: ## Publish a release to PyPI. 33 | @echo "🚀 Publishing." 34 | @uvx twine upload --repository-url https://upload.pypi.org/legacy/ dist/* 35 | 36 | .PHONY: build-and-publish 37 | build-and-publish: build publish ## Build and publish. 38 | 39 | .PHONY: docs-test 40 | docs-test: ## Test if documentation can be built without warnings or errors 41 | @uv run mkdocs build -s 42 | 43 | .PHONY: docs 44 | docs: ## Build and serve the documentation 45 | @uv run mkdocs serve 46 | 47 | .PHONY: help 48 | help: 49 | @uv run python -c "import re; \ 50 | [[print(f'\033[36m{m[0]:<20}\033[0m {m[1]}') for m in re.findall(r'^([a-zA-Z_-]+):.*?## (.*)$$', open(makefile).read(), re.M)] for makefile in ('$(MAKEFILE_LIST)').strip().split()]" 51 | 52 | .DEFAULT_GOAL := help 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GMMX: Gaussian Mixture Models in Jax 2 | 3 | [![Release](https://img.shields.io/github/v/release/adonath/gmmx)](https://img.shields.io/github/v/release/adonath/gmmx) 4 | [![Build status](https://img.shields.io/github/actions/workflow/status/adonath/gmmx/main.yml?branch=main)](https://github.com/adonath/gmmx/actions/workflows/main.yml?query=branch%3Amain) 5 | [![codecov](https://codecov.io/gh/adonath/gmmx/branch/main/graph/badge.svg)](https://codecov.io/gh/adonath/gmmx) 6 | [![Commit activity](https://img.shields.io/github/commit-activity/m/adonath/gmmx)](https://img.shields.io/github/commit-activity/m/adonath/gmmx) 7 | [![License](https://img.shields.io/github/license/adonath/gmmx)](https://img.shields.io/github/license/adonath/gmmx) 8 | [![DOI](https://zenodo.org/badge/879790145.svg)](https://doi.org/10.5281/zenodo.14515326) 9 | 10 |

11 | GMMX Logo 12 |

13 | 14 | A minimal implementation of Gaussian Mixture Models in Jax 15 | 16 | - **Github repository**: 17 | - **Documentation** 18 | 19 | ## Installation 20 | 21 | `gmmx` can be installed via pip: 22 | 23 | ```bash 24 | pip install gmmx 25 | ``` 26 | 27 | Or alternatively you can use `conda/mamba`: 28 | 29 | ```bash 30 | conda install gmmx 31 | ``` 32 | 33 | ## Usage 34 | 35 | ```python 36 | from gmmx import GaussianMixtureModelJax, EMFitter 37 | 38 | # Create a Gaussian Mixture Model with 16 components and 32 features 39 | gmm = GaussianMixtureModelJax.create(n_components=16, n_features=32) 40 | 41 | # Draw samples from the model 42 | n_samples = 10_000 43 | x = gmm.sample(n_samples) 44 | 45 | # Fit the model to the data 46 | em_fitter = EMFitter(tol=1e-3, max_iter=100) 47 | gmm_fitted = em_fitter.fit(x=x, gmm=gmm) 48 | ``` 49 | 50 | If you use the code in a scientific publication, please cite the Zenodo DOI from the badge above. 51 | 52 | ## Why Gaussian Mixture models? 53 | 54 | What are Gaussian Mixture Models (GMM) useful for in the age of deep learning? GMMs might have come out of fashion for classification tasks, but they still 55 | have a few properties that make them useful in certain scenarios: 56 | 57 | - They are universal approximators, meaning that given enough components they can approximate any distribution. 58 | - Their likelihood can be evaluated in closed form, which makes them useful for generative modeling. 59 | - They are rather fast to train and evaluate. 60 | 61 | I would strongly recommend to read [In Depth: Gaussian Mixture Models](https://jakevdp.github.io/PythonDataScienceHandbook/05.12-gaussian-mixtures.html) from the Python Data Science Handbook for a more in-depth introduction to GMMs and their 62 | application as density estimators. 63 | 64 | One of these applications in my research is the context of image reconstruction, where GMMs can be used to model the distribution and pixel correlations of local (patch based) 65 | image features. This can be useful for tasks like image denoising or inpainting. One of these methods I have used them for is [Jolideco](https://github.com/jolideco/jolideco). 66 | Speed up the training of O(10^6) patches was the main motivation for `gmmx`. 67 | 68 | ## Benchmarks 69 | 70 | Here are some results from the benchmarks in the [examples/benchmarks](https://github.com/adonath/gmmx/tree/main/examples/benchmarks) folder comparing against Scikit-Learn. The benchmarks were run on an "Intel(R) Xeon(R) Gold 6338" CPU and a single "NVIDIA L40S" GPU. 71 | 72 | ### Prediction Time 73 | 74 | | Time vs. Number of Components | Time vs. Number of Samples | Time vs. Number of Features | 75 | | ----------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | 76 | | ![Time vs. Number of Components](https://raw.githubusercontent.com/adonath/gmmx/main/docs/_static/time-vs-n-components-predict.png) | ![Time vs. Number of Samples](https://raw.githubusercontent.com/adonath/gmmx/main/docs/_static/time-vs-n-samples-predict.png) | ![Time vs. Number of Features](https://raw.githubusercontent.com/adonath/gmmx/main/docs/_static/time-vs-n-features-predict.png) | 77 | 78 | For prediction the speedup is around 5-6x for varying number of components and features and ~50x speedup on the GPU. For the number of samples the cross-over point is around O(10^3) samples. 79 | 80 | ### Training Time 81 | 82 | | Time vs. Number of Components | Time vs. Number of Samples | Time vs. Number of Features | 83 | | ------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | 84 | | ![Time vs. Number of Components](https://raw.githubusercontent.com/adonath/gmmx/main/docs/_static/time-vs-n-components-fit.png) | ![Time vs. Number of Samples](https://raw.githubusercontent.com/adonath/gmmx/main/docs/_static/time-vs-n-samples-fit.png) | ![Time vs. Number of Features](https://raw.githubusercontent.com/adonath/gmmx/main/docs/_static/time-vs-n-features-fit.png) | 85 | 86 | For training the speedup is around ~5-6x on the same architecture and ~50x speedup on the GPU. In the bechmark I have forced both fitters to evaluate exactly the same number of iterations. However in general there is no guarantee that GMMX converges to the same solution as Scikit-Learn. But there are some tests in the `tests` folder that compare the results of the two implementations which shows good agreement. 87 | -------------------------------------------------------------------------------- /codecov.yaml: -------------------------------------------------------------------------------- 1 | coverage: 2 | range: 70..100 3 | round: down 4 | precision: 1 5 | status: 6 | project: 7 | default: 8 | target: 90% 9 | threshold: 0.5% 10 | -------------------------------------------------------------------------------- /docs/_static/gmmx-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adonath/gmmx/778f718fe7b302ff7b1e18ddb8f959a3dea5ddda/docs/_static/gmmx-logo.png -------------------------------------------------------------------------------- /docs/_static/time-vs-n-components-fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adonath/gmmx/778f718fe7b302ff7b1e18ddb8f959a3dea5ddda/docs/_static/time-vs-n-components-fit.png -------------------------------------------------------------------------------- /docs/_static/time-vs-n-components-predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adonath/gmmx/778f718fe7b302ff7b1e18ddb8f959a3dea5ddda/docs/_static/time-vs-n-components-predict.png -------------------------------------------------------------------------------- /docs/_static/time-vs-n-features-fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adonath/gmmx/778f718fe7b302ff7b1e18ddb8f959a3dea5ddda/docs/_static/time-vs-n-features-fit.png -------------------------------------------------------------------------------- /docs/_static/time-vs-n-features-predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adonath/gmmx/778f718fe7b302ff7b1e18ddb8f959a3dea5ddda/docs/_static/time-vs-n-features-predict.png -------------------------------------------------------------------------------- /docs/_static/time-vs-n-samples-fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adonath/gmmx/778f718fe7b302ff7b1e18ddb8f959a3dea5ddda/docs/_static/time-vs-n-samples-fit.png -------------------------------------------------------------------------------- /docs/_static/time-vs-n-samples-predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adonath/gmmx/778f718fe7b302ff7b1e18ddb8f959a3dea5ddda/docs/_static/time-vs-n-samples-predict.png -------------------------------------------------------------------------------- /docs/development.md: -------------------------------------------------------------------------------- 1 | ## Getting started with your project 2 | 3 | ### 1. Create a New Repository 4 | 5 | First, create a repository on GitHub with the same name as this project, and then run the following commands: 6 | 7 | ```bash 8 | git init -b main 9 | git add . 10 | git commit -m "init commit" 11 | git remote add origin git@github.com:adonath/gmmx.git 12 | git push -u origin main 13 | ``` 14 | 15 | ### 2. Set Up Your Development Environment 16 | 17 | Then, install the environment and the pre-commit hooks with 18 | 19 | ```bash 20 | make install 21 | ``` 22 | 23 | This will also generate your `uv.lock` file 24 | 25 | ### 3. Run the pre-commit hooks 26 | 27 | Initially, the CI/CD pipeline might be failing due to formatting issues. To resolve those run: 28 | 29 | ```bash 30 | uv run pre-commit run -a 31 | ``` 32 | 33 | ### 4. Commit the changes 34 | 35 | Lastly, commit the changes made by the two steps above to your repository. 36 | 37 | ```bash 38 | git add . 39 | git commit -m 'Fix formatting issues' 40 | git push origin main 41 | ``` 42 | 43 | You are now ready to start development on your project! 44 | The CI/CD pipeline will be triggered when you open a pull request, merge to main, or when you create a new release. 45 | 46 | To finalize the set-up for publishing to PyPI, see [here](https://fpgmaas.github.io/cookiecutter-uv/features/publishing/#set-up-for-pypi). 47 | For activating the automatic documentation with MkDocs, see [here](https://fpgmaas.github.io/cookiecutter-uv/features/mkdocs/#enabling-the-documentation-on-github). 48 | To enable the code coverage reports, see [here](https://fpgmaas.github.io/cookiecutter-uv/features/codecov/). 49 | 50 | ## Releasing a new version 51 | 52 | - Create an API Token on [PyPI](https://pypi.org/). 53 | - Add the API Token to your projects secrets with the name `PYPI_TOKEN` by visiting [this page](https://github.com/adonath/gmmx/settings/secrets/actions/new). 54 | - Create a [new release](https://github.com/adonath/gmmx/releases/new) on Github. 55 | - Create a new tag in the form `*.*.*`. 56 | 57 | For more details, see [here](https://fpgmaas.github.io/cookiecutter-uv/features/cicd/#how-to-trigger-a-release). 58 | 59 | --- 60 | 61 | Repository initiated with [fpgmaas/cookiecutter-uv](https://github.com/fpgmaas/cookiecutter-uv). 62 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # GMMX: Gaussian Mixture Models in Jax 2 | 3 | [![Release](https://img.shields.io/github/v/release/adonath/gmmx)](https://img.shields.io/github/v/release/adonath/gmmx) 4 | [![Build status](https://img.shields.io/github/actions/workflow/status/adonath/gmmx/main.yml?branch=main)](https://github.com/adonath/gmmx/actions/workflows/main.yml?query=branch%3Amain) 5 | [![Commit activity](https://img.shields.io/github/commit-activity/m/adonath/gmmx)](https://img.shields.io/github/commit-activity/m/adonath/gmmx) 6 | [![License](https://img.shields.io/github/license/adonath/gmmx)](https://img.shields.io/github/license/adonath/gmmx) 7 | 8 | A minimal implementation of Gaussian Mixture Models in Jax. 9 | -------------------------------------------------------------------------------- /docs/modules.md: -------------------------------------------------------------------------------- 1 | ::: gmmx.gmm 2 | -------------------------------------------------------------------------------- /examples/benchmarks/run-benchmark.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import importlib 3 | import json 4 | import logging 5 | import platform 6 | import sys 7 | import timeit 8 | from dataclasses import asdict, dataclass, field 9 | from functools import partial 10 | from itertools import product 11 | from pathlib import Path 12 | from typing import Optional 13 | 14 | import jax 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | from jax.lib import xla_bridge 18 | from scipy import stats 19 | 20 | from gmmx import EMFitter, GaussianMixtureModelJax 21 | 22 | jax.config.update("jax_enable_x64", True) 23 | 24 | 25 | PATH = Path(__file__).parent 26 | PATH_RESULTS = PATH / "results" 27 | RANDOM_STATE = np.random.RandomState(81737) 28 | N_AVERAGE = 10 29 | DPI = 180 30 | KEY = jax.random.PRNGKey(81737) 31 | 32 | MAX_ITER = 20 33 | TOL = 0 34 | 35 | PATH_TEMPLATE = "{user}-{machine}-{system}-{cpu}-{device-platform}" 36 | 37 | 38 | logging.basicConfig(level=logging.INFO) 39 | log = logging.getLogger(__name__) 40 | 41 | 42 | Array = list[float] 43 | 44 | 45 | def get_provenance(): 46 | """Compute provenance info about software and data used.""" 47 | env = { 48 | "user": getpass.getuser(), 49 | "machine": platform.machine(), 50 | "system": platform.system(), 51 | "cpu": platform.processor(), 52 | "device-platform": xla_bridge.get_backend().platform, 53 | } 54 | 55 | software = { 56 | "python-executable": sys.executable, 57 | "python-version": platform.python_version(), 58 | "jax-version": str(importlib.import_module("jax").__version__), 59 | "numpy-version": str(importlib.import_module("numpy").__version__), 60 | "sklearn-version": str(importlib.import_module("sklearn").__version__), 61 | } 62 | 63 | return { 64 | "env": env, 65 | "software": software, 66 | } 67 | 68 | 69 | def gpu_is_available(): 70 | """Check if a GPU is available""" 71 | try: 72 | jax.devices("gpu") 73 | except RuntimeError: 74 | return False 75 | 76 | return True 77 | 78 | 79 | @dataclass 80 | class BenchmarkResult: 81 | """Benchmark result""" 82 | 83 | n_samples: Array 84 | n_components: Array 85 | n_features: Array 86 | time_sklearn: Array 87 | time_jax: Array 88 | time_jax_gpu: Optional[Array] = None 89 | provenance: dict = field(default_factory=get_provenance) 90 | 91 | def write_json(self, path): 92 | """Write the benchmark result to a JSON file""" 93 | path = Path(path) 94 | path.parent.mkdir(parents=True, exist_ok=True) 95 | 96 | with path.open("w") as f: 97 | json.dump(asdict(self), f) 98 | 99 | @classmethod 100 | def read(cls, path): 101 | """Read the benchmark result from a JSON file""" 102 | with Path(path).open("r") as f: 103 | data = json.load(f) 104 | return cls(**data) 105 | 106 | 107 | @dataclass 108 | class BenchmarkSpec: 109 | """Benchmark specification""" 110 | 111 | filename_result: str 112 | n_components_grid: Array 113 | n_samples_grid: Array 114 | n_features_grid: Array 115 | x_axis: str 116 | title: str 117 | func_sklearn: Optional[callable] = None 118 | func_jax: Optional[callable] = None 119 | 120 | @property 121 | def path(self): 122 | """Absolute path to the benchmark result""" 123 | return ( 124 | PATH_RESULTS 125 | / PATH_TEMPLATE.format(**get_provenance()["env"]) 126 | / self.filename_result 127 | ) 128 | 129 | 130 | def predict_sklearn(gmm, x): 131 | """Predict the responsibilities""" 132 | return partial(gmm.predict, X=x) 133 | 134 | 135 | def predict_jax(gmm, x): 136 | """Measure the time to predict the responsibilities""" 137 | 138 | def func(): 139 | return gmm.predict(x=x).block_until_ready() 140 | 141 | return func 142 | 143 | 144 | def fit_sklearn(gmm, x): 145 | """Measure the time to fit the model""" 146 | 147 | def func(): 148 | result = gmm.fit(x) 149 | return result 150 | 151 | return func 152 | 153 | 154 | def fit_jax(gmm, x): 155 | """Measure the time to fit the model""" 156 | 157 | def func(): 158 | fitter = EMFitter(tol=TOL, max_iter=MAX_ITER) 159 | result = fitter.fit(x=x, gmm=gmm) 160 | return result 161 | 162 | return func 163 | 164 | 165 | SPECS_PREDICT = [ 166 | BenchmarkSpec( 167 | filename_result="time-vs-n-components-predict.json", 168 | n_components_grid=(2 ** np.arange(1, 7)).tolist(), 169 | n_samples_grid=[100_000], 170 | n_features_grid=[64], 171 | x_axis="n_components", 172 | title="Time vs Number of components", 173 | func_sklearn=predict_sklearn, 174 | func_jax=predict_jax, 175 | ), 176 | BenchmarkSpec( 177 | filename_result="time-vs-n-features-predict.json", 178 | n_components_grid=[128], 179 | n_samples_grid=[100_000], 180 | n_features_grid=(2 ** np.arange(1, 7)).tolist(), 181 | x_axis="n_features", 182 | title="Time vs Number of features", 183 | func_sklearn=predict_sklearn, 184 | func_jax=predict_jax, 185 | ), 186 | BenchmarkSpec( 187 | filename_result="time-vs-n-samples-predict.json", 188 | n_components_grid=[128], 189 | n_samples_grid=(2 ** np.arange(5, 18)).tolist(), 190 | n_features_grid=[64], 191 | x_axis="n_samples", 192 | title="Time vs Number of samples", 193 | func_sklearn=predict_sklearn, 194 | func_jax=predict_jax, 195 | ), 196 | ] 197 | 198 | SPECS_FIT = [ 199 | BenchmarkSpec( 200 | filename_result="time-vs-n-components-fit.json", 201 | n_components_grid=(2 ** np.arange(1, 6)).tolist(), 202 | n_samples_grid=[65536], 203 | n_features_grid=[32], 204 | x_axis="n_components", 205 | title="Time vs Number of components", 206 | func_sklearn=fit_sklearn, 207 | func_jax=fit_jax, 208 | ), 209 | BenchmarkSpec( 210 | filename_result="time-vs-n-features-fit.json", 211 | n_components_grid=[64], 212 | n_samples_grid=[65536], 213 | n_features_grid=(2 ** np.arange(1, 6)).tolist(), 214 | x_axis="n_features", 215 | title="Time vs Number of features", 216 | func_sklearn=fit_sklearn, 217 | func_jax=fit_jax, 218 | ), 219 | BenchmarkSpec( 220 | filename_result="time-vs-n-samples-fit.json", 221 | n_components_grid=[64], 222 | n_samples_grid=(2 ** np.arange(8, 17)).tolist(), 223 | n_features_grid=[32], 224 | x_axis="n_samples", 225 | title="Time vs Number of samples", 226 | func_sklearn=fit_sklearn, 227 | func_jax=fit_jax, 228 | ), 229 | ] 230 | 231 | 232 | def create_random_gmm(n_components, n_features, random_state=RANDOM_STATE, device=None): 233 | """Create a random Gaussian mixture model""" 234 | means = random_state.uniform(-10, 10, (n_components, n_features)) 235 | 236 | # the whishart distribution creates a positive semi-definite matrix 237 | covariances = stats.wishart.rvs( 238 | df=n_features, 239 | scale=np.eye(n_features) / n_features, 240 | size=n_components, 241 | random_state=random_state, 242 | ) 243 | 244 | weights = random_state.uniform(0, 1, n_components) 245 | weights /= weights.sum() 246 | 247 | return GaussianMixtureModelJax.from_squeezed( 248 | means=jax.device_put(means, device=device), 249 | covariances=jax.device_put(covariances, device=device), 250 | weights=jax.device_put(weights, device=device), 251 | ) 252 | 253 | 254 | def create_random_data(n_samples, n_features, random_state=RANDOM_STATE, device=None): 255 | """Create random data""" 256 | return random_state.uniform(-10, 10, (n_samples, n_features)) 257 | 258 | 259 | def get_meta_str(result, x_axis): 260 | """Get the metadata for the benchmark result""" 261 | if x_axis == "n_samples": 262 | meta = {"n_components": result.n_components, "n_features": result.n_features} 263 | 264 | elif x_axis == "n_components": 265 | meta = {"n_samples": result.n_samples, "n_features": result.n_features} 266 | 267 | elif x_axis == "n_features": 268 | meta = {"n_samples": result.n_samples, "n_components": result.n_components} 269 | else: 270 | message = f"Invalid x_axis: {x_axis}" 271 | raise ValueError(message) 272 | 273 | return ", ".join(f"{k}={v[0]}" for k, v in meta.items()) 274 | 275 | 276 | def plot_result(result, x_axis, filename, title=""): 277 | """Plot the benchmark result""" 278 | log.info(f"Plotting {filename}") 279 | fig, ax = plt.subplots(1, 1, figsize=(6, 4)) 280 | 281 | x = getattr(result, x_axis) 282 | meta = get_meta_str(result, x_axis) 283 | 284 | color = "#F1C44D" 285 | ax.plot(x, result.time_sklearn, label=f"sklearn ({meta})", color=color) 286 | ax.scatter(x, result.time_sklearn, color=color) 287 | 288 | color = "#405087" 289 | ax.plot(x, result.time_jax, label=f"jax ({meta})", color=color, zorder=3) 290 | ax.scatter(x, result.time_jax, color=color, zorder=3) 291 | 292 | if result.time_jax_gpu: 293 | color = "#E58336" 294 | ax.plot( 295 | x, result.time_jax_gpu, label=f"jax-gpu ({meta})", color=color, zorder=5 296 | ) 297 | ax.scatter(x, result.time_jax_gpu, color=color, zorder=5) 298 | 299 | ax.set_title(title) 300 | ax.set_xlabel(x_axis) 301 | ax.set_ylabel("Time (s)") 302 | ax.semilogx() 303 | ax.semilogy() 304 | ax.legend() 305 | 306 | log.info(f"Writing {filename}") 307 | plt.savefig(filename, dpi=DPI) 308 | 309 | 310 | def measure_time(func): 311 | """Measure the time to run a function""" 312 | timer = timeit.Timer(func) 313 | return timer.timeit(N_AVERAGE) 314 | 315 | 316 | def measure_time_sklearn_vs_jax( 317 | n_components_grid, 318 | n_samples_grid, 319 | n_features_grid, 320 | init_func_sklearn=predict_sklearn, 321 | init_func_jax=predict_jax, 322 | ): 323 | """Measure the time to predict the responsibilities for sklearn and jax""" 324 | time_sklearn, time_jax, time_jax_gpu = [], [], [] 325 | 326 | for n_component, n_samples, n_features in product( 327 | n_components_grid, n_samples_grid, n_features_grid 328 | ): 329 | log.info( 330 | f"Running n_components={n_component}, n_samples={n_samples}, n_features={n_features}" 331 | ) 332 | gmm = create_random_gmm(n_component, n_features, device=jax.devices("cpu")[0]) 333 | x, _ = gmm.to_sklearn(random_state=RANDOM_STATE).sample(n_samples) 334 | 335 | func_sklearn = init_func_sklearn(gmm.to_sklearn(tol=TOL, max_iter=MAX_ITER), x) 336 | func_jax = init_func_jax(gmm, jax.device_put(x, device=jax.devices("cpu")[0])) 337 | 338 | time_sklearn.append(measure_time(func_sklearn)) 339 | time_jax.append(measure_time(func_jax)) 340 | 341 | if gpu_is_available(): 342 | gmm_gpu = create_random_gmm( 343 | n_component, n_features, device=jax.devices("gpu")[0] 344 | ) 345 | x_gpu = jax.device_put(x, device=jax.devices("gpu")[0]) 346 | func_jax = init_func_jax(gmm_gpu, x_gpu) 347 | time_jax_gpu.append(measure_time(func_jax)) 348 | 349 | return BenchmarkResult( 350 | n_samples=n_samples_grid, 351 | n_components=n_components_grid, 352 | n_features=n_features_grid, 353 | time_sklearn=time_sklearn, 354 | time_jax=time_jax, 355 | time_jax_gpu=time_jax_gpu or None, 356 | ) 357 | 358 | 359 | def run_benchmark_from_spec(spec): 360 | """Run a benchmark from a specification""" 361 | if not spec.path.exists(): 362 | result = measure_time_sklearn_vs_jax( 363 | spec.n_components_grid, 364 | spec.n_samples_grid, 365 | spec.n_features_grid, 366 | init_func_sklearn=spec.func_sklearn, 367 | init_func_jax=spec.func_jax, 368 | ) 369 | result.write_json(spec.path) 370 | 371 | result = BenchmarkResult.read(spec.path) 372 | plot_result( 373 | result, 374 | x_axis=spec.x_axis, 375 | filename=spec.path.with_suffix(".png"), 376 | title=spec.title, 377 | ) 378 | 379 | 380 | if __name__ == "__main__": 381 | for spec in SPECS_PREDICT: 382 | run_benchmark_from_spec(spec) 383 | 384 | for spec in SPECS_FIT: 385 | run_benchmark_from_spec(spec) 386 | -------------------------------------------------------------------------------- /examples/moon-density-estimation.py: -------------------------------------------------------------------------------- 1 | # Density Estimation of Moon Data. This exampled is adapted from "In Depth: Gaussian Mixture Models" chapter of 2 | # the Python Data Science Handbook by Jake VanderPlas. The original code can be found 3 | # at https://jakevdp.github.io/PythonDataScienceHandbook/05.12-gaussian-mixtures.html 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from matplotlib.patches import Ellipse 8 | from sklearn.datasets import make_moons 9 | 10 | from gmmx import EMFitter, GaussianMixtureModelJax 11 | 12 | 13 | def draw_ellipse(position, covariance, ax=None, **kwargs): 14 | """Draw an ellipse with a given position and covariance""" 15 | ax = ax or plt.gca() 16 | 17 | # Convert covariance to principal axes 18 | if covariance.shape == (2, 2): 19 | U, s, Vt = np.linalg.svd(covariance) 20 | angle = np.degrees(np.arctan2(U[1, 0], U[0, 0])) 21 | width, height = 2 * np.sqrt(s) 22 | else: 23 | angle = 0 24 | width, height = 2 * np.sqrt(covariance) 25 | 26 | # Draw the Ellipse 27 | for nsig in range(1, 4): 28 | ax.add_patch( 29 | Ellipse( 30 | xy=position, 31 | width=nsig * width, 32 | height=nsig * height, 33 | angle=angle, 34 | **kwargs, 35 | ) 36 | ) 37 | 38 | 39 | def plot_gmm(gmm, X, label=True, ax=None): 40 | """Plot the GMM""" 41 | ax = ax or plt.gca() 42 | 43 | labels = gmm.predict(X) 44 | 45 | if label: 46 | ax.scatter(X[:, 0], X[:, 1], c=labels, s=10, cmap="viridis", zorder=2) 47 | else: 48 | ax.scatter(X[:, 0], X[:, 1], s=10, zorder=2) 49 | ax.axis("equal") 50 | 51 | w_factor = 0.2 / gmm.weights_numpy.max() 52 | for pos, covar, w in zip( 53 | gmm.means_numpy, gmm.covariances.values_numpy, gmm.weights_numpy 54 | ): 55 | draw_ellipse(pos, covar, alpha=w * w_factor, ax=ax) 56 | 57 | 58 | def fit_and_plot_gmm(n_components, ax=None): 59 | """Fit and plot a GMM""" 60 | ax = ax or plt.gca() 61 | x, y = make_moons(200, noise=0.05, random_state=0) 62 | ax.scatter(x[:, 0], x[:, 1]) 63 | ax.text( 64 | 0.95, 65 | 0.9, 66 | f"N Components: {n_components}", 67 | ha="right", 68 | va="bottom", 69 | transform=ax.transAxes, 70 | ) 71 | ax.set_xticks([]) 72 | ax.set_yticks([]) 73 | 74 | gmm = GaussianMixtureModelJax.from_k_means(x, n_components=n_components) 75 | 76 | fitter = EMFitter(tol=1e-4, max_iter=100) 77 | result = fitter.fit(x=x, gmm=gmm) 78 | 79 | plot_gmm(result.gmm, x, ax=ax) 80 | return ax 81 | 82 | 83 | if __name__ == "__main__": 84 | fig, axes = plt.subplots(4, 4, figsize=(9, 9)) 85 | 86 | for idx, ax in enumerate(axes.flat): 87 | ax = fit_and_plot_gmm(idx + 1, ax=ax) 88 | 89 | plt.tight_layout() 90 | plt.show() 91 | -------------------------------------------------------------------------------- /examples/simple-1d-fit.py: -------------------------------------------------------------------------------- 1 | # A simple example of fitting a Gaussian Mixture Model (GMM) to a 1D dataset. 2 | 3 | import jax 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from scipy.stats import norm 7 | from sklearn.mixture import GaussianMixture 8 | 9 | from gmmx import EMFitter, GaussianMixtureModelJax 10 | 11 | 12 | def plot_gmm(ax, weights, means, covariances, **kwargs): 13 | """Plot a GMM""" 14 | x = np.linspace(-5, 15, 1000) 15 | y = np.zeros_like(x) 16 | 17 | for w, m, c in zip(weights, means, covariances): 18 | y += w * norm.pdf(x, m, np.sqrt(c)) 19 | 20 | ax.plot(x, y, **kwargs) 21 | 22 | 23 | def fit_and_plot_jax(ax, x, **kwargs): 24 | """Fit and plot Jax GMM""" 25 | gmm = GaussianMixtureModelJax.from_squeezed( 26 | means=np.array([[1.0], [8.0]]), 27 | covariances=np.array([[[1.0]], [[1.0]]]), 28 | weights=np.array([0.5, 0.5]), 29 | ) 30 | 31 | fitter = EMFitter(max_iter=100, tol=0.1) 32 | result = fitter.fit(x=x, gmm=gmm) 33 | 34 | plot_gmm( 35 | ax=ax, 36 | weights=result.gmm.weights.flatten(), 37 | means=result.gmm.means.flatten(), 38 | covariances=result.gmm.covariances.values.flatten(), 39 | **kwargs, 40 | ) 41 | 42 | 43 | def fit_and_plot_sklearn(ax, x, **kwargs): 44 | """Fit and plot sklearn GMM""" 45 | gmm_sk = GaussianMixture( 46 | n_components=2, 47 | max_iter=100, 48 | tol=0.1, 49 | means_init=np.array([[1.0], [8.0]]), 50 | weights_init=np.array([0.5, 0.5]), 51 | precisions_init=np.array([[[1.0]], [[1.0]]]), 52 | ) 53 | gmm_sk.fit(x) 54 | 55 | plot_gmm( 56 | ax=ax, 57 | weights=gmm_sk.weights_, 58 | means=gmm_sk.means_.flatten(), 59 | covariances=gmm_sk.covariances_.flatten(), 60 | **kwargs, 61 | ) 62 | 63 | 64 | if __name__ == "__main__": 65 | gmm_jax = GaussianMixtureModelJax.from_squeezed( 66 | means=np.array([[0], [10]]), 67 | covariances=np.array([[[2]], [[1]]]), 68 | weights=np.array([0.2, 0.8]), 69 | ) 70 | 71 | n_samples = 100_000 72 | 73 | key = jax.random.PRNGKey(0) 74 | x = gmm_jax.sample(key, n_samples=n_samples) 75 | 76 | ax = plt.subplot(111) 77 | ax.hist(x, bins=100, density=True) 78 | fit_and_plot_jax(ax=ax, x=x, label="gmmx", ls="-") 79 | fit_and_plot_sklearn(ax=ax, x=x, label="sklearn", ls="--") 80 | ax.set_ylabel("PDF (normalized)") 81 | ax.set_xlabel("x") 82 | 83 | plt.legend() 84 | plt.show() 85 | -------------------------------------------------------------------------------- /gmmx/__init__.py: -------------------------------------------------------------------------------- 1 | from .fit import EMFitter as EMFitter 2 | from .fit import EMFitterResult as EMFitterResult 3 | from .gmm import GaussianMixtureModelJax as GaussianMixtureModelJax 4 | from .gmm import GaussianMixtureSKLearn as GaussianMixtureSKLearn 5 | -------------------------------------------------------------------------------- /gmmx/fit.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import jax 4 | from jax import numpy as jnp 5 | 6 | from .gmm import Axis, GaussianMixtureModelJax 7 | from .utils import register_dataclass_jax 8 | 9 | __all__ = ["EMFitter", "EMFitterResult"] 10 | 11 | 12 | @register_dataclass_jax( 13 | data_fields=[ 14 | "x", 15 | "gmm", 16 | "n_iter", 17 | "log_likelihood", 18 | "log_likelihood_diff", 19 | "converged", 20 | ] 21 | ) 22 | @dataclass 23 | class EMFitterResult: 24 | """Expectation-Maximization Fitter Result 25 | 26 | Attributes 27 | ---------- 28 | x : jax.array 29 | Feature vectors 30 | gmm : GaussianMixtureModelJax 31 | Gaussian mixture model instance. 32 | n_iter : int 33 | Number of iterations 34 | log_likelihood : jax.array 35 | Log-likelihood of the data 36 | log_likelihood_diff : jax.array 37 | Difference in log-likelihood with respect to the previous iteration 38 | converged : bool 39 | Whether the algorithm converged 40 | """ 41 | 42 | x: jax.Array 43 | gmm: GaussianMixtureModelJax 44 | n_iter: int 45 | log_likelihood: jax.Array 46 | log_likelihood_diff: jax.Array 47 | converged: bool 48 | 49 | 50 | @register_dataclass_jax(meta_fields=["max_iter", "tol", "reg_covar"]) 51 | @dataclass 52 | class EMFitter: 53 | """Expectation-Maximization Fitter 54 | 55 | Attributes 56 | ---------- 57 | max_iter : int 58 | Maximum number of iterations 59 | tol : float 60 | Tolerance 61 | reg_covar : float 62 | Regularization for covariance matrix 63 | """ 64 | 65 | max_iter: int = 100 66 | tol: float = 1e-3 67 | reg_covar: float = 1e-6 68 | 69 | def e_step( 70 | self, x: jax.Array, gmm: GaussianMixtureModelJax 71 | ) -> tuple[jax.Array, jax.Array]: 72 | """Expectation step 73 | 74 | Parameters 75 | ---------- 76 | x : jax.array 77 | Feature vectors 78 | gmm : GaussianMixtureModelJax 79 | Gaussian mixture model instance. 80 | 81 | Returns 82 | ------- 83 | log_likelihood : jax.array 84 | Log-likelihood of the data 85 | """ 86 | log_prob = gmm.log_prob(x) 87 | log_prob_norm = jax.scipy.special.logsumexp( 88 | log_prob, axis=Axis.components, keepdims=True 89 | ) 90 | log_resp = log_prob - log_prob_norm 91 | return jnp.mean(log_prob_norm), log_resp 92 | 93 | def m_step( 94 | self, x: jax.Array, gmm: GaussianMixtureModelJax, log_resp: jax.Array 95 | ) -> GaussianMixtureModelJax: 96 | """Maximization step 97 | 98 | Parameters 99 | ---------- 100 | x : jax.array 101 | Feature vectors 102 | gmm : GaussianMixtureModelJax 103 | Gaussian mixture model instance. 104 | log_resp : jax.array 105 | Logarithm of the responsibilities 106 | 107 | Returns 108 | ------- 109 | gmm : GaussianMixtureModelJax 110 | Updated Gaussian mixture model instance. 111 | """ 112 | x = jnp.expand_dims(x, axis=(Axis.components, Axis.features_covar)) 113 | return gmm.from_responsibilities( 114 | x, 115 | jnp.exp(log_resp), 116 | reg_covar=self.reg_covar, 117 | covariance_type=gmm.covariances.type, 118 | ) 119 | 120 | @jax.jit 121 | def fit(self, x: jax.Array, gmm: GaussianMixtureModelJax) -> EMFitterResult: 122 | """Fit the model to the data 123 | 124 | Parameters 125 | ---------- 126 | x : jax.array 127 | Feature vectors 128 | gmm : GaussianMixtureModelJax 129 | Gaussian mixture model instance. 130 | 131 | Returns 132 | ------- 133 | result : EMFitterResult 134 | Fitting result 135 | """ 136 | 137 | def em_step( 138 | args: tuple[jax.Array, GaussianMixtureModelJax, int, jax.Array, jax.Array], 139 | ) -> tuple: 140 | """EM step function""" 141 | x, gmm, n_iter, log_likelihood_prev, _ = args 142 | log_likelihood, log_resp = self.e_step(x, gmm) 143 | gmm = self.m_step(x, gmm, log_resp) 144 | return ( 145 | x, 146 | gmm, 147 | n_iter + 1, 148 | log_likelihood, 149 | jnp.abs(log_likelihood - log_likelihood_prev), 150 | ) 151 | 152 | def em_cond( 153 | args: tuple[jax.Array, GaussianMixtureModelJax, int, jax.Array, jax.Array], 154 | ) -> jax.Array: 155 | """EM stop condition function""" 156 | _, _, n_iter, _, log_likelihood_diff = args 157 | return (n_iter < self.max_iter) & (log_likelihood_diff >= self.tol) 158 | 159 | result = jax.lax.while_loop( 160 | cond_fun=em_cond, 161 | body_fun=em_step, 162 | init_val=(x, gmm, 0, jnp.asarray(jnp.inf), jnp.array(jnp.inf)), 163 | ) 164 | return EMFitterResult(*result, converged=result[2] < self.max_iter) 165 | -------------------------------------------------------------------------------- /gmmx/gmm.py: -------------------------------------------------------------------------------- 1 | """Some notes on the implementation: 2 | 3 | I have not tried to keep the implementation close to the sklearn implementation. 4 | I have rather tried to realize my own best practices for code structure and 5 | clarity. Here are some more detailed thoughts: 6 | 7 | 1. **Use dataclasses for the model representation**: this reduces the amount of 8 | boilerplate code for initialization and in combination with the `register_dataclass_jax` 9 | decorator it integrates seamleassly with JAX. 10 | 11 | 2. **Split up the different covariance types into different classes**: this avoids 12 | the need for multiple blocks of if-else statements. 13 | 14 | 3. **Use a registry for the covariance types**: This allows for easy extensibility 15 | by the user. 16 | 17 | 3. **Remove Python loops**: I have not checked the reason why the sklearn implementation 18 | still uses Python loops, but my guess is that it is simpler(?) and when there are 19 | operations such as matmul and cholesky decomposition, the Python loop does not become 20 | the bottleneck. In JAX, however, it is usually better to avoid Python loops and let 21 | the JAX compiler take care of the optimization instead. 22 | 23 | 4. **Rely on same internal array dimension and axis order**: 24 | Internally all(!) involved arrays (even 1d weights) are represented as 4d arrays 25 | with the axes (batch, components, features, features_covar). This makes it much 26 | easier to write array operations and rely on broadcasting. This minimizes the 27 | amount of in-line reshaping and in-line extension of dimensions. If you think 28 | about it, this is most likely the way how array programming was meant to be used 29 | in first place. Yet, I have rarely seen this in practice, probably because people 30 | struggle with the additional dimensions in the beginning. However once you get 31 | used to it, it is much easier to write and understand the code! The only downside 32 | is that the user has to face the additional "empty" dimensions when directy working 33 | with the arrays. For convenience I have introduced properties, that return the arrays 34 | with the empty dimensions removed. Another downside maybe that you have to use `keepdims=True` 35 | more often, but there I would even argue that the default behavior in the array libraries 36 | should change. 37 | 38 | 5. **"Poor-peoples" named axes**: The axis order convention is defined in the 39 | code in the `Axis` enum, which maps the name to the integer dimension. Later I 40 | can use, e.g. `Axis.batch` to refer to the batch axis in the code. This is the 41 | simplest way to come close to named axes in any array library! So you can use 42 | e.g. `jnp.sum(x, axes=Axis.components)` to sum over the components axis. I found 43 | this to be a very powerful concept that improves the code clarity a lot, yet I 44 | have not seen it often in other libraries. Of course there is `einops` but the 45 | simple enum works just fine in many cases! 46 | 47 | """ 48 | 49 | from __future__ import annotations 50 | 51 | from dataclasses import dataclass, field 52 | from enum import Enum 53 | from functools import partial 54 | from typing import Any, ClassVar, Union 55 | 56 | import jax 57 | import numpy as np 58 | from jax import numpy as jnp 59 | from jax import scipy as jsp 60 | 61 | from gmmx.utils import register_dataclass_jax 62 | 63 | __all__ = [ 64 | "Axis", 65 | "CovarianceType", 66 | "DiagCovariances", 67 | "FullCovariances", 68 | "GaussianMixtureModelJax", 69 | "GaussianMixtureSKLearn", 70 | ] 71 | 72 | 73 | AnyArray = Union[np.typing.NDArray, jax.Array] 74 | Device = Union[str, None] 75 | 76 | 77 | class CovarianceType(str, Enum): 78 | """Convariance type""" 79 | 80 | full = "full" 81 | diag = "diag" 82 | 83 | 84 | class Axis(int, Enum): 85 | """Internal axis order""" 86 | 87 | batch = 0 88 | components = 1 89 | features = 2 90 | features_covar = 3 91 | 92 | 93 | def check_shape(array: jax.Array, expected: tuple[int | None, ...]) -> None: 94 | """Check shape of array""" 95 | if array.dtype != jnp.float32: 96 | message = f"Expected float32, got {array.dtype}" 97 | raise ValueError(message) 98 | 99 | if len(array.shape) != len(expected): 100 | message = f"Expected shape {expected}, got {array.shape}" 101 | raise ValueError(message) 102 | 103 | for n, m in zip(array.shape, expected): 104 | if m is not None and n != m: 105 | message = f"Expected shape {expected}, got {array.shape}" 106 | raise ValueError(message) 107 | 108 | 109 | @register_dataclass_jax(data_fields=["values"]) 110 | @dataclass 111 | class FullCovariances: 112 | """Full covariance matrix 113 | 114 | Attributes 115 | ---------- 116 | values : jax.array 117 | Covariance values. Expected shape is (1, n_components, n_features, n_features) 118 | """ 119 | 120 | values: jax.Array 121 | type: ClassVar[CovarianceType] = CovarianceType.full 122 | 123 | def __post_init__(self) -> None: 124 | check_shape(self.values, (1, None, None, None)) 125 | 126 | @classmethod 127 | def from_squeezed(cls, values: AnyArray) -> FullCovariances: 128 | """Create a covariance matrix from squeezed array 129 | 130 | Parameters 131 | ---------- 132 | values : jax.Array ot np.array 133 | Covariance values. Expected shape is (n_components, n_features, n_features) 134 | 135 | Returns 136 | ------- 137 | covariances : FullCovariances 138 | Covariance matrix instance. 139 | """ 140 | if values.ndim != 3: 141 | message = f"Expected array of shape (n_components, n_features, n_features), got {values.shape}" 142 | raise ValueError(message) 143 | 144 | return cls(values=jnp.expand_dims(values, axis=Axis.batch)) 145 | 146 | @property 147 | def values_numpy(self) -> np.typing.NDArray: 148 | """Covariance as numpy array""" 149 | return np.squeeze(np.asarray(self.values), axis=Axis.batch) 150 | 151 | @property 152 | def values_dense(self) -> jax.Array: 153 | """Covariance as dense matrix""" 154 | return self.values 155 | 156 | @property 157 | def precisions_cholesky_numpy(self) -> np.typing.NDArray: 158 | """Compute precision matrices""" 159 | return np.squeeze(np.asarray(self.precisions_cholesky), axis=Axis.batch) 160 | 161 | @classmethod 162 | def create( 163 | cls, n_components: int, n_features: int, device: Device = None 164 | ) -> FullCovariances: 165 | """Create covariance matrix 166 | 167 | By default the covariance matrix is set to the identity matrix. 168 | 169 | Parameters 170 | ---------- 171 | n_components : int 172 | Number of components 173 | n_features : int 174 | Number of features 175 | device : str, optional 176 | Device, by default None 177 | 178 | Returns 179 | ------- 180 | covariances : FullCovariances 181 | Covariance matrix instance. 182 | """ 183 | identity = jnp.expand_dims( 184 | jnp.eye(n_features), axis=(Axis.batch, Axis.components) 185 | ) 186 | 187 | values = jnp.repeat(identity, n_components, axis=Axis.components) 188 | values = jax.device_put(values, device=device) 189 | return cls(values=values) 190 | 191 | def log_prob(self, x: jax.Array, means: jax.Array) -> jax.Array: 192 | """Compute log likelihood from the covariance for a given feature vector 193 | 194 | Parameters 195 | ---------- 196 | x : jax.array 197 | Feature vectors 198 | means : jax.array 199 | Means of the components 200 | 201 | Returns 202 | ------- 203 | log_prob : jax.array 204 | Log likelihood 205 | """ 206 | precisions_cholesky = self.precisions_cholesky 207 | 208 | y = jnp.matmul(x.mT, precisions_cholesky) - jnp.matmul( 209 | means.mT, precisions_cholesky 210 | ) 211 | return jnp.sum( 212 | jnp.square(y), 213 | axis=(Axis.features, Axis.features_covar), 214 | keepdims=True, 215 | ) 216 | 217 | @classmethod 218 | def from_responsibilities( 219 | cls, 220 | x: jax.Array, 221 | means: jax.Array, 222 | resp: jax.Array, 223 | nk: jax.Array, 224 | reg_covar: float, 225 | ) -> FullCovariances: 226 | """Estimate updated covariance matrix from data 227 | 228 | Parameters 229 | ---------- 230 | x : jax.array 231 | Feature vectors 232 | means : jax.array 233 | Means of the components 234 | resp : jax.array 235 | Responsibilities 236 | nk : jax.array 237 | Number of samples in each component 238 | reg_covar : float 239 | Regularization for the covariance matrix 240 | 241 | Returns 242 | ------- 243 | covariances : FullCovariances 244 | Updated covariance matrix instance. 245 | """ 246 | diff = x - means 247 | axes = (Axis.features_covar, Axis.components, Axis.features, Axis.batch) 248 | diff = jnp.transpose(diff, axes=axes) 249 | resp = jnp.transpose(resp, axes=axes) 250 | values = jnp.matmul(resp * diff, diff.mT) / nk 251 | idx = jnp.arange(x.shape[Axis.features]) 252 | values = values.at[:, :, idx, idx].add(reg_covar) 253 | return cls(values=values) 254 | 255 | @property 256 | def n_components(self) -> int: 257 | """Number of components""" 258 | return self.values.shape[Axis.components] 259 | 260 | @property 261 | def n_features(self) -> int: 262 | """Number of features""" 263 | return self.values.shape[Axis.features] 264 | 265 | @property 266 | def n_parameters(self) -> int: 267 | """Number of parameters""" 268 | return int(self.n_components * self.n_features * (self.n_features + 1) / 2.0) 269 | 270 | @property 271 | def log_det_cholesky(self) -> jax.Array: 272 | """Log determinant of the cholesky decomposition""" 273 | diag = jnp.trace( 274 | jnp.log(self.precisions_cholesky), 275 | axis1=Axis.features, 276 | axis2=Axis.features_covar, 277 | ) 278 | return jnp.expand_dims(diag, axis=(Axis.features, Axis.features_covar)) 279 | 280 | @property 281 | def precisions_cholesky(self) -> jax.Array: 282 | """Compute precision matrices""" 283 | cov_chol = jsp.linalg.cholesky(self.values, lower=True) 284 | 285 | identity = jnp.expand_dims( 286 | jnp.eye(self.n_features), axis=(Axis.batch, Axis.components) 287 | ) 288 | 289 | b = jnp.repeat(identity, self.n_components, axis=Axis.components) 290 | precisions_chol = jsp.linalg.solve_triangular(cov_chol, b, lower=True) 291 | return precisions_chol.mT 292 | 293 | @classmethod 294 | def from_precisions(cls, precisions: AnyArray) -> FullCovariances: 295 | """Create covariance matrix from precision matrices""" 296 | values = jsp.linalg.inv(precisions) 297 | return cls.from_squeezed(values=values) 298 | 299 | 300 | @register_dataclass_jax(data_fields=["values"]) 301 | @dataclass 302 | class DiagCovariances: 303 | """Diagonal covariance matrices""" 304 | 305 | values: jax.Array 306 | type: ClassVar[CovarianceType] = CovarianceType.diag 307 | 308 | def __post_init__(self) -> None: 309 | check_shape(self.values, (1, None, None, 1)) 310 | 311 | @property 312 | def values_dense(self) -> jax.Array: 313 | """Covariance as dense matrix""" 314 | values = jnp.zeros((1, self.n_components, self.n_features, self.n_features)) 315 | idx = jnp.arange(self.n_features) 316 | covar_diag = jnp.squeeze(self.values, axis=(Axis.batch, Axis.features_covar)) 317 | return values.at[:, :, idx, idx].set(covar_diag) 318 | 319 | @classmethod 320 | def from_squeezed(cls, values: AnyArray) -> DiagCovariances: 321 | """Create a diagonal covariance matrix from squeezed array 322 | 323 | Parameters 324 | ---------- 325 | values : jax.Array ot np.array 326 | Covariance values. Expected shape is (n_components, n_features) 327 | 328 | Returns 329 | ------- 330 | covariances : FullCovariances 331 | Covariance matrix instance. 332 | """ 333 | if values.ndim != 2: 334 | message = f"Expected array of shape (n_components, n_features), got {values.shape}" 335 | raise ValueError(message) 336 | 337 | return cls( 338 | values=jnp.expand_dims(values, axis=(Axis.batch, Axis.features_covar)) 339 | ) 340 | 341 | @property 342 | def n_components(self) -> int: 343 | """Number of components""" 344 | return self.values.shape[Axis.components] 345 | 346 | @property 347 | def n_features(self) -> int: 348 | """Number of features""" 349 | return self.values.shape[Axis.features] 350 | 351 | @property 352 | def n_parameters(self) -> int: 353 | """Number of parameters""" 354 | return int(self.n_components * self.n_features) 355 | 356 | @classmethod 357 | def from_responsibilities( 358 | cls, 359 | x: jax.Array, 360 | means: jax.Array, 361 | resp: jax.Array, 362 | nk: jax.Array, 363 | reg_covar: float, 364 | ) -> DiagCovariances: 365 | """Estimate updated covariance matrix from data 366 | 367 | Parameters 368 | ---------- 369 | x : jax.array 370 | Feature vectors 371 | means : jax.array 372 | Means of the components 373 | resp : jax.array 374 | Responsibilities 375 | nk : jax.array 376 | Number of samples in each component 377 | reg_covar : float 378 | Regularization for the covariance matrix 379 | 380 | Returns 381 | ------- 382 | covariances : FullCovariances 383 | Updated covariance matrix instance. 384 | """ 385 | x_squared_mean = jnp.sum(resp * x**2, axis=Axis.batch, keepdims=True) / nk 386 | values = x_squared_mean - means**2 + reg_covar 387 | return cls(values=values) 388 | 389 | @property 390 | def precisions_cholesky_sparse(self) -> jax.Array: 391 | """Compute precision matrices""" 392 | return jnp.sqrt(1.0 / self.values).mT 393 | 394 | @property 395 | def precisions_cholesky_numpy(self) -> np.typing.NDArray: 396 | """Compute precision matrices""" 397 | return np.squeeze( 398 | np.asarray(self.precisions_cholesky_sparse), 399 | axis=(Axis.batch, Axis.features), 400 | ) 401 | 402 | @property 403 | def values_numpy(self) -> np.typing.NDArray: 404 | """Covariance as numpy array""" 405 | return np.squeeze( 406 | np.asarray(self.values), axis=(Axis.batch, Axis.features_covar) 407 | ) 408 | 409 | @property 410 | def log_det_cholesky(self) -> jax.Array: 411 | """Log determinant of the cholesky decomposition""" 412 | return jnp.sum( 413 | jnp.log(self.precisions_cholesky_sparse), 414 | axis=(Axis.features, Axis.features_covar), 415 | keepdims=True, 416 | ) 417 | 418 | def log_prob(self, x: jax.Array, means: jax.Array) -> jax.Array: 419 | """Compute log likelihood from the covariance for a given feature vector""" 420 | precisions_cholesky = self.precisions_cholesky_sparse 421 | y = (x.mT * precisions_cholesky) - (means.mT * precisions_cholesky) 422 | return jnp.sum( 423 | jnp.square(y), 424 | axis=(Axis.features, Axis.features_covar), 425 | keepdims=True, 426 | ) 427 | 428 | @classmethod 429 | def from_precisions(cls, precisions: AnyArray) -> DiagCovariances: 430 | """Create covariance matrix from precision matrices""" 431 | values = 1.0 / precisions 432 | return cls.from_squeezed(values=values) 433 | 434 | 435 | COVARIANCE: dict[CovarianceType, Any] = { 436 | FullCovariances.type: FullCovariances, 437 | DiagCovariances.type: DiagCovariances, 438 | } 439 | 440 | # keep this mapping separate, as names in sklearn might change 441 | SKLEARN_COVARIANCE_TYPE: dict[Any, str] = { 442 | FullCovariances: "full", 443 | DiagCovariances: "diag", 444 | } 445 | 446 | 447 | @register_dataclass_jax(data_fields=["weights", "means", "covariances"]) 448 | @dataclass 449 | class GaussianMixtureModelJax: 450 | """Gaussian Mixture Model 451 | 452 | Attributes 453 | ---------- 454 | weights : jax.array 455 | Weights of each component. Expected shape is (1, n_components, 1, 1) 456 | means : jax.array 457 | Mean of each component. Expected shape is (1, n_components, n_features, 1) 458 | covariances : jax.array 459 | Covariance of each component. Expected shape is (1, n_components, n_features, n_features) 460 | """ 461 | 462 | weights: jax.Array 463 | means: jax.Array 464 | covariances: FullCovariances 465 | 466 | def __post_init__(self) -> None: 467 | check_shape(self.weights, (1, None, 1, 1)) 468 | check_shape(self.means, (1, None, None, 1)) 469 | 470 | @property 471 | def weights_numpy(self) -> np.typing.NDArray: 472 | """Weights as numpy array""" 473 | return np.squeeze( 474 | np.asarray(self.weights), 475 | axis=(Axis.batch, Axis.features, Axis.features_covar), 476 | ) 477 | 478 | @property 479 | def means_numpy(self) -> np.typing.NDArray: 480 | """Means as numpy array""" 481 | return np.squeeze( 482 | np.asarray(self.means), axis=(Axis.batch, Axis.features_covar) 483 | ) 484 | 485 | @classmethod 486 | def create( 487 | cls, 488 | n_components: int, 489 | n_features: int, 490 | covariance_type: CovarianceType = CovarianceType.full, 491 | device: Device = None, 492 | ) -> GaussianMixtureModelJax: 493 | """Create a GMM from configuration 494 | 495 | Parameters 496 | ---------- 497 | n_components : int 498 | Number of components 499 | n_features : int 500 | Number of features 501 | covariance_type : str, optional 502 | Covariance type, by default "full" 503 | device : str, optional 504 | Device, by default None 505 | 506 | Returns 507 | ------- 508 | gmm : GaussianMixtureModelJax 509 | Gaussian mixture model instance. 510 | """ 511 | covariance_type = CovarianceType(covariance_type) 512 | 513 | weights = jnp.ones((1, n_components, 1, 1)) / n_components 514 | means = jnp.zeros((1, n_components, n_features, 1)) 515 | covariances = COVARIANCE[covariance_type].create( 516 | n_components, n_features, device=device 517 | ) 518 | return cls( 519 | weights=jax.device_put(weights, device=device), 520 | means=jax.device_put(means, device=device), 521 | covariances=covariances, 522 | ) 523 | 524 | @classmethod 525 | def from_squeezed( 526 | cls, 527 | means: AnyArray, 528 | covariances: AnyArray, 529 | weights: AnyArray, 530 | covariance_type: CovarianceType | str = CovarianceType.full, 531 | ) -> GaussianMixtureModelJax: 532 | """Create a Jax GMM from squeezed arrays 533 | 534 | Parameters 535 | ---------- 536 | means : jax.Array or np.array 537 | Mean of each component. Expected shape is (n_components, n_features) 538 | covariances : jax.Array or np.array 539 | Covariance of each component. Expected shape is (n_components, n_features, n_features) 540 | weights : jax.Array or np.array 541 | Weights of each component. Expected shape is (n_components,) 542 | covariance_type : str, optional 543 | Covariance type, by default "full" 544 | 545 | Returns 546 | ------- 547 | gmm : GaussianMixtureModelJax 548 | Gaussian mixture model instance. 549 | """ 550 | covariance_type = CovarianceType(covariance_type) 551 | 552 | means = jnp.expand_dims(means, axis=(Axis.batch, Axis.features_covar)) 553 | weights = jnp.expand_dims( 554 | weights, axis=(Axis.batch, Axis.features, Axis.features_covar) 555 | ) 556 | 557 | covariances = COVARIANCE[covariance_type].from_squeezed(values=covariances) 558 | return cls(weights=weights, means=means, covariances=covariances) # type: ignore [arg-type] 559 | 560 | @classmethod 561 | def from_responsibilities( 562 | cls, 563 | x: jax.Array, 564 | resp: jax.Array, 565 | reg_covar: float, 566 | covariance_type: CovarianceType = CovarianceType.full, 567 | ) -> GaussianMixtureModelJax: 568 | """Update parameters 569 | 570 | Parameters 571 | ---------- 572 | x : jax.array 573 | Feature vectors 574 | resp : jax.array 575 | Responsibilities 576 | reg_covar : float 577 | Regularization for the covariance matrix 578 | covariance_type : str, optional 579 | Covariance type, by default "full" 580 | 581 | Returns 582 | ------- 583 | gmm : GaussianMixtureModelJax 584 | Updated Gaussian mixture model 585 | """ 586 | covariance_type = CovarianceType(covariance_type) 587 | 588 | # I don't like the hard-coded 10 here, but it is the same as in sklearn 589 | nk = ( 590 | jnp.sum(resp, axis=Axis.batch, keepdims=True) 591 | + 10 * jnp.finfo(resp.dtype).eps 592 | ) 593 | means = jnp.matmul(resp.T, x.T.mT).T / nk 594 | covariances = COVARIANCE[covariance_type].from_responsibilities( 595 | x=x, means=means, resp=resp, nk=nk, reg_covar=reg_covar 596 | ) 597 | return cls(weights=nk / nk.sum(), means=means, covariances=covariances) 598 | 599 | @classmethod 600 | def from_k_means( 601 | cls, 602 | x: AnyArray, 603 | n_components: int, 604 | reg_covar: float = 1e-6, 605 | covariance_type: CovarianceType = CovarianceType.full, 606 | **kwargs: dict, 607 | ) -> GaussianMixtureModelJax: 608 | """Init from k-means clustering 609 | 610 | Parameters 611 | ---------- 612 | x : jax.array 613 | Feature vectors 614 | n_components : int 615 | Number of components 616 | reg_covar : float, optional 617 | Regularization for the covariance matrix, by default 1e6 618 | covariance_type : str, optional 619 | Covariance type, by default "full" 620 | **kwargs : dict 621 | Additional arguments passed to `~sklearn.cluster.KMeans` 622 | 623 | Returns 624 | ------- 625 | gmm : GaussianMixtureModelJax 626 | Gaussian mixture model instance. 627 | """ 628 | from sklearn.cluster import KMeans # type: ignore [import-untyped] 629 | 630 | n_samples = x.shape[Axis.batch] 631 | 632 | resp = jnp.zeros((n_samples, n_components)) 633 | 634 | kwargs.setdefault("n_init", 10) # type: ignore [arg-type] 635 | label = KMeans(n_clusters=n_components, **kwargs).fit(x).labels_ 636 | 637 | idx = jnp.arange(n_samples) 638 | resp = resp.at[idx, label].set(1.0) 639 | 640 | xp = jnp.expand_dims(x, axis=(Axis.components, Axis.features_covar)) 641 | resp = jnp.expand_dims(resp, axis=(Axis.features, Axis.features_covar)) 642 | return cls.from_responsibilities( 643 | xp, resp, reg_covar=reg_covar, covariance_type=covariance_type 644 | ) 645 | 646 | @property 647 | def n_features(self) -> int: 648 | """Number of features""" 649 | return self.covariances.n_features 650 | 651 | @property 652 | def n_components(self) -> int: 653 | """Number of components""" 654 | return self.covariances.n_components 655 | 656 | @property 657 | def n_parameters(self) -> int: 658 | """Number of parameters""" 659 | return int( 660 | self.n_components 661 | + self.n_components * self.n_features 662 | + self.covariances.n_parameters 663 | - 1 664 | ) 665 | 666 | @property 667 | def log_weights(self) -> jax.Array: 668 | """Log weights (~jax.ndarray)""" 669 | return jnp.log(self.weights) 670 | 671 | @jax.jit 672 | def log_prob(self, x: jax.Array) -> jax.Array: 673 | """Compute log likelihood for given feature vector 674 | 675 | Parameters 676 | ---------- 677 | x : jax.array 678 | Feature vectors 679 | 680 | Returns 681 | ------- 682 | log_prob : jax.array 683 | Log likelihood 684 | """ 685 | x = jnp.expand_dims(x, axis=(Axis.components, Axis.features_covar)) 686 | log_prob = self.covariances.log_prob(x, self.means) 687 | two_pi = jnp.array(2 * jnp.pi) 688 | 689 | value = ( 690 | -0.5 * (self.n_features * jnp.log(two_pi) + log_prob) 691 | + self.covariances.log_det_cholesky 692 | + self.log_weights 693 | ) 694 | return value 695 | 696 | def to_sklearn(self, **kwargs: dict[str, Any]) -> Any: 697 | """Convert to sklearn GaussianMixture 698 | 699 | The methods sets the weights, means, precisions_cholesky and covariances_ attributes, 700 | however sklearn will overvwrite them when fitting the model. 701 | 702 | Parameters 703 | ---------- 704 | **kwargs : dict 705 | Additional arguments passed to `~sklearn.mixture.GaussianMixture` 706 | 707 | Returns 708 | ------- 709 | gmm : `~sklearn.mixture.GaussianMixture` 710 | Gaussian mixture model instance. 711 | """ 712 | from sklearn.mixture import GaussianMixture # type: ignore [import-untyped] 713 | 714 | kwargs.setdefault("warm_start", True) # type: ignore [arg-type] 715 | gmm = GaussianMixture( 716 | n_components=self.n_components, 717 | covariance_type=SKLEARN_COVARIANCE_TYPE[type(self.covariances)], 718 | **kwargs, 719 | ) 720 | # This does a warm start at the given parameters 721 | gmm.converged_ = True 722 | gmm.lower_bound_ = -np.inf 723 | gmm.weights_ = self.weights_numpy 724 | gmm.means_ = self.means_numpy 725 | gmm.precisions_cholesky_ = self.covariances.precisions_cholesky_numpy 726 | gmm.covariances_ = self.covariances.values_numpy 727 | return gmm 728 | 729 | @jax.jit 730 | def predict(self, x: jax.Array) -> jax.Array: 731 | """Predict the component index for each sample 732 | 733 | Parameters 734 | ---------- 735 | x : jax.array 736 | Feature vectors 737 | 738 | Returns 739 | ------- 740 | predictions : jax.array 741 | Predicted component index 742 | """ 743 | log_prob = self.log_prob(x) 744 | predictions = jnp.argmax(log_prob, axis=Axis.components, keepdims=True) 745 | return jnp.squeeze(predictions, axis=(Axis.features, Axis.features_covar)) 746 | 747 | @jax.jit 748 | def predict_proba(self, x: jax.Array) -> jax.Array: 749 | """Predict the probability of each sample belonging to each component 750 | 751 | Parameters 752 | ---------- 753 | x : jax.array 754 | Feature vectors 755 | 756 | Returns 757 | ------- 758 | probabilities : jax.array 759 | Predicted probabilities 760 | """ 761 | log_prob = self.log_prob(x) 762 | log_prob_norm = jax.scipy.special.logsumexp( 763 | log_prob, axis=Axis.components, keepdims=True 764 | ) 765 | return jnp.exp(log_prob - log_prob_norm) 766 | 767 | @jax.jit 768 | def score_samples(self, x: jax.Array) -> jax.Array: 769 | """Compute the weighted log probabilities for each sample 770 | 771 | Parameters 772 | ---------- 773 | x : jax.array 774 | Feature vectors 775 | 776 | Returns 777 | ------- 778 | log_prob : jax.array 779 | Log probabilities 780 | """ 781 | log_prob = self.log_prob(x) 782 | log_prob_norm = jax.scipy.special.logsumexp( 783 | log_prob, axis=Axis.components, keepdims=True 784 | ) 785 | return log_prob_norm 786 | 787 | @jax.jit 788 | def score(self, x: jax.Array) -> jax.Array: 789 | """Compute the log likelihood of the data 790 | 791 | Parameters 792 | ---------- 793 | x : jax.array 794 | Feature vectors 795 | 796 | Returns 797 | ------- 798 | log_likelihood : float 799 | Log-likelihood of the data 800 | """ 801 | log_prob = self.score_samples(x) 802 | return jnp.mean(log_prob) 803 | 804 | @jax.jit 805 | def aic(self, x: jax.Array) -> jax.Array: 806 | """Compute the Akaike Information Criterion 807 | 808 | Parameters 809 | ---------- 810 | x : jax.array 811 | Feature vectors 812 | 813 | Returns 814 | ------- 815 | aic : jax.array 816 | Akaike Information Criterion 817 | """ 818 | return -2 * self.score(x) * x.shape[Axis.batch] + 2 * self.n_parameters # type: ignore [no-any-return] 819 | 820 | @jax.jit 821 | def bic(self, x: jax.Array) -> jax.Array: 822 | """Compute the Bayesian Information Criterion 823 | 824 | Parameters 825 | ---------- 826 | x : jax.array 827 | Feature vectors 828 | 829 | Returns 830 | ------- 831 | bic : jax.array 832 | Bayesian Information Criterion 833 | """ 834 | return -2 * self.score(x) * x.shape[Axis.batch] + self.n_parameters * jnp.log( # type: ignore [no-any-return] 835 | x.shape[Axis.batch] 836 | ) 837 | 838 | @partial(jax.jit, static_argnames=["n_samples"]) 839 | def sample(self, key: jax.Array, n_samples: int) -> jax.Array: 840 | """Sample from the model 841 | 842 | Parameters 843 | ---------- 844 | key : jax.random.PRNGKey 845 | Random key 846 | n_samples : int 847 | Number of samples 848 | 849 | Returns 850 | ------- 851 | samples : jax.array 852 | Samples 853 | """ 854 | key, subkey = jax.random.split(key) 855 | 856 | selected = jax.random.choice( 857 | key, 858 | jnp.arange(self.n_components), 859 | p=self.weights.flatten(), 860 | shape=(n_samples,), 861 | ) 862 | 863 | # TODO: this blows up the memory, as the arrays are copied, however 864 | # there is no simple way to handle the varying numbers of samples per component 865 | # Jax does not support ragged arrays and the size parameter in random methods has 866 | # to be static. One possibility would be to pad the arrays to the maximum number 867 | # of samples per component, however this might be inefficient as well. 868 | means = jnp.take(self.means, selected, axis=Axis.components) 869 | covar = jnp.take(self.covariances.values_dense, selected, axis=Axis.components) 870 | 871 | samples = jax.random.multivariate_normal( 872 | subkey, 873 | jnp.squeeze(means, axis=(Axis.batch, Axis.features_covar)), 874 | jnp.squeeze(covar, axis=Axis.batch), 875 | shape=(n_samples,), 876 | ) 877 | 878 | return samples 879 | 880 | 881 | def check_model_fitted( 882 | instance: GaussianMixtureSKLearn, 883 | ) -> GaussianMixtureModelJax: 884 | """Check if the model is fitted""" 885 | if instance._gmm is None: 886 | message = "Model not initialized. Call `fit` first." 887 | raise ValueError(message) 888 | 889 | return instance._gmm 890 | 891 | 892 | INIT_METHODS = { 893 | "kmeans": GaussianMixtureModelJax.from_k_means, 894 | } 895 | 896 | 897 | @dataclass 898 | class GaussianMixtureSKLearn: 899 | """Scikit learn compatibile API for Gaussian Mixture Model 900 | 901 | See docs at https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html 902 | """ 903 | 904 | n_components: int 905 | covariance_type: str = "full" 906 | tol: float = 1e-3 907 | reg_covar: float = 1e-6 908 | max_iter: int = 100 909 | n_init: int = 1 910 | init_params: str = "kmeans" 911 | weights_init: AnyArray | None = None 912 | means_init: AnyArray | None = None 913 | precisions_init: AnyArray | None = None 914 | random_state: np.random.RandomState | None = None 915 | warm_start: bool = False 916 | _gmm: GaussianMixtureModelJax | None = field(init=False, repr=False, default=None) 917 | 918 | def __post_init__(self) -> None: 919 | from sklearn.utils import check_random_state # type: ignore [import-untyped] 920 | 921 | if self.n_init > 1: 922 | raise NotImplementedError("n_init > 1 is not supported yet.") 923 | 924 | self.random_state = check_random_state(self.random_state) 925 | 926 | @property 927 | def weights_(self) -> np.typing.NDArray: 928 | """Weights of each component""" 929 | return check_model_fitted(self).weights_numpy 930 | 931 | @property 932 | def means_(self) -> np.typing.NDArray: 933 | """Means of each component""" 934 | return check_model_fitted(self).means_numpy 935 | 936 | @property 937 | def precisions_cholesky_(self) -> np.typing.NDArray: 938 | """Precision matrices of each component""" 939 | return check_model_fitted(self).covariances.precisions_cholesky_numpy 940 | 941 | @property 942 | def covariances_(self) -> np.typing.NDArray: 943 | """Covariances of each component""" 944 | return check_model_fitted(self).covariances.values_numpy 945 | 946 | def _initialize_gmm(self, x: AnyArray) -> None: 947 | init_from_data = ( 948 | self.weights_init is None 949 | or self.means_init is None 950 | or self.precisions_init is None 951 | ) 952 | 953 | if init_from_data: 954 | kwargs = { 955 | "x": x, 956 | "n_components": self.n_components, 957 | "covariance_type": self.covariance_type, 958 | "random_state": self.random_state, 959 | } 960 | self._gmm = INIT_METHODS[self.init_params](**kwargs) # type: ignore [arg-type] 961 | else: 962 | covar = COVARIANCE[CovarianceType(self.covariance_type)] 963 | 964 | self._gmm = GaussianMixtureModelJax.from_squeezed( 965 | means=self.means_init, # type: ignore [arg-type] 966 | covariances=covar.from_precisions( 967 | self.precisions_init.astype(np.float32) # type: ignore [union-attr] 968 | ).values_numpy, 969 | weights=self.weights_init, # type: ignore [arg-type] 970 | covariance_type=self.covariance_type, 971 | ) 972 | 973 | def fit(self, X: AnyArray) -> GaussianMixtureSKLearn: 974 | """Fit the model""" 975 | from gmmx.fit import EMFitter 976 | 977 | do_init = not (self.warm_start and hasattr(self, "converged_")) 978 | 979 | if do_init: 980 | self._initialize_gmm(x=X) 981 | 982 | fitter = EMFitter( 983 | tol=self.tol, 984 | reg_covar=self.reg_covar, 985 | max_iter=self.max_iter, 986 | ) 987 | result = fitter.fit(X, self._gmm) 988 | self._gmm = result.gmm 989 | self.converged_ = result.converged 990 | return self 991 | 992 | def predict(self, X: AnyArray) -> np.typing.NDArray: 993 | """Predict the component index for each sample""" 994 | return np.squeeze(check_model_fitted(self).predict(X), axis=Axis.components) # type: ignore [no-any-return] 995 | 996 | def fit_predict(self) -> np.typing.NDArray: 997 | """Fit the model and predict the component index for each sample""" 998 | raise NotImplementedError 999 | 1000 | def predict_proba(self, X: AnyArray) -> np.typing.NDArray: 1001 | """Predict the probability of each sample belonging to each component""" 1002 | return np.squeeze( # type: ignore [no-any-return] 1003 | check_model_fitted(self).predict_proba(X), 1004 | axis=(Axis.features, Axis.features_covar), 1005 | ) 1006 | 1007 | def sample(self, n_samples: int) -> np.typing.NDArray: 1008 | """Sample from the model""" 1009 | key = jax.random.key(self.random_state.randint(2**32 - 1)) # type: ignore [union-attr] 1010 | return np.asarray(check_model_fitted(self).sample(key=key, n_samples=n_samples)) 1011 | 1012 | def score(self, X: AnyArray) -> np.typing.NDArray: 1013 | """Compute the log likelihood of the data""" 1014 | return np.asarray(check_model_fitted(self).score(X)) 1015 | 1016 | def score_samples(self, X: AnyArray) -> np.typing.NDArray: 1017 | """Compute the weighted log probabilities for each sample""" 1018 | return np.squeeze( # type: ignore [no-any-return] 1019 | (check_model_fitted(self).score_samples(X)), 1020 | axis=(Axis.components, Axis.features, Axis.features_covar), 1021 | ) 1022 | 1023 | def bic(self, X: AnyArray) -> np.typing.NDArray: 1024 | """Compute the Bayesian Information Criterion""" 1025 | return np.asarray(check_model_fitted(self).bic(X)) 1026 | 1027 | def aic(self, X: AnyArray) -> np.typing.NDArray: 1028 | """Compute the Akaike Information Criterion""" 1029 | return np.asarray(check_model_fitted(self).aic(X)) 1030 | -------------------------------------------------------------------------------- /gmmx/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import jax 4 | 5 | 6 | class register_dataclass_jax: 7 | """Decorator to register a dataclass with JAX.""" 8 | 9 | def __init__( 10 | self, 11 | data_fields: Optional[list] = None, 12 | meta_fields: Optional[list] = None, 13 | ) -> None: 14 | self.data_fields = data_fields or [] 15 | self.meta_fields = meta_fields or [] 16 | 17 | def __call__(self, cls: Any) -> Any: 18 | jax.tree_util.register_dataclass( 19 | cls, 20 | data_fields=self.data_fields, 21 | meta_fields=self.meta_fields, 22 | ) 23 | return cls 24 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: gmmx 2 | repo_url: https://github.com/adonath/gmmx 3 | site_url: https://adonath.github.io/gmmx 4 | site_description: A minimal implementation of Gaussian Mixture Models in Jax 5 | site_author: Axel Donath 6 | edit_uri: edit/main/docs/ 7 | repo_name: adonath/gmmx 8 | copyright: Maintained by Axel Donbath. 9 | 10 | nav: 11 | - Home: index.md 12 | - Modules: modules.md 13 | plugins: 14 | - search 15 | - mkdocstrings: 16 | handlers: 17 | python: 18 | paths: [gmmx] 19 | theme: 20 | name: material 21 | feature: 22 | tabs: true 23 | palette: 24 | - media: "(prefers-color-scheme: light)" 25 | scheme: default 26 | primary: white 27 | accent: deep orange 28 | toggle: 29 | icon: material/brightness-7 30 | name: Switch to dark mode 31 | - media: "(prefers-color-scheme: dark)" 32 | scheme: slate 33 | primary: black 34 | accent: deep orange 35 | toggle: 36 | icon: material/brightness-4 37 | name: Switch to light mode 38 | icon: 39 | repo: fontawesome/brands/github 40 | 41 | extra: 42 | social: 43 | - icon: fontawesome/brands/github 44 | link: https://github.com/adonath/gmmx 45 | - icon: fontawesome/brands/python 46 | link: https://pypi.org/project/gmmx 47 | 48 | markdown_extensions: 49 | - toc: 50 | permalink: true 51 | - pymdownx.arithmatex: 52 | generic: true 53 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "gmmx" 3 | version = "0.0.1" 4 | description = "A minimal implementation of Gaussian Mixture Models in Jax" 5 | authors = [{ name = "Axel Donath", email = "mail@axeldonath.com" }] 6 | readme = "README.md" 7 | keywords = ['python'] 8 | requires-python = ">=3.9,<4.0" 9 | classifiers = [ 10 | "Intended Audience :: Developers", 11 | "Programming Language :: Python", 12 | "Programming Language :: Python :: 3", 13 | "Programming Language :: Python :: 3.9", 14 | "Programming Language :: Python :: 3.10", 15 | "Programming Language :: Python :: 3.11", 16 | "Programming Language :: Python :: 3.12", 17 | "Programming Language :: Python :: 3.13", 18 | "Topic :: Software Development :: Libraries :: Python Modules", 19 | ] 20 | dependencies = [ 21 | "jax>=0.4.30", 22 | "numpy>=1.26.0", 23 | ] 24 | 25 | [project.urls] 26 | Homepage = "https://adonath.github.io/software#gmmx" 27 | Repository = "https://github.com/adonath/gmmx" 28 | Documentation = "https://adonath.github.io/gmmx/" 29 | 30 | [tool.uv] 31 | dev-dependencies = [ 32 | "pytest>=7.2.0", 33 | "pre-commit>=2.20.0", 34 | "tox-uv>=1.11.3", 35 | "mypy>=0.991", 36 | "pytest-cov>=4.0.0", 37 | "ruff>=0.6.9", 38 | "mkdocs>=1.4.2", 39 | "mkdocs-material>=8.5.10", 40 | "mkdocstrings[python]>=0.26.1", 41 | "scikit-learn>=1.0", 42 | ] 43 | 44 | [build-system] 45 | requires = ["setuptools >= 61.0"] 46 | build-backend = "setuptools.build_meta" 47 | 48 | [tool.setuptools.packages.find] 49 | include = ["gmmx"] 50 | exclude = ["tests", "examples"] 51 | 52 | [tool.mypy] 53 | files = ["gmmx"] 54 | disallow_untyped_defs = true 55 | disallow_any_unimported = true 56 | no_implicit_optional = true 57 | check_untyped_defs = true 58 | warn_return_any = true 59 | warn_unused_ignores = true 60 | show_error_codes = true 61 | 62 | [tool.pytest.ini_options] 63 | testpaths = ["tests"] 64 | 65 | [tool.ruff] 66 | target-version = "py39" 67 | line-length = 88 68 | fix = true 69 | 70 | [tool.ruff.lint] 71 | select = [ 72 | # flake8-2020 73 | "YTT", 74 | # flake8-bandit 75 | "S", 76 | # flake8-bugbear 77 | "B", 78 | # flake8-builtins 79 | "A", 80 | # flake8-comprehensions 81 | "C4", 82 | # flake8-debugger 83 | "T10", 84 | # flake8-simplify 85 | "SIM", 86 | # isort 87 | "I", 88 | # mccabe 89 | "C90", 90 | # pycodestyle 91 | "E", "W", 92 | # pyflakes 93 | "F", 94 | # pygrep-hooks 95 | "PGH", 96 | # pyupgrade 97 | "UP", 98 | # ruff 99 | "RUF", 100 | # tryceratops 101 | "TRY", 102 | ] 103 | ignore = [ 104 | # LineTooLong 105 | "E501", 106 | # DoNotAssignLambda 107 | "E731", 108 | ] 109 | 110 | [tool.ruff.lint.per-file-ignores] 111 | "tests/*" = ["S101"] 112 | 113 | [tool.ruff.format] 114 | preview = true 115 | 116 | [tool.coverage.report] 117 | skip_empty = true 118 | 119 | [tool.coverage.run] 120 | branch = true 121 | source = ["gmmx"] 122 | -------------------------------------------------------------------------------- /tests/test_gmm.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import pytest 4 | from jax import numpy as jnp 5 | from numpy.testing import assert_allclose 6 | 7 | from gmmx import EMFitter, GaussianMixtureModelJax, GaussianMixtureSKLearn 8 | 9 | TEST_COVARIANCES = { 10 | "full": np.array([ 11 | [[1, 0.5, 0.5], [0.5, 1, 0.5], [0.5, 0.5, 1]], 12 | [[1, 0.5, 0.5], [0.5, 1, 0.5], [0.5, 0.5, 1]], 13 | ]), 14 | "diag": np.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]), 15 | } 16 | 17 | TEST_PRECISIONS = { 18 | "full": np.linalg.inv(TEST_COVARIANCES["full"]), 19 | "diag": 1 / TEST_COVARIANCES["diag"], 20 | } 21 | 22 | MEANS = np.array([[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]]) 23 | WEIGHTS = np.array([0.2, 0.8]) 24 | 25 | 26 | @pytest.fixture(params=["full", "diag"]) 27 | def gmm_jax(request): 28 | covariances = TEST_COVARIANCES[request.param] 29 | return GaussianMixtureModelJax.from_squeezed( 30 | means=MEANS, 31 | covariances=covariances, 32 | weights=WEIGHTS, 33 | covariance_type=request.param, 34 | ) 35 | 36 | 37 | def test_simple(gmm_jax): 38 | assert gmm_jax.n_features == 3 39 | assert gmm_jax.n_components == 2 40 | 41 | expected = {"full": 19, "diag": 13} 42 | assert gmm_jax.n_parameters == expected[gmm_jax.covariances.type.value] 43 | 44 | 45 | def test_create(): 46 | gmm = GaussianMixtureModelJax.create(n_components=3, n_features=2) 47 | assert gmm.n_features == 2 48 | assert gmm.n_components == 3 49 | assert gmm.n_parameters == 17 50 | 51 | 52 | def test_init_incorrect(): 53 | with pytest.raises(ValueError): 54 | GaussianMixtureModelJax( 55 | means=jnp.zeros((2, 3)), 56 | covariances=jnp.zeros((2, 3, 3)), 57 | weights=jnp.zeros((2,)), 58 | ) 59 | 60 | with pytest.raises(ValueError): 61 | GaussianMixtureModelJax( 62 | means=jnp.zeros((1, 2, 3, 1)), 63 | covariances=jnp.zeros((1, 2, 3, 3)), 64 | weights=jnp.zeros((1, 1, 4, 1)), 65 | ) 66 | 67 | 68 | def test_against_sklearn(gmm_jax): 69 | x = np.array([ 70 | [1, 2, 3], 71 | [1, 4, 2], 72 | [1, 0, 6], 73 | [4, 2, 4], 74 | [4, 4, 4], 75 | [4, 0, 2], 76 | ]) 77 | 78 | gmm = gmm_jax.to_sklearn() 79 | result_ref = gmm._estimate_weighted_log_prob(X=x) 80 | result = gmm_jax.log_prob(x=jnp.asarray(x))[:, :, 0, 0] 81 | 82 | assert_allclose(np.asarray(result), result_ref, rtol=1e-6) 83 | 84 | assert gmm_jax.n_parameters == gmm._n_parameters() 85 | 86 | 87 | @pytest.mark.parametrize( 88 | "method", ["aic", "bic", "predict", "predict_proba", "score", "score_samples"] 89 | ) 90 | def test_against_sklearn_all(gmm_jax, method): 91 | gmm = gmm_jax.to_sklearn() 92 | x = np.array([ 93 | [1, 2, 3], 94 | [1, 4, 2], 95 | [1, 0, 6], 96 | [4, 2, 4], 97 | [4, 4, 4], 98 | [4, 0, 2], 99 | ]) 100 | result_sklearn = getattr(gmm, method)(x) 101 | result_jax = getattr(gmm_jax, method)(jnp.asarray(x)) 102 | assert_allclose(np.squeeze(result_jax), result_sklearn, rtol=1e-5) 103 | 104 | 105 | def test_sample(gmm_jax): 106 | key = jax.random.PRNGKey(0) 107 | samples = gmm_jax.sample(key, 2) 108 | 109 | assert samples.shape == (2, 3) 110 | 111 | expected = {"full": -0.458194, "diag": -1.525666} 112 | assert_allclose(samples[0, 0], expected[gmm_jax.covariances.type.value], rtol=1e-6) 113 | 114 | 115 | def test_predict(gmm_jax): 116 | x = np.array([ 117 | [1, 2, 3], 118 | [1, 4, 2], 119 | [1, 0, 6], 120 | [4, 2, 4], 121 | [4, 4, 4], 122 | [4, 0, 2], 123 | ]) 124 | 125 | result = gmm_jax.predict(x=jnp.asarray(x)) 126 | 127 | assert result.shape == (6, 1) 128 | assert_allclose(result[0], 0, rtol=1e-6) 129 | 130 | 131 | def test_fit(gmm_jax): 132 | random_state = np.random.RandomState(827392) 133 | x, _ = gmm_jax.to_sklearn(random_state=random_state).sample(16_000) 134 | 135 | fitter = EMFitter(tol=1e-6) 136 | result = fitter.fit(x=x, gmm=gmm_jax) 137 | 138 | # The number of iterations is not deterministic across architectures 139 | covar_str = gmm_jax.covariances.type.value 140 | expected = {"full": [4, 7], "diag": [7]} 141 | assert int(result.n_iter) in expected[covar_str] 142 | 143 | expected = {"full": -4.3686, "diag": -5.422534} 144 | assert_allclose(result.log_likelihood, expected[covar_str], rtol=2e-4) 145 | 146 | expected = {"full": 9.536743e-07, "diag": 9.536743e-07} 147 | assert_allclose(result.log_likelihood_diff, expected[covar_str], atol=fitter.tol) 148 | 149 | assert_allclose(result.gmm.weights_numpy, [0.2, 0.8], rtol=0.05) 150 | 151 | 152 | def test_fit_against_sklearn(gmm_jax): 153 | # Fitting is hard to test, especillay we cannot guarantee the fit converges to the same solution 154 | # However the "global" likelihood (summed accross all components) for a given feature vector 155 | # should be similar for both implementations 156 | random_state = np.random.RandomState(82792) 157 | x, _ = gmm_jax.to_sklearn(random_state=random_state).sample(16_000) 158 | 159 | tol = 1e-6 160 | fitter = EMFitter(tol=tol) 161 | result_jax = fitter.fit(x=x, gmm=gmm_jax) 162 | 163 | gmm_sklearn = gmm_jax.to_sklearn(tol=tol, random_state=random_state) 164 | 165 | # This brings the sklearn model in the same state as the jax model 166 | gmm_sklearn.fit(x) 167 | 168 | covar_str = gmm_jax.covariances.type.value 169 | 170 | expected = {"full": 9, "diag": 9} 171 | assert_allclose(gmm_sklearn.n_iter_, expected[covar_str]) 172 | assert_allclose(gmm_sklearn.weights_, [0.2, 0.8], rtol=0.06) 173 | 174 | expected = {"full": [9], "diag": [8, 11]} 175 | assert result_jax.n_iter in expected[covar_str] 176 | assert_allclose(result_jax.gmm.weights_numpy, [0.2, 0.8], rtol=0.06) 177 | 178 | assert_allclose(gmm_sklearn.covariances_, TEST_COVARIANCES[covar_str], rtol=0.1) 179 | assert_allclose( 180 | result_jax.gmm.covariances.values_numpy, TEST_COVARIANCES[covar_str], rtol=0.1 181 | ) 182 | 183 | log_likelihood_jax = result_jax.gmm.log_prob(x[:10]).sum(axis=1)[:, 0, 0] 184 | log_likelihood_sklearn = gmm_sklearn._estimate_weighted_log_prob(x[:10]).sum(axis=1) 185 | 186 | # note this is agreement in log-likehood, not likelihood! 187 | assert_allclose(log_likelihood_jax, log_likelihood_sklearn, rtol=1e-2) 188 | 189 | 190 | def test_sklearn_api(gmm_jax): 191 | random_state = np.random.RandomState(829282) 192 | x = gmm_jax.sample(key=jax.random.PRNGKey(0), n_samples=16_000) 193 | 194 | covar_str = gmm_jax.covariances.type.value 195 | 196 | gmm = GaussianMixtureSKLearn( 197 | n_components=2, 198 | covariance_type=covar_str, 199 | tol=1e-6, 200 | random_state=random_state, 201 | weights_init=WEIGHTS, 202 | means_init=MEANS, 203 | precisions_init=TEST_PRECISIONS[covar_str], 204 | max_iter=0, # we just want to test the API, so we don't want to fit 205 | ) 206 | gmm.fit(x) 207 | 208 | assert not gmm.converged_ 209 | 210 | assert_allclose(gmm.weights_, WEIGHTS, rtol=0.06) 211 | assert_allclose(gmm.covariances_, TEST_COVARIANCES[covar_str], rtol=0.1) 212 | assert_allclose(gmm.means_, MEANS, atol=0.05) 213 | 214 | value = gmm.score_samples(x[:2]) 215 | expected = {"full": [-4.435944, -5.810338], "diag": [-5.678393, -7.025789]} 216 | assert_allclose(value, expected[covar_str], rtol=1e-4) 217 | 218 | value = gmm.score(x[:2]) 219 | expected = {"full": -5.123141, "diag": -6.352091} 220 | assert_allclose(value, expected[covar_str], rtol=1e-4) 221 | 222 | value = gmm.predict(x[:2]) 223 | expected = {"full": [1, 0], "diag": [1, 0]} 224 | assert_allclose(value, expected[covar_str], rtol=1e-4) 225 | 226 | value = gmm.predict_proba(x[:2]) 227 | 228 | expected = { 229 | "full": [[1.188522e-07, 1.000000e00], [9.999938e-01, 6.106255e-06]], 230 | "diag": [[3.933927e-06, 9.999962e-01], [9.733525e-01, 2.664736e-02]], 231 | } 232 | assert_allclose(value, expected[covar_str], atol=1e-3) 233 | 234 | value = gmm.aic(x[:2]) 235 | expected = {"full": 58.492565, "diag": 51.408363} 236 | assert_allclose(value, expected[covar_str], rtol=1e-4) 237 | 238 | value = gmm.bic(x[:2]) 239 | expected = {"full": 33.66236, "diag": 34.419277} 240 | assert_allclose(value, expected[covar_str], rtol=1e-4) 241 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | skipsdist = true 3 | envlist = py39, py310, py311, py312, py313 4 | 5 | [gh-actions] 6 | python = 7 | 3.9: py39 8 | 3.10: py310 9 | 3.11: py311 10 | 3.12: py312 11 | 3.13: py313 12 | 13 | [testenv:check-style] 14 | description = check code style with ruff 15 | change_dir = . 16 | skip_install = true 17 | deps = 18 | ruff 19 | commands = 20 | ruff check . {posargs} 21 | 22 | [testenv] 23 | passenv = PYTHON_VERSION 24 | allowlist_externals = uv 25 | commands = 26 | uv sync --python {envpython} 27 | uv run python -m pytest --doctest-modules tests --cov --cov-config=pyproject.toml --cov-report=xml 28 | mypy 29 | --------------------------------------------------------------------------------