├── .github └── workflows │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── HISTORY.rst ├── LICENSE ├── README.md ├── mypy.ini ├── noxfile.py ├── pyproject.toml ├── setup.py ├── src └── emcee_jax │ ├── __init__.py │ ├── _src │ ├── __init__.py │ ├── ensemble.py │ ├── host_callback.py │ ├── log_prob_fn.py │ ├── moves │ │ ├── __init__.py │ │ ├── core.py │ │ ├── slice.py │ │ └── util.py │ ├── ravel_util.py │ ├── sampler.py │ └── types.py │ ├── experimental │ ├── __init__.py │ └── moves │ │ ├── __init__.py │ │ └── hmc.py │ ├── host_callback.py │ └── moves.py └── tests ├── test_host_callback.py ├── test_moves.py ├── test_ravel_util.py └── test_sampler.py /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - "*" 9 | pull_request: 10 | 11 | jobs: 12 | tests: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | os: ["ubuntu-latest", "macos-latest"] 18 | python-version: ["3.9", "3.10"] 19 | session: 20 | - "test" 21 | include: 22 | - os: "ubuntu-latest" 23 | python-version: "3.10" 24 | session: "extras" 25 | - os: "ubuntu-latest" 26 | python-version: "3.10" 27 | session: "doctest" 28 | 29 | steps: 30 | - name: Checkout 31 | uses: actions/checkout@v2 32 | with: 33 | fetch-depth: 0 34 | submodules: true 35 | 36 | - name: Setup Python 37 | uses: actions/setup-python@v2 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | 41 | - name: Install dependencies 42 | run: | 43 | python -m pip install -U pip 44 | python -m pip install -U nox 45 | 46 | - name: Run tests 47 | run: | 48 | python -m nox --non-interactive --error-on-missing-interpreter \ 49 | --session ${{ matrix.session }}-${{ matrix.python-version }} 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.swp 3 | .DS_Store 4 | 5 | *.pyc 6 | *.so 7 | venv* 8 | build/* 9 | 10 | dist 11 | emcee.egg-info 12 | MANIFEST 13 | docs.tar 14 | 15 | *.pdf 16 | 17 | .coverage 18 | .pytest_cache 19 | htmlcov 20 | **/*_version.py 21 | 22 | .tox 23 | env 24 | .eggs 25 | .coverage.* 26 | 27 | /*.ipynb 28 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.2.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | exclude_types: [json, binary] 8 | - repo: https://github.com/PyCQA/isort 9 | rev: "5.10.1" 10 | hooks: 11 | - id: isort 12 | additional_dependencies: [toml] 13 | exclude: docs/tutorials 14 | - repo: https://github.com/hadialqattan/pycln 15 | rev: "v1.3.1" 16 | hooks: 17 | - id: pycln 18 | additional_dependencies: ["click<8.1.0"] 19 | - repo: https://github.com/psf/black 20 | rev: "22.3.0" 21 | hooks: 22 | - id: black-jupyter 23 | - repo: https://github.com/kynan/nbstripout 24 | rev: "0.5.0" 25 | hooks: 26 | - id: nbstripout 27 | exclude: docs/benchmarks.ipynb 28 | - repo: https://github.com/pre-commit/mirrors-mypy 29 | rev: "v0.942" 30 | hooks: 31 | - id: mypy 32 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | * Trolling, insulting/derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at foreman.mackey@gmail.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] 44 | 45 | [homepage]: http://contributor-covenant.org 46 | [version]: http://contributor-covenant.org/version/1/4/ 47 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | .. :changelog: 2 | 3 | 3.1.2 (2022-05-10) 4 | ++++++++++++++++++ 5 | 6 | - Removed ``numpy`` from ``setup_requires`` `#427 `_ 7 | - Made the sampler state indexable `#425 `_ 8 | 9 | 10 | 3.1.1 (2021-08-23) 11 | ++++++++++++++++++ 12 | 13 | - Added support for a progress bar description `#401 `_ 14 | 15 | 16 | 3.1.0 (2021-06-25) 17 | ++++++++++++++++++ 18 | 19 | - Added preliminary support for named parameters `#386 `_ 20 | - Improved handling of blob dtypes `#363 `_ 21 | - Fixed various small bugs and documentation issues 22 | 23 | 24 | 3.0.2 (2019-11-15) 25 | ++++++++++++++++++ 26 | 27 | - Added tutorial for moves interface 28 | - Added information about contributions to documentation 29 | - Improved documentation for installation and testing 30 | - Fixed dtype issues and instability in linear dependence test 31 | - Final release for `JOSS `_ submission 32 | 33 | 34 | 3.0.1 (2019-10-28) 35 | ++++++++++++++++++ 36 | 37 | - Added support for long double dtypes 38 | - Prepared manuscript to submit to `JOSS `_ 39 | - Improved packaging and release infrastructure 40 | - Fixed bug in initial linear dependence test 41 | 42 | 43 | 3.0.0 (2019-09-30) 44 | ++++++++++++++++++ 45 | 46 | - Added progress bars using `tqdm `_. 47 | - Added HDF5 backend using `h5py `_. 48 | - Added new ``Move`` interface for more flexible specification of proposals. 49 | - Improved autocorrelation time estimation algorithm. 50 | - Switched documentation to using Jupyter notebooks for tutorials. 51 | - More details can be found `on the docs `_. 52 | 53 | 2.2.0 (2016-07-12) 54 | ++++++++++++++++++ 55 | 56 | - Improved autocorrelation time computation. 57 | - Numpy compatibility issues. 58 | - Fixed deprecated integer division behavior in PTSampler. 59 | 60 | 61 | 2.1.0 (2014-05-22) 62 | ++++++++++++++++++ 63 | 64 | - Removing dependence on ``acor`` extension. 65 | - Added arguments to ``PTSampler`` function. 66 | - Added automatic load-balancing for MPI runs. 67 | - Added custom load-balancing for MPI and multiprocessing. 68 | - New default multiprocessing pool that supports ``^C``. 69 | 70 | 71 | 2.0.0 (2013-11-17) 72 | ++++++++++++++++++ 73 | 74 | - **Re-licensed under the MIT license!** 75 | - Clearer less verbose documentation. 76 | - Added checks for parameters becoming infinite or NaN. 77 | - Added checks for log-probability becoming NaN. 78 | - Improved parallelization and various other tweaks in ``PTSampler``. 79 | 80 | 81 | 1.2.0 (2013-01-30) 82 | ++++++++++++++++++ 83 | 84 | - Added a parallel tempering sampler ``PTSampler``. 85 | - Added instructions and utilities for using ``emcee`` with ``MPI``. 86 | - Added ``flatlnprobability`` property to the ``EnsembleSampler`` object 87 | to be consistent with the ``flatchain`` property. 88 | - Updated document for publication in PASP. 89 | - Various bug fixes. 90 | 91 | 92 | 1.1.3 (2012-11-22) 93 | ++++++++++++++++++ 94 | 95 | - Made the packaging system more robust even when numpy is not installed. 96 | 97 | 98 | 1.1.2 (2012-08-06) 99 | ++++++++++++++++++ 100 | 101 | - Another bug fix related to metadata blobs: the shape of the final ``blobs`` 102 | object was incorrect and all of the entries would generally be identical 103 | because we needed to copy the list that was appended at each step. Thanks 104 | goes to Jacqueline Chen (MIT) for catching this problem. 105 | 106 | 107 | 1.1.1 (2012-07-30) 108 | ++++++++++++++++++ 109 | 110 | - Fixed bug related to metadata blobs. The sample function was yielding 111 | the ``blobs`` object even when it wasn't expected. 112 | 113 | 114 | 1.1.0 (2012-07-28) 115 | ++++++++++++++++++ 116 | 117 | - Allow the ``lnprobfn`` to return arbitrary "blobs" of data as well as the 118 | log-probability. 119 | - Python 3 compatible (thanks Alex Conley)! 120 | - Various speed ups and clean ups in the core code base. 121 | - New documentation with better examples and more discussion. 122 | 123 | 124 | 1.0.1 (2012-03-31) 125 | ++++++++++++++++++ 126 | 127 | - Fixed transpose bug in the usage of ``acor`` in ``EnsembleSampler``. 128 | 129 | 130 | 1.0.0 (2012-02-15) 131 | ++++++++++++++++++ 132 | 133 | - Initial release. 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2022 Daniel Foreman-Mackey & contributors. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # emcee-jax 2 | 3 | An experiment. 4 | 5 | A simple example: 6 | 7 | ```python 8 | >>> import jax 9 | >>> import emcee_jax 10 | >>> 11 | >>> def log_prob(theta, a1=100.0, a2=20.0): 12 | ... x1, x2 = theta 13 | ... return -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2 14 | ... 15 | >>> num_walkers, num_steps = 100, 1000 16 | >>> key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3) 17 | >>> coords = jax.random.normal(key1, shape=(num_walkers, 2)) 18 | >>> sampler = emcee_jax.EnsembleSampler(log_prob) 19 | >>> state = sampler.init(key2, coords) 20 | >>> trace = sampler.sample(key3, state, num_steps) 21 | 22 | ``` 23 | 24 | An example using PyTrees as input coordinates: 25 | 26 | ```python 27 | >>> import jax 28 | >>> import emcee_jax 29 | >>> 30 | >>> def log_prob(theta, a1=100.0, a2=20.0): 31 | ... x1, x2 = theta["x"], theta["y"] 32 | ... return -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2 33 | ... 34 | >>> num_walkers, num_steps = 100, 1000 35 | >>> key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(0), 4) 36 | >>> coords = { 37 | ... "x": jax.random.normal(key1, shape=(num_walkers,)), 38 | ... "y": jax.random.normal(key2, shape=(num_walkers,)), 39 | ... } 40 | >>> sampler = emcee_jax.EnsembleSampler(log_prob) 41 | >>> state = sampler.init(key3, coords) 42 | >>> trace = sampler.sample(key4, state, num_steps) 43 | 44 | ``` 45 | 46 | An example that includes deterministics: 47 | 48 | ```python 49 | >>> import jax 50 | >>> import emcee_jax 51 | >>> 52 | >>> def log_prob(theta, a1=100.0, a2=20.0): 53 | ... x1, x2 = theta 54 | ... some_number = x1 + jax.numpy.sin(x2) 55 | ... log_prob_value = -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2 56 | ... 57 | ... # This second argument can be any PyTree 58 | ... return log_prob_value, {"some_number": some_number} 59 | ... 60 | >>> num_walkers, num_steps = 100, 1000 61 | >>> key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3) 62 | >>> coords = jax.random.normal(key1, shape=(num_walkers, 2)) 63 | >>> sampler = emcee_jax.EnsembleSampler(log_prob) 64 | >>> state = sampler.init(key2, coords) 65 | >>> trace = sampler.sample(key3, state, num_steps) 66 | 67 | ``` 68 | 69 | You can even use pure-Python log probability functions: 70 | 71 | ```python 72 | >>> import jax 73 | >>> import numpy as np 74 | >>> import emcee_jax 75 | >>> from emcee_jax.host_callback import wrap_python_log_prob_fn 76 | >>> 77 | >>> # A log prob function that uses numpy, not jax.numpy inside 78 | >>> @wrap_python_log_prob_fn 79 | ... def log_prob(theta, a1=100.0, a2=20.0): 80 | ... x1, x2 = theta 81 | ... return -(a1 * np.square(x2 - x1**2) + np.square(1 - x1)) / a2 82 | ... 83 | >>> num_walkers, num_steps = 100, 1000 84 | >>> key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3) 85 | >>> coords = jax.random.normal(key1, shape=(num_walkers, 2)) 86 | >>> sampler = emcee_jax.EnsembleSampler(log_prob) 87 | >>> state = sampler.init(key2, coords) 88 | >>> trace = sampler.sample(key3, state, num_steps) 89 | 90 | ``` 91 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | follow_imports = normal 3 | ignore_missing_imports = True 4 | check_untyped_defs = True 5 | disallow_any_generics = True 6 | disallow_incomplete_defs = True 7 | disallow_untyped_defs = True 8 | no_implicit_optional = True 9 | strict_optional = True 10 | warn_no_return = True 11 | warn_redundant_casts = True 12 | warn_unreachable = True 13 | warn_unused_ignores = True 14 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | import nox 4 | 5 | ALL_PYTHONS = ["3.8", "3.9", "3.10"] 6 | TEST_CMD = ["pytest", "-v"] 7 | 8 | 9 | def _session_run(session, path): 10 | if len(session.posargs): 11 | session.run(*TEST_CMD, *session.posargs) 12 | else: 13 | session.run(*TEST_CMD, path, *session.posargs) 14 | 15 | 16 | @nox.session(python=ALL_PYTHONS) 17 | def test(session): 18 | session.install(".[test]") 19 | _session_run(session, "tests") 20 | 21 | 22 | @nox.session(python=ALL_PYTHONS) 23 | def extras(session): 24 | session.install(".[test,extras]") 25 | _session_run(session, "tests") 26 | 27 | 28 | @nox.session 29 | def lint(session): 30 | session.install("pre-commit") 31 | session.run("pre-commit", "run", "--all-files", *session.posargs) 32 | 33 | 34 | @nox.session(python=ALL_PYTHONS) 35 | def doctest(session): 36 | session.install(".") 37 | session.run("python", "-m", "doctest", "README.md", *session.posargs) 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=40.6.0", "wheel", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 79 7 | 8 | [tool.isort] 9 | skip_glob = [] 10 | line_length = 79 11 | multi_line_output = 3 12 | include_trailing_comma = true 13 | force_grid_wrap = 0 14 | use_parentheses = true 15 | known_first_party = ["emcee_jax"] 16 | combine_as_imports = true 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Inspired by: 4 | # https://hynek.me/articles/sharing-your-labor-of-love-pypi-quick-and-dirty/ 5 | 6 | import codecs 7 | import os 8 | import re 9 | 10 | from setuptools import find_packages, setup 11 | 12 | # PROJECT SPECIFIC 13 | 14 | NAME = "emcee_jax" 15 | PACKAGES = find_packages(where="src") 16 | META_PATH = os.path.join("src", "emcee_jax", "__init__.py") 17 | CLASSIFIERS = [ 18 | "Development Status :: 5 - Production/Stable", 19 | "Intended Audience :: Developers", 20 | "Intended Audience :: Science/Research", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | "Programming Language :: Python", 24 | ] 25 | INSTALL_REQUIRES = ["jax", "jaxlib", "jax_dataclasses"] 26 | EXTRA_REQUIRE = { 27 | "test": ["pytest"], 28 | "extras": ["arviz"], 29 | } 30 | 31 | # END PROJECT SPECIFIC 32 | 33 | 34 | HERE = os.path.dirname(os.path.realpath(__file__)) 35 | 36 | 37 | def read(*parts: str) -> str: 38 | with codecs.open(os.path.join(HERE, *parts), "rb", "utf-8") as f: 39 | return f.read() 40 | 41 | 42 | def find_meta(meta: str, meta_file: str = read(META_PATH)) -> str: 43 | meta_match = re.search( 44 | r"^__{meta}__ = ['\"]([^'\"]*)['\"]".format(meta=meta), meta_file, re.M 45 | ) 46 | if meta_match: 47 | return meta_match.group(1) 48 | raise RuntimeError("Unable to find __{meta}__ string.".format(meta=meta)) 49 | 50 | 51 | if __name__ == "__main__": 52 | setup( 53 | name=NAME, 54 | use_scm_version={ 55 | "write_to": os.path.join( 56 | "src", NAME, "{0}_version.py".format(NAME) 57 | ), 58 | "write_to_template": '__version__ = "{version}"\n', 59 | }, 60 | author=find_meta("author"), 61 | author_email=find_meta("email"), 62 | maintainer=find_meta("author"), 63 | maintainer_email=find_meta("email"), 64 | url=find_meta("uri"), 65 | project_urls={ 66 | "Source": "https://github.com/dfm/emcee-jax", 67 | }, 68 | license=find_meta("license"), 69 | description=find_meta("description"), 70 | long_description=read("README.md"), 71 | long_description_content_type="text/x-rst", 72 | packages=PACKAGES, 73 | package_dir={"": "src"}, 74 | include_package_data=True, 75 | install_requires=INSTALL_REQUIRES, 76 | extras_require=EXTRA_REQUIRE, 77 | classifiers=CLASSIFIERS, 78 | zip_safe=False, 79 | options={"bdist_wheel": {"universal": "1"}}, 80 | ) 81 | -------------------------------------------------------------------------------- /src/emcee_jax/__init__.py: -------------------------------------------------------------------------------- 1 | # nopycln: file 2 | 3 | __bibtex__ = """ 4 | @article{emcee, 5 | author = {{Foreman-Mackey}, D. and {Hogg}, D.~W. and {Lang}, D. and {Goodman}, J.}, 6 | title = {emcee: The MCMC Hammer}, 7 | journal = {PASP}, 8 | year = 2013, 9 | volume = 125, 10 | pages = {306-312}, 11 | eprint = {1202.3665}, 12 | doi = {10.1086/670067} 13 | } 14 | """ 15 | __uri__ = "https://emcee.readthedocs.io" 16 | __author__ = "Daniel Foreman-Mackey" 17 | __email__ = "foreman.mackey@gmail.com" 18 | __license__ = "MIT" 19 | __description__ = "The Python ensemble sampling toolkit for MCMC" 20 | 21 | from emcee_jax import host_callback as host_callback, moves as moves 22 | from emcee_jax._src.sampler import EnsembleSampler as EnsembleSampler 23 | from emcee_jax.emcee_jax_version import __version__ as __version__ 24 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dfm/emcee-jax/c52f42711b351a69dbb187f99691fe0c8ad1bd6c/src/emcee_jax/_src/__init__.py -------------------------------------------------------------------------------- /src/emcee_jax/_src/ensemble.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Tuple, Union 2 | 3 | import jax 4 | import jax.linear_util as lu 5 | import numpy as np 6 | from jax.tree_util import tree_leaves 7 | 8 | from emcee_jax._src.types import Array, PyTree 9 | 10 | 11 | class Ensemble(NamedTuple): 12 | coordinates: PyTree 13 | deterministics: PyTree 14 | log_probability: Array 15 | 16 | @classmethod 17 | def init( 18 | cls, log_prob_fn: lu.WrappedFun, ensemble: Union["Ensemble", PyTree] 19 | ) -> "Ensemble": 20 | if isinstance(ensemble, cls): 21 | return ensemble 22 | fn = jax.vmap(log_prob_fn.call_wrapped) 23 | log_probability, deterministics = fn(ensemble) 24 | return cls( 25 | coordinates=ensemble, 26 | deterministics=deterministics, 27 | log_probability=log_probability, 28 | ) 29 | 30 | 31 | def get_ensemble_shape(ensemble: PyTree) -> Tuple[int, int]: 32 | leaves = tree_leaves(ensemble) 33 | if not len(leaves): 34 | raise ValueError("The ensemble is empty") 35 | if len(leaves) == 1 and leaves[0].ndim <= 1: 36 | raise ValueError( 37 | "An ensemble must have at least 2 dimensions; " 38 | "did you provide just a single walker coordinate?" 39 | ) 40 | leading, rest = zip( 41 | *((x.shape[0], int(np.prod(x.shape[1:]))) for x in leaves) 42 | ) 43 | if any(s != leading[0] for s in leading): 44 | raise ValueError( 45 | f"All leaves must have the same leading dimension; got {leading}" 46 | ) 47 | return leading[0], sum(rest) 48 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/host_callback.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Any, Callable, List, Tuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from jax._src import dtypes 8 | from jax.custom_batching import custom_vmap 9 | from jax.experimental import host_callback 10 | from jax.tree_util import tree_flatten 11 | 12 | from emcee_jax._src.log_prob_fn import LogProbFn 13 | from emcee_jax._src.ravel_util import ravel_ensemble 14 | from emcee_jax._src.types import Array, PyTree 15 | 16 | 17 | def wrap_python_log_prob_fn( 18 | python_log_prob_fn: Callable[..., Array] 19 | ) -> LogProbFn: 20 | @custom_vmap 21 | @wraps(python_log_prob_fn) 22 | def log_prob_fn(params: Array) -> Array: 23 | dtype = _tree_dtype(params) 24 | return host_callback.call( 25 | python_log_prob_fn, 26 | params, 27 | result_shape=jax.ShapeDtypeStruct((), dtype), 28 | ) 29 | 30 | @log_prob_fn.def_vmap 31 | def _( 32 | axis_size: int, in_batched: List[bool], params: Array 33 | ) -> Tuple[Array, bool]: 34 | del axis_size, in_batched 35 | 36 | if _arraylike(params): 37 | flat_params = params 38 | eval_one = python_log_prob_fn 39 | else: 40 | flat_params, unravel = ravel_ensemble(params) 41 | eval_one = lambda x: python_log_prob_fn(unravel(x)) 42 | 43 | result_shape = jax.ShapeDtypeStruct( 44 | (flat_params.shape[0],), flat_params.dtype 45 | ) 46 | return ( 47 | host_callback.call( 48 | lambda y: np.stack([eval_one(x) for x in y]), 49 | flat_params, 50 | result_shape=result_shape, 51 | ), 52 | True, 53 | ) 54 | 55 | return log_prob_fn 56 | 57 | 58 | def _tree_dtype(tree: PyTree) -> Any: 59 | leaves, _ = tree_flatten(tree) 60 | from_dtypes = [dtypes.dtype(l) for l in leaves] 61 | return dtypes.result_type(*from_dtypes) 62 | 63 | 64 | def _arraylike(x: Array) -> bool: 65 | return ( 66 | isinstance(x, np.ndarray) 67 | or isinstance(x, jnp.ndarray) 68 | or hasattr(x, "__jax_array__") 69 | or np.isscalar(x) 70 | ) 71 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/log_prob_fn.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Generator, Tuple, Union 2 | 3 | import jax 4 | import jax.linear_util as lu 5 | import jax.numpy as jnp 6 | 7 | from emcee_jax._src.types import Array, PyTree 8 | 9 | LogProbFn = Callable[..., Union[Array, Tuple[Array, PyTree]]] 10 | 11 | 12 | def wrap_log_prob_fn( 13 | log_prob_fn: LogProbFn, *log_prob_args: Any, **log_prob_kwargs: Any 14 | ) -> lu.WrappedFun: 15 | wrapped_log_prob_fn = lu.wrap_init(log_prob_fn) 16 | return handle_deterministics_and_nans( 17 | wrapped_log_prob_fn, *log_prob_args, **log_prob_kwargs 18 | ) 19 | 20 | 21 | @lu.transformation 22 | def handle_deterministics_and_nans( 23 | *args: Any, **kwargs: Any 24 | ) -> Generator[Tuple[Any, Any], Union[Any, Tuple[Any, Any]], None]: 25 | result = yield args, kwargs 26 | 27 | # Unwrap deterministics if they are provided or default to None 28 | if isinstance(result, tuple): 29 | log_prob, *deterministics = result 30 | if len(deterministics) == 1: 31 | deterministics = deterministics[0] 32 | else: 33 | log_prob = result 34 | deterministics = None 35 | 36 | if log_prob is None: 37 | raise ValueError( 38 | "A log probability function must return a scalar value, got None" 39 | ) 40 | 41 | try: 42 | log_prob = jnp.reshape(log_prob, ()) 43 | except TypeError: 44 | raise ValueError( 45 | "A log probability function must return a scalar; " 46 | f"computed shape is '{log_prob.shape}', expected '()'" 47 | ) 48 | 49 | # Handle the case where the computed log probability is NaN by replacing it 50 | # with negative infinity so that it gets rejected 51 | log_prob = jax.lax.cond( 52 | jnp.isnan(log_prob), lambda: -jnp.inf, lambda: log_prob 53 | ) 54 | 55 | yield log_prob, deterministics 56 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/moves/__init__.py: -------------------------------------------------------------------------------- 1 | from emcee_jax._src.moves.core import ( 2 | DiffEvol as DiffEvol, 3 | Move as Move, 4 | Stretch as Stretch, 5 | compose as compose, 6 | ) 7 | from emcee_jax._src.moves.slice import ( 8 | DiffEvolSlice as DiffEvolSlice, 9 | Slice as Slice, 10 | ) 11 | from emcee_jax._src.moves.util import apply_accept as apply_accept 12 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/moves/core.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple 4 | 5 | import jax 6 | import jax.linear_util as lu 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from jax import random 10 | from jax.tree_util import tree_map 11 | from jax_dataclasses import pytree_dataclass 12 | 13 | from emcee_jax._src.ensemble import Ensemble, get_ensemble_shape 14 | from emcee_jax._src.moves.util import apply_accept 15 | from emcee_jax._src.types import Array, Extras, PyTree, SampleStats 16 | 17 | MoveState = Optional[Any] 18 | 19 | 20 | @pytree_dataclass 21 | class Move: 22 | if TYPE_CHECKING: 23 | 24 | def __init__(self, *args: Any, **kwargs: Any): 25 | super().__init__(*args, **kwargs) 26 | 27 | @classmethod 28 | def __init_subclass__(cls, **kwargs: Any) -> None: 29 | super().__init_subclass__(**kwargs) 30 | pytree_dataclass(cls) 31 | 32 | def init( 33 | self, 34 | random_key: random.KeyArray, 35 | ensemble: Ensemble, 36 | ) -> Tuple[MoveState, Extras]: 37 | del random_key, ensemble 38 | return None, None 39 | 40 | def step( 41 | self, 42 | log_prob_fn: lu.WrappedFun, 43 | random_key: random.KeyArray, 44 | state: MoveState, 45 | ensemble: Ensemble, 46 | extras: Extras, 47 | *, 48 | tune: bool = False, 49 | ) -> Tuple[Tuple[MoveState, Ensemble, Extras], SampleStats]: 50 | del log_prob_fn, random_key, state, ensemble, extras, tune 51 | raise NotImplementedError 52 | 53 | 54 | class Composed(Move): 55 | moves: Sequence[Tuple[str, Move]] 56 | 57 | def init( 58 | self, 59 | random_key: random.KeyArray, 60 | ensemble: Ensemble, 61 | ) -> Tuple[MoveState, Extras]: 62 | keys = random.split(random_key, len(self.moves)) 63 | state = OrderedDict() 64 | extras = OrderedDict() 65 | for key, (name, move) in zip(keys, self.moves): 66 | state[name], extras[name] = move.init(key, ensemble) 67 | return state, extras 68 | 69 | def step( 70 | self, 71 | log_prob_fn: lu.WrappedFun, 72 | random_key: random.KeyArray, 73 | state: MoveState, 74 | ensemble: Ensemble, 75 | extras: Extras, 76 | *, 77 | tune: bool = False, 78 | ) -> Tuple[Tuple[MoveState, Ensemble, Extras], SampleStats]: 79 | if TYPE_CHECKING: 80 | assert isinstance(state, OrderedDict) 81 | assert isinstance(extras, OrderedDict) 82 | keys = random.split(random_key, len(self.moves)) 83 | new_state = OrderedDict() 84 | new_extras = OrderedDict() 85 | new_stats = {} 86 | for key, (name, move) in zip(keys, self.moves): 87 | new, new_stats[name] = move.step( 88 | log_prob_fn, 89 | key, 90 | state[name], 91 | ensemble, 92 | extras[name], 93 | tune=tune, 94 | ) 95 | new_state[name], ensemble, new_extras[name] = new 96 | 97 | stats = {"move_stats": new_stats} 98 | if any("accept" in s for s in new_stats.values()): 99 | stats["accept"] = jnp.any( 100 | jnp.stack([s["accept"] for s in new_stats.values()], axis=-1), 101 | axis=-1, 102 | ) 103 | if any("accept_prob" in s for s in new_stats.values()): 104 | stats["accept_prob"] = jnp.prod( 105 | jnp.stack( 106 | [s["accept_prob"] for s in new_stats.values()], axis=-1 107 | ), 108 | axis=-1, 109 | ) 110 | return (new_state, ensemble, new_extras), stats 111 | 112 | 113 | def compose(*moves: Move, **named_moves: Move) -> Composed: 114 | transformed = [] 115 | for ind, move in enumerate(moves): 116 | transformed.append((f"{move.__class__.__name__}_{ind}", move)) 117 | for name, move in named_moves.items(): 118 | transformed.append((name, move)) 119 | return Composed(moves=transformed) 120 | 121 | 122 | class RedBlue(Move): 123 | def propose( 124 | self, 125 | log_prob_fn: lu.WrappedFun, 126 | state: MoveState, 127 | key: random.KeyArray, 128 | target_walkers: Ensemble, 129 | target_extras: Extras, 130 | compl_walkers: Ensemble, 131 | compl_extras: Extras, 132 | *, 133 | tune: bool, 134 | ) -> Tuple[MoveState, Ensemble, PyTree, SampleStats]: 135 | del log_prob_fn, state, key, tune 136 | del target_walkers, target_extras 137 | del compl_walkers, compl_extras 138 | raise NotImplementedError 139 | 140 | def step( 141 | self, 142 | log_prob_fn: lu.WrappedFun, 143 | random_key: random.KeyArray, 144 | state: MoveState, 145 | ensemble: Ensemble, 146 | extras: Extras, 147 | *, 148 | tune: bool = False, 149 | ) -> Tuple[Tuple[MoveState, Ensemble, Extras], SampleStats]: 150 | # move_state, ensemble, extras = state 151 | key1, key2 = random.split(random_key) 152 | nwalkers, _ = get_ensemble_shape(ensemble) 153 | mid = nwalkers // 2 154 | 155 | ens1 = tree_map(lambda x: x[:mid], ensemble) 156 | ext1 = tree_map(lambda x: x[:mid], extras) 157 | ens2 = tree_map(lambda x: x[mid:], ensemble) 158 | ext2 = tree_map(lambda x: x[mid:], extras) 159 | 160 | half_step = partial(self.propose, log_prob_fn, tune=tune) 161 | state, ens1, ext1, stats1 = half_step( 162 | state, key1, ens1, ext1, ens2, ext2 163 | ) 164 | state, ens2, ext2, stats2 = half_step( 165 | state, key2, ens2, ext2, ens1, ext1 166 | ) 167 | stats = tree_map(lambda *x: jnp.concatenate(x, axis=0), stats1, stats2) 168 | ensemble = tree_map(lambda *x: jnp.concatenate(x, axis=0), ens1, ens2) 169 | extras = tree_map(lambda *x: jnp.concatenate(x, axis=0), ext1, ext2) 170 | return (state, ensemble, extras), stats 171 | 172 | 173 | class SimpleRedBlue(RedBlue): 174 | def propose_simple( 175 | self, key: random.KeyArray, s: PyTree, c: PyTree 176 | ) -> Tuple[PyTree, Array]: 177 | del key, s, c 178 | raise NotImplementedError 179 | 180 | def propose( 181 | self, 182 | log_prob_fn: lu.WrappedFun, 183 | state: MoveState, 184 | key: random.KeyArray, 185 | target_walkers: Ensemble, 186 | target_extras: Extras, 187 | compl_walkers: Ensemble, 188 | compl_extras: Extras, 189 | *, 190 | tune: bool, 191 | ) -> Tuple[MoveState, Ensemble, PyTree, SampleStats]: 192 | del compl_extras, tune 193 | key1, key2 = random.split(key) 194 | q, f = self.propose_simple( 195 | key1, target_walkers.coordinates, compl_walkers.coordinates 196 | ) 197 | nlp, ndet = jax.vmap(log_prob_fn.call_wrapped)(q) 198 | updated = target_walkers._replace( 199 | coordinates=q, 200 | deterministics=ndet, 201 | log_probability=nlp, 202 | ) 203 | diff = nlp - target_walkers.log_probability + f 204 | accept_prob = jnp.minimum(jnp.exp(diff), 1) 205 | accept = accept_prob > random.uniform(key2, shape=diff.shape) 206 | updated = apply_accept(accept, target_walkers, updated) 207 | return ( 208 | state, 209 | updated, 210 | target_extras, 211 | {"accept": accept, "accept_prob": accept_prob}, 212 | ) 213 | 214 | 215 | class Stretch(SimpleRedBlue): 216 | a: Array = 2.0 217 | 218 | def propose_simple( 219 | self, key: random.KeyArray, s: PyTree, c: PyTree 220 | ) -> Tuple[PyTree, Array]: 221 | ns, ndim = get_ensemble_shape(s) 222 | nc, _ = get_ensemble_shape(c) 223 | key1, key2 = random.split(key) 224 | u = random.uniform(key1, shape=(ns,)) 225 | z = jnp.square((self.a - 1) * u + 1) / self.a 226 | ind = random.choice(key2, nc, shape=(ns,)) 227 | updater = jax.vmap(lambda s, c, z: c - (c - s) * z) 228 | q = tree_map(lambda s, c: updater(s, c[ind], z), s, c) 229 | return q, (ndim - 1) * jnp.log(z) 230 | 231 | 232 | class DiffEvol(SimpleRedBlue): 233 | gamma: Optional[Array] = None 234 | sigma: Array = 1.0e-5 235 | 236 | def propose_simple( 237 | self, key: random.KeyArray, s: PyTree, c: PyTree 238 | ) -> Tuple[Array, Array]: 239 | ns, ndim = get_ensemble_shape(s) 240 | nc, _ = get_ensemble_shape(c) 241 | key1, key2 = random.split(key) 242 | 243 | # This is a magic formula from the paper 244 | gamma0 = 2.38 / np.sqrt(2 * ndim) if self.gamma is None else self.gamma 245 | 246 | # These two slightly complicated lines are just to select two helper 247 | # walkers per target walker _without replacement_. This means that we'll 248 | # always get two different complementary walkers per target walker. 249 | choose2 = partial(random.choice, a=nc, replace=False, shape=(2,)) 250 | inds = jax.vmap(choose2)(random.split(key1, ns)) 251 | norm = random.normal(key2, shape=(ns,)) 252 | 253 | @jax.vmap 254 | def update(s: Array, c: Array, norm: Array) -> Array: 255 | delta = c[1] - c[0] 256 | delta = (gamma0 + self.sigma * norm) * delta 257 | return s + delta 258 | 259 | return tree_map( 260 | lambda s, c: update(s, c[inds], norm), s, c 261 | ), jnp.zeros(ns) 262 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/moves/slice.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple 3 | 4 | import jax 5 | import jax.linear_util as lu 6 | import jax.numpy as jnp 7 | from jax import random 8 | from jax.tree_util import tree_map 9 | 10 | from emcee_jax._src.ensemble import Ensemble, get_ensemble_shape 11 | from emcee_jax._src.moves.core import MoveState, RedBlue 12 | from emcee_jax._src.moves.util import apply_accept 13 | from emcee_jax._src.types import Array, Extras, PyTree, SampleStats 14 | 15 | 16 | class Slice(RedBlue): 17 | max_doubles: int = 10_000 18 | max_shrinks: int = 100 19 | initial_step_size: float = 1.0 20 | tune_max_doubles: Optional[int] = None 21 | tune_max_shrinks: Optional[int] = None 22 | 23 | def init( 24 | self, 25 | random_key: random.KeyArray, 26 | ensemble: Ensemble, 27 | ) -> Tuple[MoveState, Extras]: 28 | del random_key, ensemble 29 | return {"step_size": self.initial_step_size}, None 30 | 31 | def get_directions( 32 | self, 33 | step_size: Array, 34 | count: int, 35 | key: random.KeyArray, 36 | complementary: PyTree, 37 | ) -> PyTree: 38 | del step_size, count, key, complementary 39 | raise NotImplementedError 40 | 41 | def propose( 42 | self, 43 | log_prob_fn: lu.WrappedFun, 44 | state: MoveState, 45 | key: random.KeyArray, 46 | target_walkers: Ensemble, 47 | target_extras: Extras, 48 | compl_walkers: Ensemble, 49 | compl_extras: Extras, 50 | *, 51 | tune: bool, 52 | ) -> Tuple[MoveState, Ensemble, PyTree, SampleStats]: 53 | del compl_extras 54 | if TYPE_CHECKING: 55 | assert isinstance(state, dict) 56 | num_target, _ = get_ensemble_shape(target_walkers.coordinates) 57 | key0, *keys = random.split(key, num_target + 1) 58 | directions = self.get_directions( 59 | state["step_size"], num_target, key0, compl_walkers.coordinates 60 | ) 61 | 62 | if tune and self.tune_max_doubles is not None: 63 | max_doubles = self.tune_max_doubles 64 | else: 65 | max_doubles = self.max_doubles 66 | 67 | if tune and self.tune_max_shrinks is not None: 68 | max_shrinks = self.tune_max_shrinks 69 | else: 70 | max_shrinks = self.max_shrinks 71 | 72 | sample_func = partial( 73 | slice_sample, max_doubles, max_shrinks, log_prob_fn 74 | ) 75 | updated, stats = jax.vmap(sample_func)( 76 | jnp.asarray(keys), target_walkers, directions 77 | ) 78 | accept = jnp.logical_and(stats["bounds_ok"], stats["sample_ok"]) 79 | stats["accept"] = accept 80 | stats["accept_prob"] = jnp.ones_like(updated.log_probability) 81 | stats["step_size"] = jnp.full_like( 82 | updated.log_probability, state["step_size"] 83 | ) 84 | updated = apply_accept(accept, target_walkers, updated) 85 | 86 | if tune: 87 | num_doubles = 0.5 * jnp.mean( 88 | stats["num_doubles_left"] + stats["num_doubles_right"] 89 | ) 90 | num_shrinks = jnp.mean(stats["num_shrinks"]) 91 | factor = 2 * num_doubles / (num_doubles + num_shrinks) 92 | next_state = dict(state, step_size=state["step_size"] * factor) 93 | else: 94 | next_state = state 95 | 96 | return next_state, updated, target_extras, stats 97 | 98 | 99 | class DiffEvolSlice(Slice): 100 | def get_directions( 101 | self, 102 | step_size: Array, 103 | count: int, 104 | key: random.KeyArray, 105 | complementary: PyTree, 106 | ) -> PyTree: 107 | # See the ``DiffEvol`` move for an explanation of the following 108 | nc, _ = get_ensemble_shape(complementary) 109 | choose2 = partial(random.choice, a=nc, replace=False, shape=(2,)) 110 | inds = jax.vmap(choose2)(random.split(key, count)) 111 | return tree_map( 112 | lambda c: step_size * jnp.squeeze(jnp.diff(c[inds], axis=1)), 113 | complementary, 114 | ) 115 | 116 | 117 | def slice_sample( 118 | max_doubles: int, 119 | max_shrinks: int, 120 | log_prob_fn: lu.WrappedFun, 121 | random_key: random.KeyArray, 122 | initial: Ensemble, 123 | dx: Array, 124 | ) -> Tuple[Ensemble, Dict[str, Any]]: 125 | level_key, doubling_key, shrink_key = random.split(random_key, 3) 126 | level = initial.log_probability - random.exponential(level_key) 127 | 128 | ( 129 | left, 130 | right, 131 | num_doubles_left, 132 | num_doubles_right, 133 | bounds_ok, 134 | ) = _find_bounds_by_doubling_while_loop( 135 | max_doubles, log_prob_fn, level, doubling_key, initial.coordinates, dx 136 | ) 137 | 138 | final, num_shrinks, sample_ok = _sample_by_shrinking_while_loop( 139 | max_shrinks, log_prob_fn, level, shrink_key, initial, left, right 140 | ) 141 | 142 | return final, { 143 | "level": level, 144 | "num_doubles_left": num_doubles_left, 145 | "num_doubles_right": num_doubles_right, 146 | "bounds_ok": bounds_ok, 147 | "num_shrinks": num_shrinks, 148 | "sample_ok": sample_ok, 149 | } 150 | 151 | 152 | def _find_bounds_by_doubling_while_loop( 153 | max_doubles: int, 154 | log_prob_fn: lu.WrappedFun, 155 | level: Array, 156 | key: random.KeyArray, 157 | x0: PyTree, 158 | dx: PyTree, 159 | ) -> Tuple[PyTree, PyTree, Array, Array, Array]: 160 | def doubling( 161 | direction: float, args: Tuple[Array, Array, PyTree] 162 | ) -> Tuple[Array, Array, PyTree]: 163 | count, found, loc = args 164 | next_loc = tree_map(lambda loc, dx: loc + direction * dx, loc, dx) 165 | log_prob, _ = log_prob_fn.call_wrapped(next_loc) 166 | return count + 1, found | jnp.less(log_prob, level), next_loc 167 | 168 | cond = lambda args: jnp.logical_and( 169 | jnp.less(args[0], max_doubles), jnp.any(jnp.logical_not(args[1])) 170 | ) 171 | r = random.uniform(key) 172 | init = tree_map(lambda x0, dx: x0 - r * dx, x0, dx) 173 | num_left, left_ok, left = jax.lax.while_loop( 174 | cond, partial(doubling, -1), (0, False, init) 175 | ) 176 | init = tree_map(lambda x0, dx: x0 + (1 - r) * dx, x0, dx) 177 | num_right, right_ok, right = jax.lax.while_loop( 178 | cond, partial(doubling, 1), (0, False, init) 179 | ) 180 | 181 | return left, right, num_left, num_right, jnp.logical_and(left_ok, right_ok) 182 | 183 | 184 | def _sample_by_shrinking_while_loop( 185 | max_shrinks: int, 186 | log_prob_fn: lu.WrappedFun, 187 | level: Array, 188 | key: random.KeyArray, 189 | initial: Ensemble, 190 | left: PyTree, 191 | right: PyTree, 192 | ) -> Tuple[Ensemble, Array, Array]: 193 | def shrinking( 194 | args: Tuple[Array, Array, Array, Array, random.KeyArray, Ensemble] 195 | ) -> Tuple[Array, Array, Array, Array, random.KeyArray, Ensemble]: 196 | count, found, left, right, key, state = args 197 | key, next_key = random.split(key) 198 | u = random.uniform(key) 199 | x = tree_map( 200 | lambda left, right: (1 - u) * left + u * right, left, right 201 | ) 202 | log_prob, deterministics = log_prob_fn.call_wrapped(x) 203 | next_state = state._replace( 204 | coordinates=x, 205 | deterministics=deterministics, 206 | log_probability=log_prob, 207 | ) 208 | next_left = tree_map( 209 | lambda x, x0, left: jnp.where(jnp.less(x, x0), x, left), 210 | x, 211 | x0, 212 | left, 213 | ) 214 | next_right = tree_map( 215 | lambda x, x0, right: jnp.where(jnp.greater_equal(x, x0), x, right), 216 | x, 217 | x0, 218 | right, 219 | ) 220 | accept = jnp.greater_equal(log_prob, level) 221 | return ( 222 | count + 1, 223 | found | accept, 224 | next_left, 225 | next_right, 226 | next_key, 227 | next_state, 228 | ) 229 | 230 | x0 = initial.coordinates 231 | cond = lambda args: jnp.logical_and( 232 | jnp.less(args[0], max_shrinks), jnp.any(jnp.logical_not(args[1])) 233 | ) 234 | count, ok, *_, state = jax.lax.while_loop( 235 | cond, 236 | shrinking, 237 | (0, False, left, right, key, initial), 238 | ) 239 | return state, count, ok 240 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/moves/util.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import TypeVar 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax.tree_util import tree_map 7 | 8 | from emcee_jax._src.types import Array, PyTree 9 | 10 | T = TypeVar("T", PyTree, Array) 11 | 12 | 13 | def apply_accept(accept: Array, target: T, other: T) -> T: 14 | accepter = jax.vmap(lambda a, x, y: jnp.where(a, y, x)) 15 | accepter = partial(accepter, accept) 16 | return tree_map(accepter, target, other) 17 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/ravel_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is based on the implementation of ``ravel_pytree`` in 3 | ``jax.flatten_util``, but it adds support for the leading dimension encountered 4 | in an ensemble. 5 | """ 6 | 7 | import warnings 8 | from typing import Callable, List, Tuple 9 | 10 | import jax 11 | import jax.numpy as jnp 12 | import numpy as np 13 | from jax import lax 14 | from jax._src import dtypes 15 | from jax._src.util import safe_zip 16 | from jax.tree_util import tree_flatten, tree_unflatten 17 | 18 | from emcee_jax._src.types import Array, PyTree 19 | 20 | UnravelFn = Callable[[Array], PyTree] 21 | 22 | zip = safe_zip 23 | 24 | 25 | def ravel_ensemble(coords: PyTree) -> Tuple[Array, UnravelFn]: 26 | leaves, treedef = tree_flatten(coords) 27 | flat, unravel_inner = _ravel_inner(leaves) 28 | unravel_one = lambda flat: tree_unflatten(treedef, unravel_inner(flat)) 29 | return flat, unravel_one 30 | 31 | 32 | def _ravel_inner(lst: List[Array]) -> Tuple[Array, UnravelFn]: 33 | if not lst: 34 | return jnp.array([], jnp.float32), lambda _: [] 35 | from_dtypes = [dtypes.dtype(l) for l in lst] 36 | to_dtype = dtypes.result_type(*from_dtypes) 37 | shapes = [jnp.shape(x)[1:] for x in lst] 38 | indices = np.cumsum([int(np.prod(s)) for s in shapes]) 39 | 40 | if all(dt == to_dtype for dt in from_dtypes): 41 | del from_dtypes, to_dtype 42 | 43 | def unravel(arr: Array) -> PyTree: 44 | chunks = jnp.split(arr, indices[:-1]) 45 | return [ 46 | chunk.reshape(shape) for chunk, shape in zip(chunks, shapes) 47 | ] 48 | 49 | ravel = lambda arg: jnp.concatenate([jnp.ravel(e) for e in arg]) 50 | raveled = jax.vmap(ravel)(lst) 51 | return raveled, unravel 52 | 53 | else: 54 | 55 | def unravel(arr: Array) -> PyTree: 56 | arr_dtype = dtypes.dtype(arr) 57 | if arr_dtype != to_dtype: 58 | raise TypeError( 59 | f"unravel function given array of dtype {arr_dtype}, " 60 | f"but expected dtype {to_dtype}" 61 | ) 62 | chunks = jnp.split(arr, indices[:-1]) 63 | with warnings.catch_warnings(): 64 | warnings.simplefilter("ignore") 65 | return [ 66 | lax.convert_element_type(chunk.reshape(shape), dtype) 67 | for chunk, shape, dtype in zip(chunks, shapes, from_dtypes) 68 | ] 69 | 70 | ravel = lambda arg: jnp.concatenate( 71 | [jnp.ravel(lax.convert_element_type(e, to_dtype)) for e in arg] 72 | ) 73 | raveled = jax.vmap(ravel)(lst) 74 | return raveled, unravel 75 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/sampler.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Tuple, Union 3 | 4 | import jax 5 | import jax.linear_util as lu 6 | import jax.numpy as jnp 7 | from jax import device_get, random 8 | from jax.tree_util import tree_flatten, tree_map 9 | 10 | from emcee_jax._src.ensemble import Ensemble 11 | from emcee_jax._src.log_prob_fn import LogProbFn, wrap_log_prob_fn 12 | from emcee_jax._src.moves.core import Extras, Move, MoveState, Stretch 13 | from emcee_jax._src.types import Array, SampleStats 14 | 15 | if TYPE_CHECKING: 16 | from arviz import InferenceData 17 | 18 | 19 | class SamplerState(NamedTuple): 20 | move_state: MoveState 21 | ensemble: Ensemble 22 | extras: Extras 23 | 24 | 25 | class Trace(NamedTuple): 26 | final_state: SamplerState 27 | samples: Ensemble 28 | extras: Extras 29 | move_state: MoveState 30 | sample_stats: SampleStats 31 | 32 | def to_inference_data(self, **kwargs: Any) -> "InferenceData": 33 | from arviz import InferenceData, dict_to_dataset 34 | 35 | import emcee_jax 36 | 37 | # Deal with different possible PyTree shapes 38 | samples = self.samples.coordinates 39 | if not isinstance(samples, dict): 40 | flat, _ = tree_flatten(samples) 41 | samples = {f"param_{n}": x for n, x in enumerate(flat)} 42 | 43 | # Deterministics also live in samples 44 | deterministics = self.samples.deterministics 45 | if deterministics is not None: 46 | if not isinstance(deterministics, dict): 47 | flat, _ = tree_flatten(deterministics) 48 | deterministics = {f"det_{n}": x for n, x in enumerate(flat)} 49 | for k in list(deterministics.keys()): 50 | if k in samples: 51 | assert f"{k}_det" not in samples 52 | deterministics[f"{k}_det"] = deterministics.pop(k) 53 | samples = dict(samples, **deterministics) 54 | 55 | # ArviZ has a different convention about axis locations. It wants (chain, 56 | # draw, ...) whereas we produce (draw, chain, ...). 57 | samples = tree_map(lambda x: jnp.swapaxes(x, 0, 1), samples) 58 | 59 | # Convert sample stats to ArviZ's preferred style 60 | sample_stats = dict( 61 | _flatten_dict(self.sample_stats), lp=self.samples.log_probability 62 | ) 63 | renames = [("accept_prob", "acceptance_rate")] 64 | for old, new in renames: 65 | if old in sample_stats: 66 | sample_stats[new] = sample_stats.pop(old) 67 | sample_stats = tree_map(lambda x: jnp.swapaxes(x, 0, 1), sample_stats) 68 | 69 | return InferenceData( 70 | posterior=dict_to_dataset(device_get(samples), library=emcee_jax), 71 | sample_stats=dict_to_dataset( 72 | device_get(sample_stats), library=emcee_jax 73 | ), 74 | **kwargs, 75 | ) 76 | 77 | 78 | def _flatten_dict( 79 | obj: Union[Dict[str, Any], Any] 80 | ) -> Union[Dict[str, Any], Any]: 81 | if not isinstance(obj, dict): 82 | return obj 83 | result = {} 84 | for k, v in obj.items(): 85 | if isinstance(v, dict): 86 | for k1, v1 in _flatten_dict(v).items(): 87 | result[f"{k}:{k1}"] = v1 88 | else: 89 | result[k] = v 90 | return result 91 | 92 | 93 | @dataclass(frozen=True, init=False) 94 | class EnsembleSampler: 95 | wrapped_log_prob_fn: lu.WrappedFun 96 | move: Move 97 | 98 | def __init__( 99 | self, 100 | log_prob_fn: LogProbFn, 101 | *, 102 | move: Optional[Move] = None, 103 | log_prob_args: Tuple[Any, ...] = (), 104 | log_prob_kwargs: Optional[Dict[str, Any]] = None, 105 | ): 106 | log_prob_kwargs = {} if log_prob_kwargs is None else log_prob_kwargs 107 | wrapped_log_prob_fn = wrap_log_prob_fn( 108 | log_prob_fn, *log_prob_args, **log_prob_kwargs 109 | ) 110 | object.__setattr__(self, "wrapped_log_prob_fn", wrapped_log_prob_fn) 111 | 112 | move = Stretch() if move is None else move 113 | object.__setattr__(self, "move", move) 114 | 115 | def init( 116 | self, 117 | random_key: random.KeyArray, 118 | ensemble: Union[Ensemble, Array], 119 | ) -> SamplerState: 120 | initial_ensemble = Ensemble.init(self.wrapped_log_prob_fn, ensemble) 121 | move_state, extras = self.move.init(random_key, initial_ensemble) 122 | return SamplerState(move_state, initial_ensemble, extras) 123 | 124 | def step( 125 | self, 126 | random_key: random.KeyArray, 127 | state: SamplerState, 128 | *, 129 | tune: bool = False, 130 | ) -> Tuple[SamplerState, SampleStats]: 131 | if not isinstance(state, SamplerState): 132 | raise ValueError( 133 | "Invalid input state; you must call sampler.init(...) " 134 | "to initialize the state first" 135 | ) 136 | new_state, stats = self.move.step( 137 | self.wrapped_log_prob_fn, random_key, *state, tune=tune 138 | ) 139 | return SamplerState(*new_state), stats 140 | 141 | def sample( 142 | self, 143 | random_key: random.KeyArray, 144 | state: SamplerState, 145 | num_steps: int, 146 | *, 147 | tune: bool = False, 148 | ) -> Trace: 149 | def one_step( 150 | state: SamplerState, key: random.KeyArray 151 | ) -> Tuple[SamplerState, Tuple[SamplerState, SampleStats]]: 152 | state, stats = self.step(key, state, tune=tune) 153 | return state, (state, stats) 154 | 155 | keys = random.split(random_key, num_steps) 156 | final, (trace, stats) = jax.lax.scan(one_step, state, keys) 157 | return Trace( 158 | final_state=final, 159 | samples=trace.ensemble, 160 | extras=trace.extras, 161 | move_state=trace.move_state, 162 | sample_stats=stats, 163 | ) 164 | -------------------------------------------------------------------------------- /src/emcee_jax/_src/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, NamedTuple, Union 2 | 3 | Array = Any 4 | PyTree = Union[Array, Iterable[Array], Dict[Any, Array], NamedTuple] 5 | SampleStats = Dict[str, Array] 6 | Extras = PyTree 7 | -------------------------------------------------------------------------------- /src/emcee_jax/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dfm/emcee-jax/c52f42711b351a69dbb187f99691fe0c8ad1bd6c/src/emcee_jax/experimental/__init__.py -------------------------------------------------------------------------------- /src/emcee_jax/experimental/moves/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dfm/emcee-jax/c52f42711b351a69dbb187f99691fe0c8ad1bd6c/src/emcee_jax/experimental/moves/__init__.py -------------------------------------------------------------------------------- /src/emcee_jax/experimental/moves/hmc.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | from functools import partial 4 | from typing import ( 5 | Any, 6 | Callable, 7 | Dict, 8 | Generator, 9 | NamedTuple, 10 | Optional, 11 | Tuple, 12 | Union, 13 | ) 14 | 15 | import jax 16 | import jax.linear_util as lu 17 | import jax.numpy as jnp 18 | import jax.scipy as jsp 19 | from jax import random 20 | from jax.tree_util import tree_map 21 | 22 | from emcee_jax._src.moves.core import MoveState, RedBlue, StepState 23 | from emcee_jax._src.types import Array, PyTree, SampleStats 24 | 25 | 26 | class HMCState(NamedTuple): 27 | coordinates: Array 28 | momenta: Array 29 | log_probability: Array 30 | grad_log_probability: Array 31 | deterministics: Array 32 | 33 | 34 | def leapfrog( 35 | log_prob_and_grad_fn: WrappedLogProbFn, 36 | state: HMCState, 37 | *, 38 | step_size: Array, 39 | ) -> HMCState: 40 | p = state.coordinates + 0.5 * step_size * state.grad_log_probability 41 | q = state.momenta + step_size * p 42 | (log_prob, det), dlogp = log_prob_and_grad_fn(q) 43 | p = p + 0.5 * step_size * dlogp 44 | return state._replace( 45 | coordinates=q, 46 | momenta=p, 47 | log_probability=log_prob, 48 | grad_log_probability=dlogp, 49 | deterministics=det, 50 | ) 51 | 52 | 53 | def hmc( 54 | log_prob_and_grad_fn: WrappedLogProbFn, 55 | random_key: random.KeyArray, 56 | state: HMCState, 57 | step_size: Array, 58 | num_steps: Array, 59 | ) -> Tuple[HMCState, SampleStats]: 60 | momenta_key, accept_key = random.split(random_key) 61 | norm = random.normal(momenta_key, state.momenta.shape) 62 | init = state._replace(momenta=norm) 63 | 64 | def step(_: Any, state: HMCState) -> HMCState: 65 | new_state = leapfrog(log_prob_and_grad_fn, state, step_size=step_size) 66 | return new_state 67 | 68 | proposed = jax.lax.fori_loop(0, num_steps, step, init) 69 | proposed = proposed._replace(momenta=-proposed.momenta) 70 | 71 | return mh_accept(init, proposed, key=accept_key) 72 | 73 | 74 | def mh_accept( 75 | init: HMCState, 76 | prop: HMCState, 77 | *, 78 | key: Optional[random.KeyArray] = None, 79 | level: Optional[Array] = None, 80 | ) -> Tuple[HMCState, SampleStats]: 81 | prop_lp = prop.log_probability - 0.5 * jnp.sum(jnp.square(prop.momenta)) 82 | init_lp = init.log_probability - 0.5 * jnp.sum(jnp.square(init.momenta)) 83 | log_accept_prob = prop_lp - init_lp 84 | 85 | if level is None: 86 | assert key is not None 87 | u = random.uniform(key) 88 | accept = jnp.log(u) < log_accept_prob 89 | return tree_map(lambda x, y: jnp.where(accept, y, x), init, prop), { 90 | "accept": accept, 91 | "log_accept_prob": log_accept_prob, 92 | } 93 | 94 | raise NotImplementedError 95 | 96 | 97 | class HMC(RedBlue): 98 | step_size: Array = 0.1 99 | num_steps: Array = 50 100 | max_step_size: Optional[Array] = None 101 | max_num_steps: Optional[Array] = None 102 | precondition: bool = False 103 | 104 | def init( 105 | self, 106 | log_prob_fn: WrappedLogProbFn, 107 | key: random.KeyArray, 108 | ensemble: FlatWalkerState, 109 | ) -> StepState: 110 | extras = {} if ensemble.extras is None else ensemble.extras 111 | extras["momenta"] = random.normal(key, ensemble.coordinates.shape) 112 | dlogp, _ = jax.vmap(jax.grad(log_prob_fn, has_aux=True))( 113 | ensemble.coordinates 114 | ) 115 | extras["grad_log_probability"] = dlogp 116 | ensemble = ensemble._replace(extras=extras) 117 | return StepState(move_state={"iteration": 0}, walker_state=ensemble) 118 | 119 | def get_step_size(self, key: random.KeyArray) -> Array: 120 | if self.max_step_size is None: 121 | return self.step_size 122 | return jnp.exp( 123 | random.uniform( 124 | key, 125 | minval=jnp.log(self.step_size), 126 | maxval=jnp.log(self.max_step_size), 127 | ) 128 | ) 129 | 130 | def get_num_steps(self, key: random.KeyArray) -> Array: 131 | if self.max_num_steps is None: 132 | return jnp.asarray(self.num_steps, dtype=int) 133 | return jnp.floor( 134 | random.uniform( 135 | key, 136 | minval=self.num_steps, 137 | maxval=self.max_num_steps + 1, 138 | ) 139 | ).astype(int) 140 | 141 | def propose( 142 | self, 143 | log_prob_fn: WrappedLogProbFn, 144 | move_state: MoveState, 145 | key: random.KeyArray, 146 | target: FlatWalkerState, 147 | complementary: FlatWalkerState, 148 | *, 149 | tune: bool, 150 | ) -> Tuple[MoveState, FlatWalkerState, SampleStats]: 151 | assert move_state is not None 152 | assert target.extras is not None 153 | assert complementary.extras is not None 154 | 155 | if self.precondition: 156 | cov = jnp.cov(complementary.coordinates, rowvar=0) 157 | ell = jnp.linalg.cholesky(cov) 158 | condition = partial(jsp.linalg.solve_triangular, ell, lower=True) 159 | coords = jax.vmap(condition)(target.coordinates) 160 | uncondition = lambda x: ell @ x 161 | wrapped_log_prob_fn = precondition_log_prob_fn( 162 | lu.wrap_init(log_prob_fn), uncondition 163 | ).call_wrapped 164 | else: 165 | coords = target.coordinates 166 | wrapped_log_prob_fn = log_prob_fn 167 | 168 | num_target = target.coordinates.shape[0] 169 | step_size_key, num_steps_key, step_key = random.split(key, 3) 170 | 171 | init = HMCState( 172 | coordinates=coords, 173 | momenta=target.extras["momenta"], 174 | log_probability=target.log_probability, 175 | grad_log_probability=target.extras["grad_log_probability"], 176 | deterministics=target.deterministics, 177 | ) 178 | 179 | step_size = self.get_step_size(step_size_key) 180 | num_steps = self.get_num_steps(num_steps_key) 181 | if jnp.ndim(step_size) >= 1 or jnp.ndim(num_steps) >= 1: 182 | step_size = jnp.broadcast_to(step_size, num_target) 183 | num_steps = jnp.broadcast_to(step_size, num_target) 184 | step = jax.vmap( 185 | partial( 186 | hmc, jax.value_and_grad(wrapped_log_prob_fn, has_aux=True) 187 | ) 188 | ) 189 | result, stats = step( 190 | random.split(step_key, num_target), init, step_size, num_steps 191 | ) 192 | 193 | else: 194 | step = jax.vmap( 195 | partial( 196 | hmc, 197 | jax.value_and_grad(wrapped_log_prob_fn, has_aux=True), 198 | step_size=step_size, 199 | num_steps=num_steps, 200 | ) 201 | ) 202 | result, stats = step(random.split(step_key, num_target), init) 203 | 204 | step_size = jnp.broadcast_to(step_size, num_target) 205 | num_steps = jnp.broadcast_to(step_size, num_target) 206 | 207 | if self.precondition: 208 | coords = jax.vmap(uncondition)(result.coordinates) 209 | else: 210 | coords = result.coordinates 211 | 212 | updated = target._replace( 213 | coordinates=coords, 214 | deterministics=result.deterministics, 215 | log_probability=result.log_probability, 216 | extras={ 217 | "momenta": result.momenta, 218 | "grad_log_probability": result.grad_log_probability, 219 | }, 220 | ) 221 | return ( 222 | dict(move_state, iteration=move_state.pop("iteration") + 1), 223 | updated, 224 | dict(stats, step_size=step_size, num_steps=num_steps), 225 | ) 226 | 227 | 228 | @lu.transformation 229 | def precondition_log_prob_fn( 230 | uncondition: Callable[[Array], Array], x: Array 231 | ) -> Generator[ 232 | Tuple[Array, Union[PyTree, Dict[str, Any]]], 233 | Tuple[Array, Union[PyTree, Dict[str, Any]]], 234 | None, 235 | ]: 236 | result = yield (uncondition(x),), {} 237 | yield result 238 | 239 | 240 | # def persistent_ghmc( 241 | # random_key: random.KeyArray, 242 | # state: HMCState, 243 | # u: Array, 244 | # *, 245 | # value_and_grad: Callable[[Array], Tuple[Array, Array]], 246 | # eps: Array, 247 | # alpha: Array, 248 | # delta: Array, 249 | # ) -> Tuple[HMCState, Array, SampleStats]: 250 | # u = (u + 1.0 + delta) % 2.0 - 1.0 251 | 252 | # # Jitter momentum 253 | # n = random.normal(random_key, state.p.shape) 254 | # p = state.p * jnp.sqrt(1 - alpha) + jnp.sqrt(alpha) * n 255 | # init = state._replace(p=p) 256 | 257 | # # Run integrator 258 | # state_ = _leapfrog(value_and_grad, state, eps=eps) 259 | 260 | # # Accept/reject 261 | # diff = init.log_prob - state_.log_prob 262 | # diff += -0.5 * jnp.sum(jnp.square(init.p)) 263 | # diff -= -0.5 * jnp.sum(jnp.square(state_.p)) 264 | 265 | # accept_prob = jnp.exp(diff) 266 | # accept = jnp.log(jnp.abs(u)) < diff 267 | 268 | # # Negate the initial momentum 269 | # init = init._replace(p=-init.p) 270 | # state = tree_map(lambda x, y: jnp.where(accept, x, y), state_, init) 271 | # new_u = u * (~accept + accept / accept_prob) 272 | 273 | # return ( 274 | # state, 275 | # new_u, 276 | # { 277 | # "accept": accept, 278 | # "accept_prob": accept_prob, 279 | # "eps": eps, 280 | # "alpha": alpha, 281 | # "delta": delta, 282 | # "u": new_u, 283 | # }, 284 | # ) 285 | 286 | 287 | # MEADS: https://proceedings.mlr.press/v151/hoffman22a.html 288 | 289 | 290 | # class HMCState(NamedTuple): 291 | # q: Array 292 | # p: Array 293 | # log_prob: Array 294 | # d_log_prob: Array 295 | # deterministics: Array 296 | 297 | 298 | # def _leapfrog( 299 | # value_and_grad: Callable[[Array], Tuple[Array, Array]], 300 | # state: HMCState, 301 | # *, 302 | # eps: Array, 303 | # ) -> HMCState: 304 | # p = state.p + 0.5 * eps * state.d_log_prob 305 | # q = state.q + eps * p 306 | # (lp, det), dlogp = value_and_grad(q) 307 | # p = p + 0.5 * eps * dlogp 308 | # return HMCState( 309 | # q=q, p=p, log_prob=lp, d_log_prob=dlogp, deterministics=det 310 | # ) 311 | 312 | 313 | # def persistent_ghmc( 314 | # random_key: random.KeyArray, 315 | # state: HMCState, 316 | # u: Array, 317 | # *, 318 | # value_and_grad: Callable[[Array], Tuple[Array, Array]], 319 | # eps: Array, 320 | # alpha: Array, 321 | # delta: Array, 322 | # ) -> Tuple[HMCState, Array, SampleStats]: 323 | # u = (u + 1.0 + delta) % 2.0 - 1.0 324 | 325 | # # Jitter momentum 326 | # n = random.normal(random_key, state.p.shape) 327 | # p = state.p * jnp.sqrt(1 - alpha) + jnp.sqrt(alpha) * n 328 | # init = state._replace(p=p) 329 | 330 | # # Run integrator 331 | # state_ = _leapfrog(value_and_grad, state, eps=eps) 332 | 333 | # # Accept/reject 334 | # diff = init.log_prob - state_.log_prob 335 | # diff += -0.5 * jnp.sum(jnp.square(init.p)) 336 | # diff -= -0.5 * jnp.sum(jnp.square(state_.p)) 337 | 338 | # accept_prob = jnp.exp(diff) 339 | # accept = jnp.log(jnp.abs(u)) < diff 340 | 341 | # # Negate the initial momentum 342 | # init = init._replace(p=-init.p) 343 | # state = tree_map(lambda x, y: jnp.where(accept, x, y), state_, init) 344 | # new_u = u * (~accept + accept / accept_prob) 345 | 346 | # return ( 347 | # state, 348 | # new_u, 349 | # { 350 | # "accept": accept, 351 | # "accept_prob": accept_prob, 352 | # "eps": eps, 353 | # "alpha": alpha, 354 | # "delta": delta, 355 | # "u": new_u, 356 | # }, 357 | # ) 358 | 359 | 360 | # def _larget_eigenvalue_of_cov(x: Array, remove_mean: bool = True) -> Array: 361 | # if remove_mean: 362 | # x = x - jnp.mean(x, axis=0) 363 | # trace_est = jnp.sum(jnp.square(x)) / x.shape[0] 364 | # trace_sq_est = jnp.sum(jnp.square(x @ x.T)) / x.shape[0] ** 2 365 | # return trace_sq_est / trace_est 366 | 367 | 368 | # @dataclass(frozen=True) 369 | # class MEADS(RedBlue): 370 | # step_size_multiplier: Array = 0.5 371 | # damping_slowdown: Array = 1.0 372 | # diagonal_preconditioning: bool = True 373 | 374 | # def init( 375 | # self, 376 | # log_prob_fn: WrappedLogProbFn, 377 | # random_key: random.KeyArray, 378 | # ensemble: FlatWalkerState, 379 | # ) -> StepState: 380 | # key1, key2 = random.split(random_key) 381 | 382 | # augments = {} if ensemble.augments is None else ensemble.augments 383 | # augments["u"] = random.uniform( 384 | # key1, (ensemble.coordinates.shape[0],), minval=-1, maxval=1 385 | # ) 386 | # augments["p"] = random.normal(key2, ensemble.coordinates.shape) 387 | # dlogp, _ = jax.vmap(jax.grad(log_prob_fn, has_aux=True))( 388 | # ensemble.coordinates 389 | # ) 390 | # augments["d_log_prob"] = dlogp 391 | 392 | # updated = ensemble._replace(augments=augments) 393 | # return StepState(move_state={"iteration": 0}, walker_state=updated) 394 | 395 | # def propose( 396 | # self, 397 | # log_prob_fn: WrappedLogProbFn, 398 | # move_state: MoveState, 399 | # key: random.KeyArray, 400 | # target: FlatWalkerState, 401 | # complementary: FlatWalkerState, 402 | # ) -> Tuple[MoveState, FlatWalkerState, SampleStats]: 403 | # assert move_state is not None 404 | # assert target.augments is not None 405 | # assert complementary.augments is not None 406 | 407 | # # Apply preconditioning 408 | # if self.diagonal_preconditioning: 409 | # sigma = jnp.std(complementary.coordinates, axis=0) 410 | # else: 411 | # sigma = 1.0 412 | # scaled_coords = complementary.coordinates / sigma 413 | # scaled_grads = complementary.augments["d_log_prob"] * sigma 414 | 415 | # # Step size 416 | # max_eig_step = _larget_eigenvalue_of_cov( 417 | # scaled_grads, remove_mean=False 418 | # ) 419 | # eps = self.step_size_multiplier / jnp.sqrt(max_eig_step) 420 | # eps = jnp.minimum(1.0, eps) 421 | 422 | # # Damping 423 | # max_eig_damp = _larget_eigenvalue_of_cov(scaled_coords) 424 | # gamma = eps / jnp.sqrt(max_eig_damp) 425 | # gamma = jnp.maximum( 426 | # self.damping_slowdown / move_state["iteration"], gamma 427 | # ) 428 | # alpha = 1 - jnp.exp(-2 * gamma) 429 | # delta = 0.5 * alpha 430 | 431 | # init = HMCState( 432 | # q=target.coordinates, 433 | # p=target.augments["p"], 434 | # log_prob=target.log_probability, 435 | # d_log_prob=target.augments["d_log_prob"], 436 | # deterministics=target.deterministics, 437 | # ) 438 | # step = jax.vmap( 439 | # partial( 440 | # persistent_ghmc, 441 | # value_and_grad=jax.value_and_grad(log_prob_fn, has_aux=True), 442 | # eps=eps * sigma, 443 | # alpha=alpha, 444 | # delta=delta, 445 | # ) 446 | # ) 447 | # new_state, new_u, stats = step( 448 | # random.split(key, target.coordinates.shape[0]), 449 | # init, 450 | # target.augments["u"], 451 | # ) 452 | 453 | # updated = FlatWalkerState( 454 | # coordinates=new_state.q, 455 | # deterministics=new_state.deterministics, 456 | # log_probability=new_state.log_prob, 457 | # augments={ 458 | # "u": new_u, 459 | # "p": new_state.p, 460 | # "d_log_prob": new_state.d_log_prob, 461 | # }, 462 | # ) 463 | # return ( 464 | # dict(move_state, iteration=move_state.pop("iteration") + 1), 465 | # updated, 466 | # stats, 467 | # ) 468 | -------------------------------------------------------------------------------- /src/emcee_jax/host_callback.py: -------------------------------------------------------------------------------- 1 | from emcee_jax._src.host_callback import ( 2 | wrap_python_log_prob_fn as wrap_python_log_prob_fn, 3 | ) 4 | -------------------------------------------------------------------------------- /src/emcee_jax/moves.py: -------------------------------------------------------------------------------- 1 | from emcee_jax._src.moves import ( 2 | DiffEvol as DiffEvol, 3 | DiffEvolSlice as DiffEvolSlice, 4 | Move as Move, 5 | Slice as Slice, 6 | Stretch as Stretch, 7 | apply_accept as apply_accept, 8 | compose as compose, 9 | ) 10 | -------------------------------------------------------------------------------- /tests/test_host_callback.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from emcee_jax.host_callback import wrap_python_log_prob_fn 8 | 9 | 10 | def test_host_callback_vmap(seed=0): 11 | func = wrap_python_log_prob_fn(lambda x: -0.5 * np.sum(np.square(x))) 12 | 13 | arg = jax.random.normal(jax.random.PRNGKey(seed), (10, 3)) 14 | expected = jnp.stack([func(row) for row in arg]) 15 | computed = jax.vmap(func)(arg) 16 | 17 | np.testing.assert_allclose(computed, expected, rtol=1e-6) 18 | 19 | 20 | def test_host_callback_vmap_pytree(seed=0): 21 | func_py = wrap_python_log_prob_fn( 22 | lambda x: np.sum(np.square(x["x"])) + x["y"] 23 | ) 24 | func_jax = lambda x: jnp.sum(jnp.square(x["x"])) + x["y"] 25 | 26 | key1, key2 = jax.random.split(jax.random.PRNGKey(seed)) 27 | arg = { 28 | "x": jax.random.normal(key1, (10, 3)), 29 | "y": jax.random.normal(key2, (10,)), 30 | } 31 | expected = jax.vmap(func_jax)(arg) 32 | computed = jax.vmap(func_py)(arg) 33 | 34 | np.testing.assert_allclose(computed, expected, rtol=1e-6) 35 | -------------------------------------------------------------------------------- /tests/test_moves.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | from itertools import product 4 | 5 | import jax.numpy as jnp 6 | import pytest 7 | from jax import random, vmap 8 | from jax.flatten_util import ravel_pytree 9 | from scipy import stats 10 | 11 | from emcee_jax import EnsembleSampler, moves 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "ndim,move", 16 | product( 17 | [1, 2], 18 | [ 19 | moves.Stretch(), 20 | moves.DiffEvol(), 21 | moves.DiffEvolSlice(), 22 | moves.compose(moves.Stretch(), moves.DiffEvolSlice()), 23 | ], 24 | ), 25 | ) 26 | def test_uniform(ndim, move, seed=1, num_walkers=32, num_steps=2_000): 27 | key = random.PRNGKey(seed) 28 | coords_key, init_key, sample_key = random.split(key, 3) 29 | coords = random.uniform(coords_key, shape=(num_walkers, ndim)) 30 | sampler = EnsembleSampler( 31 | lambda x: jnp.sum(jnp.where((0 < x) & (x < 1), 0.0, -jnp.inf)), 32 | move=move, 33 | ) 34 | state = sampler.init(init_key, coords) 35 | trace = sampler.sample(sample_key, state, num_steps) 36 | flat_samples = vmap(vmap(lambda x: ravel_pytree(x)[0]))( 37 | trace.samples.coordinates 38 | ) 39 | assert flat_samples.shape == (num_steps, num_walkers, ndim) 40 | for n in range(ndim): 41 | _, pvalue = stats.kstest( 42 | flat_samples[::100, :, n].flatten(), "uniform" 43 | ) 44 | assert pvalue > 0.01, n 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "pytree,move", 49 | product( 50 | [True, False], 51 | [moves.Stretch(), moves.DiffEvol(), moves.DiffEvolSlice()], 52 | ), 53 | ) 54 | def test_normal(pytree, move, seed=1, num_walkers=32, num_steps=2_000): 55 | def log_prob(x): 56 | if pytree: 57 | x, y = x["x"], x["y"] 58 | else: 59 | x, y = x 60 | return -0.5 * (x**2 + y**2) 61 | 62 | key = random.PRNGKey(seed) 63 | coords_key, init_key, sample_key = random.split(key, 3) 64 | coords = random.normal(coords_key, shape=(num_walkers, 2)) 65 | if pytree: 66 | coords = {"x": coords[:, 0], "y": coords[:, 1]} 67 | sampler = EnsembleSampler(log_prob, move=move) 68 | state = sampler.init(init_key, coords) 69 | trace = sampler.sample(sample_key, state, num_steps) 70 | flat_samples = vmap(vmap(lambda x: ravel_pytree(x)[0]))( 71 | trace.samples.coordinates 72 | ) 73 | assert flat_samples.shape == (num_steps, num_walkers, 2) 74 | for n in range(2): 75 | _, pvalue = stats.kstest(flat_samples[::100, :, n].flatten(), "norm") 76 | assert pvalue > 0.01, n 77 | -------------------------------------------------------------------------------- /tests/test_ravel_util.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | from jax.tree_util import tree_flatten, tree_structure 8 | 9 | from emcee_jax._src.ravel_util import ravel_ensemble 10 | 11 | ensembles_and_shapes = [ 12 | (jnp.ones((5, 3)), (5, 3)), 13 | ( 14 | {"x": jnp.zeros((3, 2)), "y": (jnp.ones(3), 2 + jnp.zeros((3, 2, 4)))}, 15 | (3, 11), 16 | ), 17 | ((jnp.zeros((3, 2)), jnp.ones((3, 4), dtype=int)), (3, 6)), 18 | ] 19 | 20 | 21 | @pytest.mark.parametrize("ensemble,shape", ensembles_and_shapes) 22 | def test_shape(ensemble, shape): 23 | flat, _ = ravel_ensemble(ensemble) 24 | assert flat.shape == shape 25 | 26 | 27 | @pytest.mark.parametrize("ensemble", [e for e, _ in ensembles_and_shapes]) 28 | def test_round_trip(ensemble): 29 | flat, unravel = ravel_ensemble(ensemble) 30 | computed = jax.vmap(unravel)(flat) 31 | assert tree_structure(computed) == tree_structure(ensemble) 32 | for a, b in zip(tree_flatten(computed)[0], tree_flatten(ensemble)[0]): 33 | assert a.dtype == b.dtype 34 | np.testing.assert_allclose(a, b) 35 | -------------------------------------------------------------------------------- /tests/test_sampler.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | 8 | import emcee_jax 9 | from emcee_jax import moves 10 | from emcee_jax.host_callback import wrap_python_log_prob_fn 11 | 12 | 13 | def build_rosenbrock(pytree_input=False, deterministics=False): 14 | def rosenbrock(theta, a1=100.0, a2=20.0): 15 | if pytree_input: 16 | x1, x2 = theta["x"], theta["y"] 17 | else: 18 | x1, x2 = theta 19 | log_prob = -(a1 * (x2 - x1**2) ** 2 + (1 - x1) ** 2) / a2 20 | 21 | if deterministics: 22 | return log_prob, {"some_number": x1 + jnp.sin(x2)} 23 | return log_prob 24 | 25 | return rosenbrock 26 | 27 | 28 | def test_basic(seed=0, num_walkers=5, num_steps=21): 29 | log_prob = build_rosenbrock() 30 | key1, key2, key3 = jax.random.split(jax.random.PRNGKey(seed), 3) 31 | coords = jax.random.normal(key1, shape=(num_walkers, 2)) 32 | sampler = emcee_jax.EnsembleSampler(log_prob) 33 | state = sampler.init(key2, coords) 34 | trace = sampler.sample(key3, state, num_steps) 35 | samples = trace.samples 36 | assert samples.deterministics is None 37 | assert samples.coordinates.shape == (num_steps, num_walkers, 2) 38 | assert samples.log_probability.shape == (num_steps, num_walkers) 39 | assert trace.sample_stats["accept"].shape == (num_steps, num_walkers) 40 | 41 | 42 | def test_pytree_input(seed=0, num_walkers=5, num_steps=21): 43 | log_prob = build_rosenbrock(pytree_input=True) 44 | key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(seed), 4) 45 | coords = { 46 | "x": jax.random.normal(key1, shape=(num_walkers,)), 47 | "y": jax.random.normal(key2, shape=(num_walkers,)), 48 | } 49 | sampler = emcee_jax.EnsembleSampler(log_prob) 50 | state = sampler.init(key3, coords) 51 | trace = sampler.sample(key4, state, num_steps) 52 | samples = trace.samples 53 | shape = (num_steps, num_walkers) 54 | assert samples.deterministics is None 55 | assert samples.coordinates["x"].shape == shape 56 | assert samples.log_probability.shape == shape 57 | assert trace.sample_stats["accept"].shape == shape 58 | 59 | 60 | def test_deterministics(seed=0, num_walkers=5, num_steps=21): 61 | log_prob = build_rosenbrock(deterministics=True) 62 | key1, key2, key3 = jax.random.split(jax.random.PRNGKey(seed), 3) 63 | coords = jax.random.normal(key1, shape=(num_walkers, 2)) 64 | sampler = emcee_jax.EnsembleSampler(log_prob) 65 | state = sampler.init(key2, coords) 66 | trace = sampler.sample(key3, state, num_steps) 67 | samples = trace.samples 68 | shape = (num_steps, num_walkers) 69 | assert samples.deterministics["some_number"].shape == shape 70 | assert samples.coordinates.shape == (num_steps, num_walkers, 2) 71 | assert samples.log_probability.shape == shape 72 | assert trace.sample_stats["accept"].shape == shape 73 | 74 | 75 | def test_host_callback(seed=0, num_walkers=5, num_steps=21): 76 | import numpy as np 77 | 78 | @wrap_python_log_prob_fn 79 | def log_prob(theta, a1=100.0, a2=20.0): 80 | x1, x2 = theta 81 | return -(a1 * np.square(x2 - x1**2) + np.square(1 - x1)) / a2 82 | 83 | num_walkers, num_steps = 100, 1000 84 | key1, key2, key3 = jax.random.split(jax.random.PRNGKey(seed), 3) 85 | coords = jax.random.normal(key1, shape=(num_walkers, 2)) 86 | sampler = emcee_jax.EnsembleSampler(log_prob) 87 | state = sampler.init(key2, coords) 88 | trace = sampler.sample(key3, state, num_steps) 89 | samples = trace.samples 90 | assert samples.deterministics is None 91 | assert samples.coordinates.shape == (num_steps, num_walkers, 2) 92 | assert samples.log_probability.shape == (num_steps, num_walkers) 93 | assert trace.sample_stats["accept"].shape == (num_steps, num_walkers) 94 | 95 | 96 | def test_init_errors(seed=0, num_walkers=5, num_steps=21): 97 | def check_raises(log_prob): 98 | key1, key2 = jax.random.split(jax.random.PRNGKey(seed)) 99 | coords = jax.random.normal(key1, shape=(num_walkers, 2)) 100 | sampler = emcee_jax.EnsembleSampler(log_prob) 101 | with pytest.raises(ValueError): 102 | state = sampler.init(key2, coords) 103 | 104 | check_raises(lambda *_: None) 105 | check_raises(lambda *_: jnp.ones(2)) 106 | check_raises(lambda *_: jnp.ones(4)) 107 | check_raises(lambda *_: jnp.ones(5)) 108 | 109 | 110 | def test_to_inference_data_basic(seed=0, num_walkers=5, num_steps=21): 111 | pytest.importorskip("arviz") 112 | log_prob = build_rosenbrock() 113 | key1, key2, key3 = jax.random.split(jax.random.PRNGKey(seed), 3) 114 | coords = jax.random.normal(key1, shape=(num_walkers, 2)) 115 | sampler = emcee_jax.EnsembleSampler(log_prob) 116 | state = sampler.init(key2, coords) 117 | trace = sampler.sample(key3, state, num_steps) 118 | data = trace.to_inference_data() 119 | 120 | assert data.posterior.dims["chain"] == num_walkers 121 | assert data.posterior.dims["draw"] == num_steps 122 | np.testing.assert_allclose( 123 | np.swapaxes(data.posterior.param_0.values, 0, 1), 124 | trace.samples.coordinates, 125 | ) 126 | 127 | assert data.sample_stats.dims["chain"] == num_walkers 128 | assert data.sample_stats.dims["draw"] == num_steps 129 | assert data.sample_stats.lp.values.shape == (num_walkers, num_steps) 130 | np.testing.assert_allclose( 131 | data.sample_stats.lp.values.T, trace.samples.log_probability 132 | ) 133 | 134 | 135 | def test_to_inference_data_full(seed=0, num_walkers=5, num_steps=21): 136 | pytest.importorskip("arviz") 137 | log_prob = build_rosenbrock(pytree_input=True, deterministics=True) 138 | key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(seed), 4) 139 | coords = { 140 | "x": jax.random.normal(key1, shape=(num_walkers,)), 141 | "y": jax.random.normal(key2, shape=(num_walkers,)), 142 | } 143 | sampler = emcee_jax.EnsembleSampler( 144 | log_prob, move=moves.compose(moves.Stretch(), moves.DiffEvol()) 145 | ) 146 | state = sampler.init(key3, coords) 147 | trace = sampler.sample(key4, state, num_steps) 148 | data = trace.to_inference_data() 149 | 150 | assert data.posterior.dims["chain"] == num_walkers 151 | assert data.posterior.dims["draw"] == num_steps 152 | np.testing.assert_allclose( 153 | data.posterior.x.values.T, trace.samples.coordinates["x"] 154 | ) 155 | np.testing.assert_allclose( 156 | data.posterior.y.values.T, trace.samples.coordinates["y"] 157 | ) 158 | np.testing.assert_allclose( 159 | data.posterior.some_number.values.T, 160 | trace.samples.deterministics["some_number"], 161 | ) 162 | 163 | assert data.sample_stats.dims["chain"] == num_walkers 164 | assert data.sample_stats.dims["draw"] == num_steps 165 | assert data.sample_stats.lp.values.shape == (num_walkers, num_steps) 166 | np.testing.assert_allclose( 167 | data.sample_stats.lp.values.T, trace.samples.log_probability 168 | ) 169 | --------------------------------------------------------------------------------