├── docs
├── requirements.txt
└── source
│ ├── _static
│ └── bayex.png
│ ├── api
│ ├── generated
│ │ ├── bayex.gp.predict.rst
│ │ ├── bayex.gp.grad_fun.rst
│ │ ├── bayex.gp.posterior_fit.rst
│ │ ├── bayex.gp.gaussian_process.rst
│ │ ├── bayex.gp.marginal_likelihood.rst
│ │ ├── bayex.acq.expected_improvement.rst
│ │ ├── bayex.acq.lower_confidence_bounds.rst
│ │ ├── bayex.acq.probability_improvement.rst
│ │ ├── bayex.acq.upper_confidence_bounds.rst
│ │ ├── bayex.domain.Real.rst
│ │ ├── bayex.domain.Domain.rst
│ │ ├── bayex.domain.Integer.rst
│ │ ├── bayex.optimizer.Optimizer.rst
│ │ ├── bayex.gp.GPState.rst
│ │ └── bayex.gp.GPParams.rst
│ ├── index.rst
│ ├── bayex.domain.rst
│ ├── bayex.optimizer.rst
│ ├── bayex.gp.rst
│ └── bayex.acq.rst
│ ├── conf.py
│ └── index.rst
├── bayex
├── __init__.py
├── gp.py
├── acq.py
├── domain.py
└── optimizer.py
├── .readthedocs.yml
├── .github
└── workflows
│ ├── tests.yml
│ └── python-publish.yml
├── LICENSE
├── pyproject.toml
├── tests
├── test_acq.py
├── test_optimizer.py
├── test_domain.py
└── test_gp.py
├── README.md
└── .gitignore
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | sphinx-book-theme
3 | myst-parser
4 | sphinx_autodoc_typehints
5 |
--------------------------------------------------------------------------------
/docs/source/_static/bayex.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alonfnt/bayex/HEAD/docs/source/_static/bayex.png
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.gp.predict.rst:
--------------------------------------------------------------------------------
1 | bayex.gp.predict
2 | ================
3 |
4 | .. currentmodule:: bayex.gp
5 |
6 | .. autofunction:: predict
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.gp.grad_fun.rst:
--------------------------------------------------------------------------------
1 | bayex.gp.grad\_fun
2 | ==================
3 |
4 | .. currentmodule:: bayex.gp
5 |
6 | .. autofunction:: grad_fun
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.gp.posterior_fit.rst:
--------------------------------------------------------------------------------
1 | bayex.gp.posterior\_fit
2 | =======================
3 |
4 | .. currentmodule:: bayex.gp
5 |
6 | .. autofunction:: posterior_fit
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.gp.gaussian_process.rst:
--------------------------------------------------------------------------------
1 | bayex.gp.gaussian\_process
2 | ==========================
3 |
4 | .. currentmodule:: bayex.gp
5 |
6 | .. autofunction:: gaussian_process
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.gp.marginal_likelihood.rst:
--------------------------------------------------------------------------------
1 | bayex.gp.marginal\_likelihood
2 | =============================
3 |
4 | .. currentmodule:: bayex.gp
5 |
6 | .. autofunction:: marginal_likelihood
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.acq.expected_improvement.rst:
--------------------------------------------------------------------------------
1 | bayex.acq.expected\_improvement
2 | ===============================
3 |
4 | .. currentmodule:: bayex.acq
5 |
6 | .. autofunction:: expected_improvement
--------------------------------------------------------------------------------
/bayex/__init__.py:
--------------------------------------------------------------------------------
1 | from .optimizer import Optimizer
2 | from . import domain, acq
3 |
4 |
5 | __version__ = "0.2.2"
6 |
7 | __all__ = [
8 | "Optimizer",
9 | "domain",
10 | "acq",
11 | ]
12 |
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.acq.lower_confidence_bounds.rst:
--------------------------------------------------------------------------------
1 | bayex.acq.lower\_confidence\_bounds
2 | ===================================
3 |
4 | .. currentmodule:: bayex.acq
5 |
6 | .. autofunction:: lower_confidence_bounds
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.acq.probability_improvement.rst:
--------------------------------------------------------------------------------
1 | bayex.acq.probability\_improvement
2 | ==================================
3 |
4 | .. currentmodule:: bayex.acq
5 |
6 | .. autofunction:: probability_improvement
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.acq.upper_confidence_bounds.rst:
--------------------------------------------------------------------------------
1 | bayex.acq.upper\_confidence\_bounds
2 | ===================================
3 |
4 | .. currentmodule:: bayex.acq
5 |
6 | .. autofunction:: upper_confidence_bounds
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-24.04
5 | tools:
6 | python: "3.12"
7 |
8 | sphinx:
9 | configuration: docs/source/conf.py
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
14 | - method: pip
15 | path: .
16 |
--------------------------------------------------------------------------------
/docs/source/api/index.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =============
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | bayex.acq
8 | bayex.domain
9 | bayex.optimizer
10 |
11 | Module contents
12 | ---------------
13 |
14 | .. automodule:: bayex
15 | :members:
16 | :show-inheritance:
17 | :undoc-members:
18 |
--------------------------------------------------------------------------------
/docs/source/api/bayex.domain.rst:
--------------------------------------------------------------------------------
1 | ``bayex.domain`` module
2 | -----------------------
3 |
4 | .. currentmodule:: bayex.domain
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 | :nosignatures:
9 |
10 | Real
11 | Integer
12 |
13 | .. automodule:: bayex.domain
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 | :no-index:
18 |
19 |
--------------------------------------------------------------------------------
/docs/source/api/bayex.optimizer.rst:
--------------------------------------------------------------------------------
1 | ``bayex.optimizer`` module
2 | --------------------------
3 |
4 | .. currentmodule:: bayex.optimizer
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 | :nosignatures:
9 |
10 | Optimizer
11 |
12 | .. automodule:: bayex.optimizer
13 | :members:
14 | :undoc-members:
15 | :show-inheritance:
16 | :no-index:
17 | :exclude-members: _fit, expand
18 |
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.domain.Real.rst:
--------------------------------------------------------------------------------
1 | bayex.domain.Real
2 | =================
3 |
4 | .. currentmodule:: bayex.domain
5 |
6 | .. autoclass:: Real
7 |
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. rubric:: Methods
13 |
14 | .. autosummary::
15 |
16 | ~Real.__init__
17 | ~Real.sample
18 | ~Real.transform
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.domain.Domain.rst:
--------------------------------------------------------------------------------
1 | bayex.domain.Domain
2 | ===================
3 |
4 | .. currentmodule:: bayex.domain
5 |
6 | .. autoclass:: Domain
7 |
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. rubric:: Methods
13 |
14 | .. autosummary::
15 |
16 | ~Domain.__init__
17 | ~Domain.sample
18 | ~Domain.transform
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.domain.Integer.rst:
--------------------------------------------------------------------------------
1 | bayex.domain.Integer
2 | ====================
3 |
4 | .. currentmodule:: bayex.domain
5 |
6 | .. autoclass:: Integer
7 |
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. rubric:: Methods
13 |
14 | .. autosummary::
15 |
16 | ~Integer.__init__
17 | ~Integer.sample
18 | ~Integer.transform
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/docs/source/api/bayex.gp.rst:
--------------------------------------------------------------------------------
1 | ``bayex.gp`` module
2 | -------------------
3 |
4 | .. currentmodule:: bayex.gp
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 | :nosignatures:
9 |
10 | GPParams
11 | GPState
12 | marginal_likelihood
13 | predict
14 | posterior_fit
15 | gaussian_process
16 | grad_fun
17 |
18 | .. automodule:: bayex.gp
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 | :no-index:
23 |
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.optimizer.Optimizer.rst:
--------------------------------------------------------------------------------
1 | bayex.optimizer.Optimizer
2 | =========================
3 |
4 | .. currentmodule:: bayex.optimizer
5 |
6 | .. autoclass:: Optimizer
7 |
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. rubric:: Methods
13 |
14 | .. autosummary::
15 |
16 | ~Optimizer.__init__
17 | ~Optimizer.expand
18 | ~Optimizer.fit
19 | ~Optimizer.init
20 | ~Optimizer.sample
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.gp.GPState.rst:
--------------------------------------------------------------------------------
1 | bayex.gp.GPState
2 | ================
3 |
4 | .. currentmodule:: bayex.gp
5 |
6 | .. autoclass:: GPState
7 |
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. rubric:: Methods
13 |
14 | .. autosummary::
15 |
16 | ~GPState.__init__
17 | ~GPState.count
18 | ~GPState.index
19 |
20 |
21 |
22 |
23 |
24 | .. rubric:: Attributes
25 |
26 | .. autosummary::
27 |
28 | ~GPState.momentums
29 | ~GPState.params
30 | ~GPState.scales
31 |
32 |
--------------------------------------------------------------------------------
/docs/source/api/generated/bayex.gp.GPParams.rst:
--------------------------------------------------------------------------------
1 | bayex.gp.GPParams
2 | =================
3 |
4 | .. currentmodule:: bayex.gp
5 |
6 | .. autoclass:: GPParams
7 |
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. rubric:: Methods
13 |
14 | .. autosummary::
15 |
16 | ~GPParams.__init__
17 | ~GPParams.count
18 | ~GPParams.index
19 |
20 |
21 |
22 |
23 |
24 | .. rubric:: Attributes
25 |
26 | .. autosummary::
27 |
28 | ~GPParams.amplitude
29 | ~GPParams.lengthscale
30 | ~GPParams.noise
31 |
32 |
--------------------------------------------------------------------------------
/docs/source/api/bayex.acq.rst:
--------------------------------------------------------------------------------
1 | ``bayex.acq`` module
2 | --------------------
3 |
4 | This module contains the implementation of various acquisition functions used in Bayesian optimization.
5 | These functions help in selecting the next point to sample based on the current model of the objective function.
6 |
7 | .. currentmodule:: bayex.acq
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 | :nosignatures:
12 |
13 | expected_improvement
14 | probability_improvement
15 | upper_confidence_bounds
16 | lower_confidence_bounds
17 |
18 | .. automodule:: bayex.acq
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 | :no-index:
23 |
24 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | build:
11 | name: "Python ${{ matrix.python-version }}"
12 | runs-on: ubuntu-latest
13 |
14 | strategy:
15 | matrix:
16 | python-version: ["3.11", "3.12", "3.13"]
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 |
21 | - uses: actions/setup-python@v4
22 | with:
23 | python-version: "${{ matrix.python-version }}"
24 |
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | python -m pip install -e . pytest
29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
30 |
31 | - name: Test with pytest
32 | run: |
33 | pytest
34 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # Workflow that publishes the last release to PyPI.
2 |
3 | name: Upload Python Package
4 |
5 | on:
6 | release:
7 | types: [published]
8 |
9 | workflow_dispatch:
10 |
11 | jobs:
12 | pypi-publish:
13 | name: Upload release to PyPI
14 | runs-on: ubuntu-latest
15 | environment:
16 | name: pypi
17 | url: https://pypi.org/p/bayex
18 | permissions:
19 | id-token: write
20 | steps:
21 |
22 | - uses: actions/checkout@v3
23 |
24 | - name: Set up Python
25 | uses: actions/setup-python@v3
26 | with:
27 | python-version: '3.x'
28 |
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 |
34 | - name: Build package
35 | run: python -m build
36 |
37 | - name: Publish package distributions to PyPI
38 | uses: pypa/gh-action-pypi-publish@release/v1
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Albert Alonso
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "bayex"
7 | version = "0.2.2"
8 | description = "Minimal Bayesian Optimization in JAX using Gaussian Processes."
9 | readme = "README.md"
10 | authors = [{name = "Albert Alonso", email = "alonfnt@pm.me"}, ]
11 | license = {file = "LICENSE"}
12 | keywords = ["jax", "bayesian-optimization", "automatic-differentiation", "gaussian-process", "machine-learning"]
13 | classifiers = [
14 | "Development Status :: 3 - Alpha",
15 | "Intended Audience :: Science/Research",
16 | "Topic :: Scientific/Engineering",
17 | "License :: OSI Approved :: MIT License",
18 | "Programming Language :: Python :: 3",
19 | ]
20 | dependencies = ["jax", "jaxlib", "numpy", "optax"]
21 |
22 | [project.optional-dependencies]
23 | dev = ["pytest"]
24 |
25 | [tool.setuptools]
26 | packages = ["bayex"]
27 |
28 | [project.urls]
29 | "Homepage" = "https://github.com/alonfnt/bayex"
30 | "Documentation" = "https://github.com/alonfnt/bayex"
31 | "Source" = "https://github.com/alonfnt/bayex"
32 | "Bug Tracker" = "https://github.com/alonfnt/bayex/issues"
33 |
--------------------------------------------------------------------------------
/tests/test_acq.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import jax
3 | import jax.numpy as jnp
4 |
5 | from bayex.gp import GPParams
6 | from bayex.acq import (
7 | expected_improvement,
8 | probability_improvement,
9 | upper_confidence_bounds,
10 | lower_confidence_bounds,
11 | )
12 |
13 |
14 | @pytest.fixture
15 | def gp_inputs():
16 | x = jnp.linspace(-3, 3, 10)
17 | y = jnp.sin(x)
18 | mask = jnp.ones_like(x)
19 | xt = jnp.linspace(-4, 4, 5)
20 | params = GPParams(noise=0.1, amplitude=1.0, lengthscale=1.0)
21 | return xt, x, y, mask, params
22 |
23 |
24 | @pytest.mark.parametrize("padding", [0, 3])
25 | def test_acquisition_invariance_to_padding(gp_inputs, padding):
26 | xt, x, y, mask, params = gp_inputs
27 | ei_ref = expected_improvement(xt, x, y, mask, params)
28 |
29 | x_pad, y_pad, mask_pad = jax.tree.map(lambda t: jnp.pad(t, (0, padding)), (x, y, mask))
30 | ei_pad = expected_improvement(xt, x_pad, y_pad, mask_pad, params)
31 |
32 | assert ei_pad.shape == ei_ref.shape
33 | assert jnp.allclose(ei_ref, ei_pad, atol=1e-5)
34 |
35 |
36 | @pytest.mark.parametrize("acq_fn", [
37 | expected_improvement,
38 | probability_improvement,
39 | upper_confidence_bounds,
40 | lower_confidence_bounds,
41 | ])
42 | def test_acquisition_output_validity(acq_fn, gp_inputs):
43 | xt, x, y, mask, params = gp_inputs
44 | acq_vals = acq_fn(xt, x, y, mask, params)
45 |
46 | assert acq_vals.shape == xt.shape
47 | assert jnp.all(jnp.isfinite(acq_vals))
48 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 | import os
6 | import sys
7 |
8 | sys.path.insert(0, os.path.abspath('../../bayex'))
9 |
10 | # -- Project information -----------------------------------------------------
11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
12 |
13 | project = 'bayex'
14 | copyright = '2025, Albert Alonso'
15 | author = 'Albert Alonso'
16 | release = '0.2.2'
17 |
18 | # -- General configuration ---------------------------------------------------
19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
20 |
21 | extensions = [
22 | 'myst_parser',
23 | 'sphinx.ext.autodoc',
24 | 'sphinx.ext.napoleon',
25 | 'sphinx.ext.autosummary',
26 | 'sphinx_autodoc_typehints',
27 | 'sphinx.ext.mathjax',
28 | ]
29 |
30 | templates_path = ['_templates']
31 | exclude_patterns = []
32 |
33 | autodoc_default_options = {
34 | 'members': True,
35 | 'undoc-members': False,
36 | 'private-members': False,
37 | 'special-members': False,
38 | 'inherited-members': True,
39 | }
40 |
41 |
42 |
43 | # -- Options for HTML output -------------------------------------------------
44 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
45 |
46 | html_theme = 'sphinx_book_theme'
47 | html_static_path = ["_static"]
48 | html_logo = "_static/bayex.png"
49 |
--------------------------------------------------------------------------------
/tests/test_optimizer.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import pytest
5 |
6 | import bayex
7 |
8 | KEY = jax.random.key(42)
9 | SEED = 42
10 |
11 |
12 | def test_1D_optim():
13 |
14 | def f(x):
15 | return jnp.sin(x) - 0.1 * x**2
16 |
17 | TARGET = np.max(f(np.linspace(-2, 2, 1000)))
18 |
19 | domain = {'x': bayex.domain.Real(-2, 2)}
20 |
21 | opt = bayex.Optimizer(domain=domain, maximize=True)
22 |
23 | # Evaluate 3 times the function
24 | params = {'x': [0.0, 1.0, 2.0]}
25 | ys = [f(x) for x in params['x']]
26 | opt_state = opt.init(ys, params)
27 |
28 | assert opt_state.best_score == np.max(ys)
29 | assert opt_state.best_params['x'] == params['x'][np.argmax(ys)]
30 |
31 | assert type(opt_state) == bayex.optimizer.OptimizerState
32 | assert type(opt_state.gp_params) == bayex.gp.GPParams # pyright: ignore
33 |
34 | assert opt_state.params['x'].shape == (10,)
35 | assert opt_state.ys.shape == (10,)
36 | assert opt_state.mask.shape == (10,)
37 |
38 | assert np.allclose(opt_state.params['x'][:3], params['x'])
39 | assert np.allclose(opt_state.ys[:3], ys)
40 | assert np.allclose(opt_state.mask[:3], [True, True, True])
41 | assert np.allclose(opt_state.mask[3:], [False] * 7)
42 |
43 | key = jax.random.key(SEED)
44 |
45 | sample_fn = jax.jit(opt.sample)
46 | for step in range(100):
47 | key = jax.random.fold_in(key, step)
48 | new_params = sample_fn(key, opt_state)
49 | y = f(**new_params)
50 | opt_state = opt.fit(opt_state, y, new_params)
51 | if jnp.allclose(opt_state.best_score, TARGET, atol=1e-03):
52 | break
53 | target = opt_state.best_score
54 | assert jnp.allclose(TARGET, target, atol=1e-02)
55 |
56 |
57 | def test_evaluate_raise_invalid_acq_fun():
58 | domain = {'x': bayex.domain.Real(-2, 2)}
59 |
60 | # This shouldn't raise an error
61 | for acq in ['EI',]:
62 | bayex.Optimizer(domain=domain, acq=acq)
63 |
64 | # But this should!
65 | for acq in ['random', 'magic']:
66 | with pytest.raises(ValueError, match=f"Acquisition function {acq} is not implemented"):
67 | bayex.Optimizer(domain=domain, acq=acq)
68 |
--------------------------------------------------------------------------------
/tests/test_domain.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import jax
3 | import jax.numpy as jnp
4 | from bayex.domain import Real, Integer
5 |
6 |
7 | @pytest.mark.parametrize("lower, upper", [(0.0, 1.0), (-5.5, 5.5)])
8 | def test_real_transform_clips_within_bounds(lower, upper):
9 | domain = Real(lower, upper)
10 | x = jnp.array([lower - 1, (lower + upper) / 2, upper + 1])
11 | out = domain.transform(x)
12 | assert jnp.all(out >= lower) and jnp.all(out <= upper)
13 |
14 |
15 | def test_real_sample_within_bounds():
16 | domain = Real(0.0, 1.0)
17 | key = jax.random.PRNGKey(0)
18 | shape = (100,)
19 | samples = domain.sample(key, shape)
20 | assert samples.shape == shape
21 | assert samples.dtype == jnp.float32
22 | assert jnp.all(samples >= 0.0) and jnp.all(samples <= 1.0)
23 |
24 |
25 | @pytest.mark.parametrize("lower, upper", [(0, 5), (-3, 3)])
26 | def test_integer_transform_and_type(lower, upper):
27 | domain = Integer(lower, upper)
28 | x = jnp.array([lower - 2.3, (lower + upper) / 2.0, upper + 2.7])
29 | out = domain.transform(x)
30 | assert jnp.issubdtype(out.dtype, jnp.floating)
31 | assert jnp.all(out >= lower) and jnp.all(out <= upper)
32 | assert jnp.all(out == jnp.round(out))
33 |
34 |
35 | def test_integer_sample_within_bounds():
36 | domain = Integer(1, 10)
37 | key = jax.random.PRNGKey(42)
38 | shape = (1000,)
39 | samples = domain.sample(key, shape)
40 | assert samples.shape == shape
41 | assert jnp.all(samples >= 1) and jnp.all(samples <= 10)
42 | assert jnp.all(jnp.equal(samples, jnp.round(samples)))
43 |
44 |
45 | def test_domain_equality_and_hash():
46 | a = Real(0.0, 1.0)
47 | b = Real(0.0, 1.0)
48 | c = Real(1.0, 2.0)
49 | assert a == b
50 | assert a != c
51 | assert hash(a) == hash(b)
52 | assert hash(a) != hash(c)
53 |
54 | i1 = Integer(1, 5)
55 | i2 = Integer(1, 5)
56 | i3 = Integer(0, 4)
57 | assert i1 == i2
58 | assert i1 != i3
59 | assert hash(i1) == hash(i2)
60 | assert hash(i1) != hash(i3)
61 |
62 |
63 | @pytest.mark.parametrize("lower, upper", [
64 | (1.0, 1.0), # equal
65 | ("a", 1.0), # wrong type
66 | (1.0, "b"), # wrong type
67 | (5.0, 3.0), # lower > upper
68 | ])
69 | def test_real_init_invalid(lower, upper):
70 | with pytest.raises(AssertionError):
71 | Real(lower, upper)
72 |
73 |
74 | @pytest.mark.parametrize("lower, upper", [
75 | (5, 5), # equal
76 | ("a", 5), # wrong type
77 | (0, "b"), # wrong type
78 | (10, 1), # lower > upper
79 | ])
80 | def test_integer_init_invalid(lower, upper):
81 | with pytest.raises(AssertionError):
82 | Integer(lower, upper)
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [](https://github.com/alonfnt/bayex/actions/workflows/tests.yml)
6 | [](https://bayex.readthedocs.io/en/latest/)
7 | [](https://pypi.org/project/bayex/)
8 |
9 | >[!NOTE]
10 | >Bayex is currently a minimal, personally developed implementation that requires further development for broader application. If you're interested in engaging with Jax and enhancing Bayex, your contributions would be highly welcomed and appreciated.
11 |
12 | [**Installation**](#installation)
13 | | [**Usage**](#usage)
14 | | [**Reference docs**](https://bayex.readthedocs.io/en/latest/)
15 |
16 | Bayex is a lightweight Bayesian optimization library designed for efficiency and flexibility, leveraging the power of JAX for high-performance numerical computations.
17 | This library aims to provide an easy-to-use interface for optimizing expensive-to-evaluate functions through Gaussian Process (GP) models and various acquisition functions. Whether you're maximizing or minimizing your objective function, Bayex offers a simple yet powerful set of tools to guide your search for optimal parameters.
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## Installation
25 | Bayex can be installed using [PyPI](https://pypi.org/project/bayex/) via `pip`:
26 | ```
27 | pip install bayex
28 | ```
29 |
30 | ## Usage
31 | Using Bayex is quite simple despite its low level approach:
32 | ```python
33 | import jax
34 | import numpy as np
35 | import bayex
36 |
37 | def f(x):
38 | return -(1.4 - 3 * x) * np.sin(18 * x)
39 |
40 | domain = {'x': bayex.domain.Real(0.0, 2.0)}
41 | optimizer = bayex.Optimizer(domain=domain, maximize=True, acq='PI')
42 |
43 | # Define some prior evaluations to initialise the GP.
44 | params = {'x': [0.0, 0.5, 1.0]}
45 | ys = [f(x) for x in params['x']]
46 | opt_state = optimizer.init(ys, params)
47 |
48 | # Sample new points using Jax PRNG approach.
49 | ori_key = jax.random.key(42)
50 | for step in range(20):
51 | key = jax.random.fold_in(ori_key, step)
52 | new_params = optimizer.sample(key, opt_state)
53 | y_new = f(**new_params)
54 | opt_state = optimizer.fit(opt_state, y_new, new_params)
55 | ```
56 |
57 | with the results being saved at `opt_state.best_params`.
58 |
59 | ## Documentation
60 | Available at [https://bayex.readthedocs.io/en/latest](https://bayex.readthedocs.io/en/latest/).
61 |
62 | ## License
63 | Bayex is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details.
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Visual studio folks
121 | .vscode/
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
137 | # pytype static type analyzer
138 | .pytype/
139 |
140 | # Cython debug symbols
141 | cython_debug/
142 |
143 | # Poetry
144 | .python-version
145 |
146 | # CTags
147 | /tags*
148 |
149 | # Vim swap files
150 | *.swp
151 | *.swo
152 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | Bayex
2 | ===========================================
3 |
4 | Minimal Bayesian Optimization in JAX
5 | ------------------------------------
6 |
7 | .. image:: https://github.com/alonfnt/bayex/actions/workflows/tests.yml/badge.svg
8 | :target: https://github.com/alonfnt/bayex/actions/workflows/tests.yml
9 | :alt: Test status
10 |
11 | .. note::
12 | Bayex is currently a minimal, personally developed implementation that requires further development for broader application.
13 | If you're interested in engaging with JAX and enhancing Bayex, your contributions would be highly welcomed and appreciated.
14 |
15 | .. raw:: html
16 |
17 |
18 |
19 |
20 |
21 | Bayex is a lightweight Bayesian optimization library designed for efficiency and flexibility, leveraging the power of JAX for high-performance numerical computations.
22 |
23 | This library aims to provide an easy-to-use interface for optimizing expensive-to-evaluate functions through Gaussian Process (GP) models and various acquisition functions. Whether you're maximizing or minimizing your objective function, Bayex offers a simple yet powerful set of tools to guide your search for optimal parameters.
24 |
25 | Installation
26 | ------------
27 |
28 | Bayex can be installed using `PyPI `_ via ``pip``:
29 |
30 | .. code-block:: bash
31 |
32 | pip install bayex
33 |
34 | Usage
35 | -----
36 |
37 | Using Bayex is quite simple despite its low-level approach:
38 |
39 | .. code-block:: python
40 |
41 | import jax
42 | import numpy as np
43 | import bayex
44 |
45 | def f(x):
46 | return -(1.4 - 3 * x) * np.sin(18 * x)
47 |
48 | domain = {'x': bayex.domain.Real(0.0, 2.0)}
49 | optimizer = bayex.Optimizer(domain=domain, maximize=True, acq='PI')
50 |
51 | # Define some prior evaluations to initialise the GP.
52 | params = {'x': [0.0, 0.5, 1.0]}
53 | ys = [f(x) for x in params['x']]
54 | opt_state = optimizer.init(ys, params)
55 |
56 | # Sample new points using Jax PRNG approach.
57 | ori_key = jax.random.key(42)
58 | for step in range(20):
59 | key = jax.random.fold_in(ori_key, step)
60 | new_params = optimizer.sample(key, opt_state)
61 | y_new = f(**new_params)
62 | opt_state = optimizer.fit(opt_state, y_new, new_params)
63 |
64 | With the results being saved at ``opt_state``.
65 |
66 | Contributing
67 | ------------
68 |
69 | We welcome contributions to Bayex! Whether it's adding new features, improving documentation, or reporting issues, please feel free to make a pull request or open an issue.
70 |
71 | License
72 | -------
73 |
74 | Bayex is licensed under the MIT License. See the `LICENSE `_ file for more details.
75 |
76 | .. toctree::
77 | :maxdepth: 2
78 | :caption: API Reference
79 |
80 | api/index
81 |
--------------------------------------------------------------------------------
/tests/test_gp.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import pytest
4 | from bayex import gp
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "x1, x2, mask, expected",
9 | [
10 | (jnp.array([0]), jnp.array([0]), 1, 1),
11 | (jnp.array([0]), jnp.array([1]), 1, jnp.exp(-1)),
12 | (jnp.array([1]), jnp.array([1]), 0, 0),
13 | ],
14 | )
15 | def test_exp_quadratic(x1, x2, mask, expected):
16 | result = gp.exp_quadratic(x1, x2, mask)
17 | assert jnp.isclose(result, expected), f"Expected {expected}, got {result}"
18 |
19 |
20 | @pytest.mark.parametrize(
21 | "x1, x2, mask1, mask2, expected_shape",
22 | [
23 | (jnp.linspace(-5, 5, 10), jnp.linspace(-5, 5, 10), jnp.ones(10), jnp.ones(10), (10, 10)),
24 | (jnp.linspace(-5, 5, 5), jnp.linspace(-5, 5, 10), jnp.ones(5), jnp.ones(10), (5, 10)),
25 | ],
26 | )
27 | def test_covariance_shape(x1, x2, mask1, mask2, expected_shape):
28 | cov_matrix = gp.cov(x1, x2, mask1, mask2)
29 | assert (
30 | cov_matrix.shape == expected_shape
31 | ), f"Expected shape {expected_shape}, got {cov_matrix.shape}"
32 |
33 |
34 | @pytest.mark.parametrize(
35 | "compute_ml, expected_output_type",
36 | [
37 | (False, tuple), # Expecting mean and std as output
38 | (True, jnp.ndarray), # Expecting marginal likelihood as output
39 | ],
40 | )
41 | def test_gaussian_process_output_type(compute_ml, expected_output_type):
42 | params = gp.GPParams(noise=0.1, amplitude=1.0, lengthscale=1.0)
43 | x = jnp.linspace(-5, 5, 10)
44 | y = jnp.sin(x)
45 | mask = jnp.ones_like(x, dtype=bool)
46 | xt = jnp.array([0.0]) if not compute_ml else None
47 |
48 | output = gp.gaussian_process(params, x, y, mask, xt, compute_ml=compute_ml)
49 | assert isinstance(
50 | output, expected_output_type
51 | ), f"Output type mismatch. Expected: {expected_output_type}, got: {type(output)}"
52 |
53 |
54 | @pytest.mark.parametrize("padding", [0, 1, 5, 10])
55 | def test_masking_in_gaussian_process(padding: int):
56 |
57 | params = gp.GPParams(noise=0.1, amplitude=1.0, lengthscale=1.0)
58 |
59 | x = jnp.linspace(-5, 5, 10)
60 | y = jnp.sin(x)
61 | mask = jnp.ones_like(x, dtype=float)
62 | reference = gp.marginal_likelihood(params, x, y, mask)
63 |
64 | x_pad, y_pad, mask_pad = jax.tree.map(
65 | lambda x: jnp.pad(x, (0, padding)), (x, y, mask)
66 | )
67 | assert len(x_pad) == len(x) + padding, "X should be padded to 10 + padding length"
68 | assert mask_pad[len(mask):].sum() == 0, "Mask should be zero for padded values"
69 |
70 | output = gp.marginal_likelihood(params, x_pad, y_pad, mask_pad)
71 |
72 | assert jnp.allclose(reference, output), f"Mismatch for padding={padding}"
73 |
74 |
75 | @pytest.mark.parametrize("padding", [0, 1, 5, 10])
76 | def test_masking_in_prediction(padding: int):
77 |
78 | params = gp.GPParams(noise=0.1, amplitude=1.0, lengthscale=1.0)
79 |
80 | x = jnp.linspace(-5, 5, 10)
81 | y = jnp.sin(x)
82 | mask = jnp.ones_like(x)
83 | xt = jnp.linspace(-6, 6, 20)
84 |
85 | # Reference prediction without padding
86 | mean_ref, std_ref = gp.predict(params, x, y, mask, xt)
87 |
88 | # Apply padding
89 | x_pad, y_pad, mask_pad = jax.tree.map(
90 | lambda a: jnp.pad(a, (0, padding)), (x, y, mask)
91 | )
92 |
93 | # Prediction with padded inputs
94 | mean_pad, std_pad = gp.predict(params, x_pad, y_pad, mask_pad, xt)
95 |
96 | assert mean_pad.shape == mean_ref.shape
97 | assert std_pad.shape == std_ref.shape
98 | assert jnp.allclose(mean_pad, mean_ref, atol=1e-5), f"Mean mismatch for padding={padding}"
99 | assert jnp.allclose(std_pad, std_ref, atol=1e-5), f"Std mismatch for padding={padding}"
100 |
--------------------------------------------------------------------------------
/bayex/gp.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from functools import partial
3 | from typing import Any, Optional
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | from jax.scipy.linalg import cholesky, solve_triangular
8 | import optax
9 |
10 | MASK_VARIANCE = 1e12 # High variance for masked points to not affect the process.
11 |
12 | GPParams = namedtuple("GPParams", ["noise", "amplitude", "lengthscale"])
13 | GPState = namedtuple("GPState", ["params", "momentums", "scales"])
14 |
15 | def exp_quadratic(x1, x2, mask):
16 | distance = jnp.sum((x1 - x2) ** 2)
17 | return jnp.exp(-distance) * mask
18 |
19 |
20 | def cov(x1, x2, mask1, mask2):
21 | M = jnp.outer(mask1, mask2)
22 | k = exp_quadratic
23 | return jax.vmap(jax.vmap(k, in_axes=(None, 0, 0)), in_axes=(0, None, 0))(x1, x2, M)
24 |
25 |
26 | def softplus(x):
27 | return jnp.logaddexp(x, 0.0)
28 |
29 |
30 | def gaussian_process(
31 | params,
32 | x: jnp.ndarray,
33 | y: jnp.ndarray,
34 | mask,
35 | xt: Optional[jnp.ndarray] = None,
36 | compute_ml: bool = False,
37 | ) -> Any:
38 | # Number of points in the prior distribution
39 | n = x.shape[0]
40 |
41 | noise, amp, ls = jax.tree_util.tree_map(softplus, params)
42 |
43 | ymean = jnp.mean(y, where=mask.astype(bool))
44 | y = (y - ymean) * mask
45 | x = x / ls
46 | K = amp * cov(x, x, mask, mask) + (jnp.eye(n) * (noise + 1e-6))
47 | K += jnp.eye(n) * (1.0 - mask.astype(float)) * MASK_VARIANCE
48 | L = cholesky(K, lower=True)
49 | K_inv_y = solve_triangular(L.T, solve_triangular(L, y, lower=True), lower=False)
50 |
51 | if compute_ml:
52 | logp = 0.5 * jnp.dot(y.T, K_inv_y)
53 | logp += jnp.sum(jnp.log(jnp.diag(L)))
54 | logp -= jnp.sum(1.0 - mask) * 0.5 * jnp.log(MASK_VARIANCE)
55 | logp += (jnp.sum(mask) / 2) * jnp.log(2 * jnp.pi)
56 | logp += jnp.sum(-0.5 * jnp.log(2*jnp.pi) - jnp.log(amp) - jnp.log(amp)**2)
57 | return jnp.sum(logp)
58 |
59 | assert xt is not None, "xt can't be None during prediction."
60 | xt = xt / ls
61 |
62 | # Compute the covariance with the new point xt
63 | mask_t = jnp.ones(len(xt))==1
64 | K_cross = amp * cov(x, xt, mask, mask_t)
65 |
66 | K_inv_y = K_inv_y * mask # masking
67 | pred_mean = jnp.dot(K_cross.T, K_inv_y) + ymean
68 | v = solve_triangular(L, K_cross, lower=True)
69 | pred_var = amp * cov(xt, xt, mask_t, mask_t) - v.T @ v
70 | pred_std = jnp.sqrt(jnp.maximum(jnp.diag(pred_var), 1e-10))
71 | return pred_mean, pred_std
72 |
73 |
74 | marginal_likelihood = partial(gaussian_process, compute_ml=True)
75 | grad_fun = jax.jit(jax.grad(marginal_likelihood))
76 | predict = jax.jit(partial(gaussian_process, compute_ml=False))
77 |
78 | def neg_log_likelihood(params, x, y, mask):
79 | ll = marginal_likelihood(params, x, y, mask)
80 |
81 | # Weak priors to keep things sane
82 | # params = jax.tree.map(softplus, params)
83 | priors = GPParams(-8.0, 1.0, 1.0)
84 | log_prior = jax.tree.map(lambda p, m: jnp.sum((p - m) ** 2), params, priors)
85 | log_prior = sum(jax.tree.leaves(log_prior))
86 | log_posterior = ll - 0.5 * log_prior
87 | return -log_posterior
88 |
89 |
90 | def posterior_fit(
91 | y: jax.Array,
92 | x: jax.Array,
93 | mask: jax.Array,
94 | params: GPParams,
95 | lr: float = 1e-3,
96 | trainsteps: int = 100,
97 | ) -> GPState:
98 |
99 | optimizer = optax.chain(optax.clip_by_global_norm(10.0), optax.adamw(lr))
100 | opt_state = optimizer.init(params)
101 |
102 | def train_step(carry, _):
103 | params, opt_state = carry
104 | grads = jax.grad(neg_log_likelihood)(params, x, y, mask)
105 | updates, opt_state = optimizer.update(grads, opt_state, params=params)
106 | params = optax.apply_updates(params, updates)
107 | return (params, opt_state), None
108 |
109 | (params, _), __ = jax.lax.scan(train_step, (params, opt_state), None, length=trainsteps)
110 | return params
111 |
--------------------------------------------------------------------------------
/bayex/acq.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.scipy.stats import norm
4 |
5 | from bayex.gp import GPParams, predict
6 |
7 |
8 | def expected_improvement(
9 | x_pred: jnp.ndarray,
10 | xs: jax.Array,
11 | ys: jax.Array,
12 | mask: jax.Array,
13 | gp_params: GPParams,
14 | xi: float = 0.01,
15 | ):
16 | r"""
17 | Expected Improvement (EI) acquisition function.
18 |
19 | Favors points with high improvement over the current best observed value,
20 | balancing exploitation and exploration.
21 |
22 | The formula is:
23 |
24 | .. math::
25 |
26 | EI(x) = (\mu(x) - y^* - \xi) \Phi(z) + \sigma(x) \phi(z)
27 |
28 | where:
29 |
30 | .. math::
31 |
32 | z = \frac{\mu(x) - y^* - \xi}{\sigma(x)}
33 |
34 | Args:
35 | x_pred: Candidate input locations to evaluate.
36 | xs: Observed inputs.
37 | ys: Observed function values.
38 | mask: Boolean mask indicating valid entries in `ys`.
39 | gp_params: Gaussian Process hyperparameters.
40 | xi: Exploration-exploitation tradeoff parameter.
41 |
42 | Returns:
43 | EI scores at `x_pred`.
44 | """
45 | ymax = jnp.max(ys, where=mask.astype(bool), initial=-jnp.inf)
46 | mu, std = predict(gp_params, xs, ys, mask, xt=x_pred)
47 | a = mu - ymax - xi
48 | z = a / (std + 1e-3)
49 | ei = a * norm.cdf(z) + std * norm.pdf(z)
50 | return ei
51 |
52 |
53 | def probability_improvement(
54 | x_pred: jnp.ndarray,
55 | xs: jax.Array,
56 | ys: jax.Array,
57 | mask: jax.Array,
58 | gp_params: GPParams,
59 | xi: float = 0.01,
60 | ):
61 | r"""
62 | Probability of Improvement (PI) acquisition function.
63 |
64 | Estimates the probability that a candidate point will improve
65 | over the current best observed value.
66 |
67 | The formula is:
68 |
69 | .. math::
70 |
71 | PI(x) = \Phi\left(\frac{\mu(x) - y^* - \xi}{\sigma(x)}\right)
72 |
73 | Args:
74 | x_pred: Candidate input locations to evaluate.
75 | xs: Observed inputs.
76 | ys: Observed function values.
77 | mask: Boolean mask indicating valid entries in `ys`.
78 | gp_params: Gaussian Process hyperparameters.
79 | xi: Improvement margin for sensitivity.
80 |
81 | Returns:
82 | PI scores at `x_pred`.
83 | """
84 | y_max = ys.max()
85 | mu, std = predict(gp_params, xs, ys, mask, xt=x_pred)
86 | z = (mu - y_max - xi) / std
87 | return norm.cdf(z)
88 |
89 |
90 | def upper_confidence_bounds(
91 | x_pred: jnp.ndarray,
92 | xs: jax.Array,
93 | ys: jax.Array,
94 | mask: jax.Array,
95 | gp_params: GPParams,
96 | kappa: float = 0.01,
97 | ):
98 | r"""
99 | Upper Confidence Bound (UCB) acquisition function.
100 |
101 | Promotes exploration by favoring points with high predictive uncertainty.
102 |
103 | The formula is:
104 |
105 | .. math::
106 |
107 | UCB(x) = \mu(x) + \kappa \cdot \sigma(x)
108 |
109 | Args:
110 | x_pred: Candidate input locations to evaluate.
111 | xs: Observed inputs.
112 | ys: Observed function values.
113 | mask: Boolean mask indicating valid entries in `ys`.
114 | gp_params: Gaussian Process hyperparameters.
115 | kappa: Weighting factor for uncertainty.
116 |
117 | Returns:
118 | UCB scores at `x_pred`.
119 | """
120 | mu, std = predict(gp_params, xs, ys, mask, xt=x_pred)
121 | return mu + kappa * std
122 |
123 |
124 | def lower_confidence_bounds(
125 | x_pred: jnp.ndarray,
126 | xs: jax.Array,
127 | ys: jax.Array,
128 | mask: jax.Array,
129 | gp_params: GPParams,
130 | kappa: float = 2.576,
131 | ):
132 | r"""
133 | Lower Confidence Bound (LCB) acquisition function.
134 |
135 | Useful for minimization tasks. Encourages sampling in uncertain regions
136 | with low predicted values.
137 |
138 | The formula is:
139 |
140 | .. math::
141 |
142 | LCB(x) = \mu(x) - \kappa \cdot \sigma(x)
143 |
144 | Args:
145 | x_pred: Candidate input locations to evaluate.
146 | xs: Observed inputs.
147 | ys: Observed function values.
148 | mask: Boolean mask indicating valid entries in `ys`.
149 | gp_params: Gaussian Process hyperparameters.
150 | kappa: Weighting factor for uncertainty.
151 |
152 | Returns:
153 | LCB scores at `x_pred`.
154 | """
155 | mu, std = predict(gp_params, xs, ys, mask, xt=x_pred)
156 | return mu - kappa * std
157 |
--------------------------------------------------------------------------------
/bayex/domain.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import jax
3 |
4 |
5 | class Domain:
6 | def __init__(self, dtype):
7 | self.dtype = dtype
8 |
9 | def __hash__(self):
10 | return hash(self.dtype)
11 |
12 | def __eq__(self, other):
13 | return self.dtype == other.dtype
14 |
15 | def transform(self, x: jax.Array):
16 | raise NotImplementedError
17 |
18 | def sample(self, key: jax.Array, shape: Tuple):
19 | raise NotImplementedError
20 |
21 |
22 | class Real(Domain):
23 | """
24 | Continuous real-valued domain with clipping.
25 |
26 | Represents a parameter that can take real values within [lower, upper].
27 | """
28 |
29 | def __init__(self, lower, upper):
30 | """
31 | Initializes a real domain with bounds.
32 |
33 | Args:
34 | lower: Lower bound (inclusive).
35 | upper: Upper bound (inclusive).
36 | """
37 | assert isinstance(lower, float) or isinstance(lower, int), "Lower bound must be a float"
38 | assert isinstance(upper, float) or isinstance(lower, int), "Upper bound must be a float"
39 | assert lower < upper, "Lower bound must be less than upper bound"
40 |
41 | self.lower = float(lower)
42 | self.upper = float(upper)
43 | super().__init__(dtype=jax.numpy.float32)
44 |
45 | def __hash__(self):
46 | return hash((self.lower, self.upper))
47 |
48 | def __eq__(self, other):
49 | return self.lower == other.lower and self.upper == other.upper
50 |
51 | def transform(self, x: jax.Array):
52 | """
53 | Clips values to the domain range [lower, upper].
54 |
55 | Args:
56 | x: Input values.
57 |
58 | Returns:
59 | Clipped values within bounds.
60 | """
61 | return jax.numpy.clip(x, self.lower, self.upper)
62 |
63 | def sample(self, key: jax.Array, shape: Tuple):
64 | """
65 | Samples uniformly from the domain.
66 |
67 | Args:
68 | key: JAX PRNGKey.
69 | shape: Desired output shape.
70 |
71 | Returns:
72 | Sampled values clipped to the domain.
73 | """
74 | samples = jax.random.uniform(key, shape, minval=self.lower, maxval=self.upper)
75 | return self.transform(samples)
76 |
77 |
78 | class Integer(Domain):
79 | """
80 | Discrete integer-valued domain with rounding and clipping.
81 |
82 | Represents a parameter that can take integer values within [lower, upper].
83 | """
84 |
85 | def __init__(self, lower, upper):
86 | """
87 | Initializes an integer domain with bounds.
88 |
89 | Args:
90 | lower: Lower integer bound (inclusive).
91 | upper: Upper integer bound (inclusive).
92 | """
93 | assert isinstance(lower, int), "Lower bound must be an integer"
94 | assert isinstance(upper, int), "Upper bound must be an integer"
95 | assert lower < upper, "Lower bound must be less than upper bound"
96 |
97 | self.lower = int(lower)
98 | self.upper = int(upper)
99 | super().__init__(dtype=jax.numpy.int32)
100 |
101 | def __hash__(self):
102 | return hash((self.lower, self.upper))
103 |
104 | def __eq__(self, other):
105 | return self.lower == other.lower and self.upper == other.upper
106 |
107 | def transform(self, x: jax.Array):
108 | """
109 | Rounds and clips values to the integer domain.
110 |
111 | Args:
112 | x: Input values.
113 |
114 | Returns:
115 | Rounded and clipped values as float32.
116 | """
117 | return jax.numpy.clip(jax.numpy.round(x), self.lower, self.upper).astype(jax.numpy.float32)
118 |
119 | def sample(self, key: jax.Array, shape: Tuple):
120 | """
121 | Samples integers uniformly from the domain.
122 |
123 | Args:
124 | key: JAX PRNGKey.
125 | shape: Desired output shape.
126 |
127 | Returns:
128 | Sampled values clipped to valid integer range.
129 | """
130 | samples = jax.random.randint(key, shape, minval=self.lower, maxval=self.upper + 1)
131 | return self.transform(samples)
132 |
133 |
134 | class ParamSpace:
135 | """
136 | Internal class that manages a collection of named parameter domains.
137 |
138 | This utility encapsulates logic for sampling, transforming, and handling
139 | structured parameter inputs defined by a mapping of variable names to Domain
140 | instances (e.g., Real, Integer).
141 |
142 | Example:
143 | >>> space = ParamSpace({
144 | ... "x1": Real(0.0, 1.0),
145 | ... "x2": Integer(1, 5)
146 | ... })
147 | >>> key = jax.random.PRNGKey(0)
148 | >>> samples = space.sample_tree(key, (128,))
149 | >>> xs = space.transform_tree(samples)
150 |
151 | Notes:
152 | This class is intended for internal use by the optimizer and should not
153 | be exposed as part of the public API.
154 | """
155 |
156 | def __init__(self, space: dict):
157 | self.space = space
158 |
159 | def sample_params(self, key: jax.Array, shape: Tuple) -> dict:
160 | keys = jax.random.split(key, len(self.space))
161 | return {name: self.space[name].sample(k, shape) for name, k in zip(self.space, keys)}
162 |
163 | def to_array(self, tree: dict) -> jax.Array:
164 | """
165 | Transforms a batch of parameter values into a 2D array suitable for GP input.
166 |
167 | Applies each domain's `.transform()` to its corresponding parameter values.
168 |
169 | Args:
170 | tree: A dictionary of parameter name → array of raw values.
171 |
172 | Returns:
173 | A JAX array of shape (batch_size, num_params) with transformed values.
174 | """
175 | return jax.numpy.stack([self.space[k].transform(tree[k]) for k in self.space], axis=1)
176 |
177 |
178 | def to_dict(self, xs: jax.Array) -> dict:
179 | """
180 | Converts a stacked parameter matrix back into named parameter trees.
181 |
182 | Typically used after optimization in transformed space.
183 |
184 | Args:
185 | xs: A 2D JAX array of shape (batch_size, num_params), with each column
186 | corresponding to a parameter.
187 |
188 | Returns:
189 | A dictionary mapping parameter names to individual 1D arrays.
190 | """
191 | return {k: self.space[k].transform(xs[:, i]) for i, k in enumerate(self.space)}
--------------------------------------------------------------------------------
/bayex/optimizer.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Union, NamedTuple, Callable
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | import numpy as np
7 | import optax
8 |
9 | import bayex.acq as boacq
10 | from bayex.gp import GPParams, posterior_fit
11 | from bayex.domain import ParamSpace
12 |
13 |
14 | class OptimizerState(NamedTuple):
15 | """
16 | Container for the state of the Bayesian optimizer.
17 |
18 | Attributes:
19 | params (dict): Dictionary mapping parameter names to their
20 | corresponding padded JAX arrays of observed values.
21 | ys (jax.Array or np.ndarray): Array of objective values associated
22 | with the observed parameters. Includes padding.
23 | best_score (float): Best observed objective value so far.
24 | best_params (dict): Parameter configuration corresponding to
25 | the best_score.
26 | mask (jax.Array): Boolean array indicating which entries in
27 | `params` and `ys` are valid (i.e., not padding).
28 | gp_params (GPParams): Parameters of the Gaussian Process
29 | fitted to the observations.
30 | """
31 | params: dict
32 | ys: Union[jax.Array, np.ndarray]
33 | best_score: float
34 | best_params: dict
35 | mask: jax.Array
36 | gp_params: GPParams
37 |
38 |
39 | def _optimize_suggestion(params: dict, fun: Callable, max_iter: int = 10):
40 | """
41 | Applies local optimization (L-BFGS) to a given starting point.
42 |
43 | This function refines candidate points proposed by the acquisition
44 | function by performing a fixed number of gradient-based optimization
45 | steps using Optax's L-BFGS optimizer.
46 |
47 | Args:
48 | params (jax.Array): Initial point in input space to optimize.
49 | fun (Callable): Objective function to maximize. It must return a
50 | scalar value and support automatic differentiation.
51 | max_iter (int): Maximum number of L-BFGS steps to apply.
52 |
53 | Returns:
54 | jax.Array: The optimized point after `max_iter` iterations.
55 | """
56 |
57 | # L-BFGS optimizer is used for minimization, so to maximize acquisition
58 | # we need to negate the function value.
59 | opt = optax.lbfgs()
60 | value_and_grad_fun = optax.value_and_grad_from_state(lambda x: -fun(x))
61 |
62 | def step(carry, _):
63 | params, state = carry
64 | value, grad = value_and_grad_fun(params, state=state)
65 | updates, state = opt.update(grad, state, params,
66 | value=value, grad=grad, value_fn=fun)
67 | params = optax.apply_updates(params, updates)
68 | params = jnp.clip(params, -1e6, 1e6)
69 | return (params, state), None
70 |
71 | init_carry = (params, opt.init(params))
72 | (final_params, _), __ = jax.lax.scan(step, init_carry, None, length=max_iter)
73 | return final_params
74 |
75 |
76 | class Optimizer:
77 | """
78 | Bayesian optimizer using Gaussian Processes and acquisition functions.
79 |
80 | This class manages the optimization loop for expensive black-box functions
81 | by modeling them with a Gaussian Process and selecting samples via
82 | acquisition functions such as EI, PI, UCB, or LCB.
83 | """
84 |
85 | def __init__(self, domain: dict, acq: str = 'EI', maximize: bool = False):
86 | """
87 | Initializes the optimizer.
88 |
89 | Args:
90 | domain: A dict mapping parameter names to domain objects (e.g., Real, Integer).
91 | acq: Acquisition function ('EI', 'PI', 'UCB', 'LCB').
92 | maximize: Whether to maximize or minimize the objective.
93 | """
94 | self.domain = domain
95 | self.sign = 1 if maximize else -1
96 | self.param_space = ParamSpace(domain)
97 |
98 | if acq == 'EI':
99 | self.acq = jax.jit(boacq.expected_improvement)
100 | elif acq == 'PI':
101 | self.acq = jax.jit(boacq.probability_improvement)
102 | elif acq == 'UCB':
103 | self.acq = jax.jit(boacq.upper_confidence_bounds)
104 | elif acq == 'LCB':
105 | self.acq = jax.jit(boacq.lower_confidence_bounds)
106 | else:
107 | raise ValueError(f"Acquisition function {acq} is not implemented")
108 |
109 | def init(self, ys: jax.Array, params: dict, noise_scale: float = -8.0):
110 | """
111 | Initializes the optimizer state from initial data.
112 |
113 | Args:
114 | ys: Objective values for the initial parameters.
115 | params: Dict of parameter arrays (same keys as domain).
116 |
117 | Returns:
118 | Initialized OptimizerState.
119 | """
120 | # Create a padded jax array for each parameter and each score.
121 | # In order to keep jax compilations at a bay.
122 | num_entries = len(ys)
123 | pad_value = int(np.ceil(len(ys) / 10) * 10)
124 |
125 | # Convert to jax arrays if they are not already
126 | ys = jnp.asarray(ys)
127 | ys = self.sign * ys
128 | params = jax.tree.map(lambda x: jnp.asarray(x), params)
129 |
130 | # Define padded arrays for the inputs and the outputs
131 | mask = jnp.zeros(shape=(pad_value,), dtype=jnp.bool_).at[:num_entries].set(True)
132 | ys = jnp.zeros(shape=(pad_value,), dtype=ys.dtype).at[:num_entries].set(ys)
133 |
134 | _params = {}
135 | for key, entries in params.items():
136 | # Assert that the parameter is in the domain dictionary
137 | assert key in self.domain, f"Parameter {key} is not in the domain"
138 |
139 | # Get dytpe from the domain and create a padded array
140 | dtype = self.domain[key].dtype
141 | values = jnp.zeros(shape=(pad_value,), dtype=dtype).at[:num_entries].set(entries)
142 | _params[key] = values
143 |
144 | # From the given observation, find the better one (either maxima or minima) and return the
145 | # initial optizer state.
146 | best_score = float(jnp.max(ys[mask]))
147 | best_params_idx = jnp.argmax(ys[mask])
148 | best_params = jax.tree.map(lambda x: x[mask][best_params_idx], _params)
149 |
150 | # Initialize the gaussian processes state
151 | gp_params = GPParams(
152 | noise=jnp.full((1, 1), 1. * noise_scale),
153 | amplitude=jnp.zeros((1, 1)),
154 | lengthscale=jnp.zeros((1, len(_params)))
155 | )
156 |
157 | # Fit to the current observations
158 | xs = jnp.stack([self.domain[key].transform(_params[key]) for key in _params], axis=1)
159 | gp_params = posterior_fit(ys, xs, mask=mask, params=gp_params)
160 |
161 | ys = self.sign * ys
162 | best_score = self.sign * best_score
163 | opt_state = OptimizerState(params=_params, ys=ys, best_score=best_score,
164 | best_params=best_params, mask=mask, gp_params=gp_params)
165 |
166 | return opt_state
167 |
168 | @partial(jax.jit, static_argnames=('self', 'size'))
169 | def sample(self, key, state, size=10_000):
170 | """
171 | Samples new parameters using the acquisition function.
172 |
173 | Args:
174 | key: JAX PseudoRandom key for random sampling.
175 | opt_state: Current optimizer state.
176 | size: Number of samples to draw.
177 | has_prior: If True, also return GP predictions.
178 |
179 | Returns:
180 | Sampled parameters (dict), and optionally (xs_samples, means, stds).
181 | """
182 | # Sample 'size' elements of each distribution.
183 | samples = self.param_space.sample_params(key, (size,))
184 | xs_samples = self.param_space.to_array(samples)
185 |
186 | # Prepare the data for the Gaussian process prediction.
187 | xs = self.param_space.to_array(state.params)
188 | ys = self.sign * state.ys
189 | mask = state.mask
190 | gp_params = state.gp_params
191 |
192 | # Compute the acquisition function values for the sampled points.
193 | acq_vals = self.acq(xs_samples, xs, ys, mask, gp_params)
194 |
195 | # Of those, find the best 50 points and optimize them using BFGS.
196 | top_idxs = jnp.argsort(acq_vals)[-50:]
197 | init_points = xs_samples[top_idxs]
198 | f = lambda x: jnp.squeeze(self.acq(x[None, :], xs, ys, mask, gp_params))
199 | optimized = jax.vmap(lambda x: _optimize_suggestion(x, f, max_iter=10))(init_points)
200 | opt_vals = self.acq(optimized, xs, ys, mask, gp_params)
201 |
202 | # Return the best point from the optimized points and the sampled points.
203 | all_points = jnp.concatenate((optimized, xs_samples), axis=0)
204 | all_vals = jnp.concatenate((opt_vals, acq_vals), axis=0)
205 | chosen_suggestion = jnp.argmax(all_vals)
206 |
207 | best_params = self.param_space.to_dict(all_points[chosen_suggestion][None])
208 | return best_params
209 |
210 |
211 | def expand(self, opt_state: OptimizerState):
212 | """
213 | Expands internal buffers if no space is available.
214 |
215 | Args:
216 | opt_state: Current optimizer state.
217 |
218 | Returns:
219 | OptimizerState with expanded storage.
220 | """
221 | current = jnp.sum(opt_state.mask)
222 |
223 | if current == len(opt_state.mask):
224 | pad_value = int(np.ceil(len(opt_state.mask)*2 / 10) * 10)
225 | diff = pad_value - len(opt_state.mask)
226 | mask = jnp.pad(opt_state.mask, (0, diff))
227 | ys = jnp.pad(opt_state.ys, (0, diff))
228 | params = {}
229 | for key in opt_state.params:
230 | params[key] = jnp.pad(opt_state.params[key], (0, diff))
231 | else:
232 | mask = opt_state.mask
233 | ys = opt_state.ys
234 | params = opt_state.params
235 |
236 | opt_state = OptimizerState(params=params, ys=ys, best_score=opt_state.best_score,
237 | best_params=opt_state.best_params, mask=mask,
238 | gp_params=opt_state.gp_params)
239 | return opt_state
240 |
241 |
242 | def fit(self, opt_state, y, new_params):
243 | """
244 | Updates optimizer state with a new observation.
245 |
246 | Args:
247 | opt_state: Current optimizer state.
248 | y: New objective value.
249 | new_params: Parameters that produced y.
250 |
251 | Returns:
252 | Updated OptimizerState.
253 | """
254 | opt_state = self.expand(opt_state) # Prompts recompilation
255 | opt_state = self._fit(opt_state, y, new_params)
256 | return opt_state
257 |
258 |
259 | @partial(jax.jit, static_argnums=(0,))
260 | def _fit(self, opt_state, y, new_params):
261 | last_idx = jnp.arange(len(opt_state.mask)) == jnp.argmin(opt_state.mask)
262 | mask = jnp.asarray(jnp.where(last_idx, True, opt_state.mask))
263 | ys = jnp.where(last_idx, y, opt_state.ys)
264 | params = jax.tree_util.tree_map(lambda x, y: jnp.where(last_idx, y, x), opt_state.params, new_params)
265 |
266 | xs = jnp.stack([self.domain[key].transform(params[key])
267 | for key in params], axis=1)
268 | ys = self.sign * ys
269 | gp_params = posterior_fit(ys, xs, mask=mask, params=opt_state.gp_params)
270 |
271 | best_score = jnp.max(ys, where=mask, initial=-jnp.inf)
272 | best_params_idx = jnp.argmax(jnp.where(mask, ys, -jnp.inf))
273 | best_params = jax.tree_util.tree_map(lambda x: x[best_params_idx], params)
274 |
275 | ys = self.sign * ys
276 | best_score = self.sign * best_score
277 | opt_state = OptimizerState(params=params, ys=ys, best_score=best_score,
278 | best_params=best_params, mask=mask,
279 | gp_params=gp_params)
280 | return opt_state
281 |
282 |
--------------------------------------------------------------------------------