├── docs ├── requirements.txt └── source │ ├── _static │ └── bayex.png │ ├── api │ ├── generated │ │ ├── bayex.gp.predict.rst │ │ ├── bayex.gp.grad_fun.rst │ │ ├── bayex.gp.posterior_fit.rst │ │ ├── bayex.gp.gaussian_process.rst │ │ ├── bayex.gp.marginal_likelihood.rst │ │ ├── bayex.acq.expected_improvement.rst │ │ ├── bayex.acq.lower_confidence_bounds.rst │ │ ├── bayex.acq.probability_improvement.rst │ │ ├── bayex.acq.upper_confidence_bounds.rst │ │ ├── bayex.domain.Real.rst │ │ ├── bayex.domain.Domain.rst │ │ ├── bayex.domain.Integer.rst │ │ ├── bayex.optimizer.Optimizer.rst │ │ ├── bayex.gp.GPState.rst │ │ └── bayex.gp.GPParams.rst │ ├── index.rst │ ├── bayex.domain.rst │ ├── bayex.optimizer.rst │ ├── bayex.gp.rst │ └── bayex.acq.rst │ ├── conf.py │ └── index.rst ├── bayex ├── __init__.py ├── gp.py ├── acq.py ├── domain.py └── optimizer.py ├── .readthedocs.yml ├── .github └── workflows │ ├── tests.yml │ └── python-publish.yml ├── LICENSE ├── pyproject.toml ├── tests ├── test_acq.py ├── test_optimizer.py ├── test_domain.py └── test_gp.py ├── README.md └── .gitignore /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-book-theme 3 | myst-parser 4 | sphinx_autodoc_typehints 5 | -------------------------------------------------------------------------------- /docs/source/_static/bayex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alonfnt/bayex/HEAD/docs/source/_static/bayex.png -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.gp.predict.rst: -------------------------------------------------------------------------------- 1 | bayex.gp.predict 2 | ================ 3 | 4 | .. currentmodule:: bayex.gp 5 | 6 | .. autofunction:: predict -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.gp.grad_fun.rst: -------------------------------------------------------------------------------- 1 | bayex.gp.grad\_fun 2 | ================== 3 | 4 | .. currentmodule:: bayex.gp 5 | 6 | .. autofunction:: grad_fun -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.gp.posterior_fit.rst: -------------------------------------------------------------------------------- 1 | bayex.gp.posterior\_fit 2 | ======================= 3 | 4 | .. currentmodule:: bayex.gp 5 | 6 | .. autofunction:: posterior_fit -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.gp.gaussian_process.rst: -------------------------------------------------------------------------------- 1 | bayex.gp.gaussian\_process 2 | ========================== 3 | 4 | .. currentmodule:: bayex.gp 5 | 6 | .. autofunction:: gaussian_process -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.gp.marginal_likelihood.rst: -------------------------------------------------------------------------------- 1 | bayex.gp.marginal\_likelihood 2 | ============================= 3 | 4 | .. currentmodule:: bayex.gp 5 | 6 | .. autofunction:: marginal_likelihood -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.acq.expected_improvement.rst: -------------------------------------------------------------------------------- 1 | bayex.acq.expected\_improvement 2 | =============================== 3 | 4 | .. currentmodule:: bayex.acq 5 | 6 | .. autofunction:: expected_improvement -------------------------------------------------------------------------------- /bayex/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer import Optimizer 2 | from . import domain, acq 3 | 4 | 5 | __version__ = "0.2.2" 6 | 7 | __all__ = [ 8 | "Optimizer", 9 | "domain", 10 | "acq", 11 | ] 12 | -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.acq.lower_confidence_bounds.rst: -------------------------------------------------------------------------------- 1 | bayex.acq.lower\_confidence\_bounds 2 | =================================== 3 | 4 | .. currentmodule:: bayex.acq 5 | 6 | .. autofunction:: lower_confidence_bounds -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.acq.probability_improvement.rst: -------------------------------------------------------------------------------- 1 | bayex.acq.probability\_improvement 2 | ================================== 3 | 4 | .. currentmodule:: bayex.acq 5 | 6 | .. autofunction:: probability_improvement -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.acq.upper_confidence_bounds.rst: -------------------------------------------------------------------------------- 1 | bayex.acq.upper\_confidence\_bounds 2 | =================================== 3 | 4 | .. currentmodule:: bayex.acq 5 | 6 | .. autofunction:: upper_confidence_bounds -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-24.04 5 | tools: 6 | python: "3.12" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | - method: pip 15 | path: . 16 | -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | bayex.acq 8 | bayex.domain 9 | bayex.optimizer 10 | 11 | Module contents 12 | --------------- 13 | 14 | .. automodule:: bayex 15 | :members: 16 | :show-inheritance: 17 | :undoc-members: 18 | -------------------------------------------------------------------------------- /docs/source/api/bayex.domain.rst: -------------------------------------------------------------------------------- 1 | ``bayex.domain`` module 2 | ----------------------- 3 | 4 | .. currentmodule:: bayex.domain 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | :nosignatures: 9 | 10 | Real 11 | Integer 12 | 13 | .. automodule:: bayex.domain 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | :no-index: 18 | 19 | -------------------------------------------------------------------------------- /docs/source/api/bayex.optimizer.rst: -------------------------------------------------------------------------------- 1 | ``bayex.optimizer`` module 2 | -------------------------- 3 | 4 | .. currentmodule:: bayex.optimizer 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | :nosignatures: 9 | 10 | Optimizer 11 | 12 | .. automodule:: bayex.optimizer 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | :no-index: 17 | :exclude-members: _fit, expand 18 | -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.domain.Real.rst: -------------------------------------------------------------------------------- 1 | bayex.domain.Real 2 | ================= 3 | 4 | .. currentmodule:: bayex.domain 5 | 6 | .. autoclass:: Real 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Real.__init__ 17 | ~Real.sample 18 | ~Real.transform 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.domain.Domain.rst: -------------------------------------------------------------------------------- 1 | bayex.domain.Domain 2 | =================== 3 | 4 | .. currentmodule:: bayex.domain 5 | 6 | .. autoclass:: Domain 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Domain.__init__ 17 | ~Domain.sample 18 | ~Domain.transform 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.domain.Integer.rst: -------------------------------------------------------------------------------- 1 | bayex.domain.Integer 2 | ==================== 3 | 4 | .. currentmodule:: bayex.domain 5 | 6 | .. autoclass:: Integer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Integer.__init__ 17 | ~Integer.sample 18 | ~Integer.transform 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /docs/source/api/bayex.gp.rst: -------------------------------------------------------------------------------- 1 | ``bayex.gp`` module 2 | ------------------- 3 | 4 | .. currentmodule:: bayex.gp 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | :nosignatures: 9 | 10 | GPParams 11 | GPState 12 | marginal_likelihood 13 | predict 14 | posterior_fit 15 | gaussian_process 16 | grad_fun 17 | 18 | .. automodule:: bayex.gp 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | :no-index: 23 | -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.optimizer.Optimizer.rst: -------------------------------------------------------------------------------- 1 | bayex.optimizer.Optimizer 2 | ========================= 3 | 4 | .. currentmodule:: bayex.optimizer 5 | 6 | .. autoclass:: Optimizer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Optimizer.__init__ 17 | ~Optimizer.expand 18 | ~Optimizer.fit 19 | ~Optimizer.init 20 | ~Optimizer.sample 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.gp.GPState.rst: -------------------------------------------------------------------------------- 1 | bayex.gp.GPState 2 | ================ 3 | 4 | .. currentmodule:: bayex.gp 5 | 6 | .. autoclass:: GPState 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~GPState.__init__ 17 | ~GPState.count 18 | ~GPState.index 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~GPState.momentums 29 | ~GPState.params 30 | ~GPState.scales 31 | 32 | -------------------------------------------------------------------------------- /docs/source/api/generated/bayex.gp.GPParams.rst: -------------------------------------------------------------------------------- 1 | bayex.gp.GPParams 2 | ================= 3 | 4 | .. currentmodule:: bayex.gp 5 | 6 | .. autoclass:: GPParams 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~GPParams.__init__ 17 | ~GPParams.count 18 | ~GPParams.index 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~GPParams.amplitude 29 | ~GPParams.lengthscale 30 | ~GPParams.noise 31 | 32 | -------------------------------------------------------------------------------- /docs/source/api/bayex.acq.rst: -------------------------------------------------------------------------------- 1 | ``bayex.acq`` module 2 | -------------------- 3 | 4 | This module contains the implementation of various acquisition functions used in Bayesian optimization. 5 | These functions help in selecting the next point to sample based on the current model of the objective function. 6 | 7 | .. currentmodule:: bayex.acq 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | :nosignatures: 12 | 13 | expected_improvement 14 | probability_improvement 15 | upper_confidence_bounds 16 | lower_confidence_bounds 17 | 18 | .. automodule:: bayex.acq 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | :no-index: 23 | 24 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | name: "Python ${{ matrix.python-version }}" 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.11", "3.12", "3.13"] 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: actions/setup-python@v4 22 | with: 23 | python-version: "${{ matrix.python-version }}" 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install -e . pytest 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | 31 | - name: Test with pytest 32 | run: | 33 | pytest 34 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # Workflow that publishes the last release to PyPI. 2 | 3 | name: Upload Python Package 4 | 5 | on: 6 | release: 7 | types: [published] 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | pypi-publish: 13 | name: Upload release to PyPI 14 | runs-on: ubuntu-latest 15 | environment: 16 | name: pypi 17 | url: https://pypi.org/p/bayex 18 | permissions: 19 | id-token: write 20 | steps: 21 | 22 | - uses: actions/checkout@v3 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: '3.x' 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | 34 | - name: Build package 35 | run: python -m build 36 | 37 | - name: Publish package distributions to PyPI 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Albert Alonso 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "bayex" 7 | version = "0.2.2" 8 | description = "Minimal Bayesian Optimization in JAX using Gaussian Processes." 9 | readme = "README.md" 10 | authors = [{name = "Albert Alonso", email = "alonfnt@pm.me"}, ] 11 | license = {file = "LICENSE"} 12 | keywords = ["jax", "bayesian-optimization", "automatic-differentiation", "gaussian-process", "machine-learning"] 13 | classifiers = [ 14 | "Development Status :: 3 - Alpha", 15 | "Intended Audience :: Science/Research", 16 | "Topic :: Scientific/Engineering", 17 | "License :: OSI Approved :: MIT License", 18 | "Programming Language :: Python :: 3", 19 | ] 20 | dependencies = ["jax", "jaxlib", "numpy", "optax"] 21 | 22 | [project.optional-dependencies] 23 | dev = ["pytest"] 24 | 25 | [tool.setuptools] 26 | packages = ["bayex"] 27 | 28 | [project.urls] 29 | "Homepage" = "https://github.com/alonfnt/bayex" 30 | "Documentation" = "https://github.com/alonfnt/bayex" 31 | "Source" = "https://github.com/alonfnt/bayex" 32 | "Bug Tracker" = "https://github.com/alonfnt/bayex/issues" 33 | -------------------------------------------------------------------------------- /tests/test_acq.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | from bayex.gp import GPParams 6 | from bayex.acq import ( 7 | expected_improvement, 8 | probability_improvement, 9 | upper_confidence_bounds, 10 | lower_confidence_bounds, 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def gp_inputs(): 16 | x = jnp.linspace(-3, 3, 10) 17 | y = jnp.sin(x) 18 | mask = jnp.ones_like(x) 19 | xt = jnp.linspace(-4, 4, 5) 20 | params = GPParams(noise=0.1, amplitude=1.0, lengthscale=1.0) 21 | return xt, x, y, mask, params 22 | 23 | 24 | @pytest.mark.parametrize("padding", [0, 3]) 25 | def test_acquisition_invariance_to_padding(gp_inputs, padding): 26 | xt, x, y, mask, params = gp_inputs 27 | ei_ref = expected_improvement(xt, x, y, mask, params) 28 | 29 | x_pad, y_pad, mask_pad = jax.tree.map(lambda t: jnp.pad(t, (0, padding)), (x, y, mask)) 30 | ei_pad = expected_improvement(xt, x_pad, y_pad, mask_pad, params) 31 | 32 | assert ei_pad.shape == ei_ref.shape 33 | assert jnp.allclose(ei_ref, ei_pad, atol=1e-5) 34 | 35 | 36 | @pytest.mark.parametrize("acq_fn", [ 37 | expected_improvement, 38 | probability_improvement, 39 | upper_confidence_bounds, 40 | lower_confidence_bounds, 41 | ]) 42 | def test_acquisition_output_validity(acq_fn, gp_inputs): 43 | xt, x, y, mask, params = gp_inputs 44 | acq_vals = acq_fn(xt, x, y, mask, params) 45 | 46 | assert acq_vals.shape == xt.shape 47 | assert jnp.all(jnp.isfinite(acq_vals)) 48 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | import os 6 | import sys 7 | 8 | sys.path.insert(0, os.path.abspath('../../bayex')) 9 | 10 | # -- Project information ----------------------------------------------------- 11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 12 | 13 | project = 'bayex' 14 | copyright = '2025, Albert Alonso' 15 | author = 'Albert Alonso' 16 | release = '0.2.2' 17 | 18 | # -- General configuration --------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 20 | 21 | extensions = [ 22 | 'myst_parser', 23 | 'sphinx.ext.autodoc', 24 | 'sphinx.ext.napoleon', 25 | 'sphinx.ext.autosummary', 26 | 'sphinx_autodoc_typehints', 27 | 'sphinx.ext.mathjax', 28 | ] 29 | 30 | templates_path = ['_templates'] 31 | exclude_patterns = [] 32 | 33 | autodoc_default_options = { 34 | 'members': True, 35 | 'undoc-members': False, 36 | 'private-members': False, 37 | 'special-members': False, 38 | 'inherited-members': True, 39 | } 40 | 41 | 42 | 43 | # -- Options for HTML output ------------------------------------------------- 44 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 45 | 46 | html_theme = 'sphinx_book_theme' 47 | html_static_path = ["_static"] 48 | html_logo = "_static/bayex.png" 49 | -------------------------------------------------------------------------------- /tests/test_optimizer.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pytest 5 | 6 | import bayex 7 | 8 | KEY = jax.random.key(42) 9 | SEED = 42 10 | 11 | 12 | def test_1D_optim(): 13 | 14 | def f(x): 15 | return jnp.sin(x) - 0.1 * x**2 16 | 17 | TARGET = np.max(f(np.linspace(-2, 2, 1000))) 18 | 19 | domain = {'x': bayex.domain.Real(-2, 2)} 20 | 21 | opt = bayex.Optimizer(domain=domain, maximize=True) 22 | 23 | # Evaluate 3 times the function 24 | params = {'x': [0.0, 1.0, 2.0]} 25 | ys = [f(x) for x in params['x']] 26 | opt_state = opt.init(ys, params) 27 | 28 | assert opt_state.best_score == np.max(ys) 29 | assert opt_state.best_params['x'] == params['x'][np.argmax(ys)] 30 | 31 | assert type(opt_state) == bayex.optimizer.OptimizerState 32 | assert type(opt_state.gp_params) == bayex.gp.GPParams # pyright: ignore 33 | 34 | assert opt_state.params['x'].shape == (10,) 35 | assert opt_state.ys.shape == (10,) 36 | assert opt_state.mask.shape == (10,) 37 | 38 | assert np.allclose(opt_state.params['x'][:3], params['x']) 39 | assert np.allclose(opt_state.ys[:3], ys) 40 | assert np.allclose(opt_state.mask[:3], [True, True, True]) 41 | assert np.allclose(opt_state.mask[3:], [False] * 7) 42 | 43 | key = jax.random.key(SEED) 44 | 45 | sample_fn = jax.jit(opt.sample) 46 | for step in range(100): 47 | key = jax.random.fold_in(key, step) 48 | new_params = sample_fn(key, opt_state) 49 | y = f(**new_params) 50 | opt_state = opt.fit(opt_state, y, new_params) 51 | if jnp.allclose(opt_state.best_score, TARGET, atol=1e-03): 52 | break 53 | target = opt_state.best_score 54 | assert jnp.allclose(TARGET, target, atol=1e-02) 55 | 56 | 57 | def test_evaluate_raise_invalid_acq_fun(): 58 | domain = {'x': bayex.domain.Real(-2, 2)} 59 | 60 | # This shouldn't raise an error 61 | for acq in ['EI',]: 62 | bayex.Optimizer(domain=domain, acq=acq) 63 | 64 | # But this should! 65 | for acq in ['random', 'magic']: 66 | with pytest.raises(ValueError, match=f"Acquisition function {acq} is not implemented"): 67 | bayex.Optimizer(domain=domain, acq=acq) 68 | -------------------------------------------------------------------------------- /tests/test_domain.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import jax 3 | import jax.numpy as jnp 4 | from bayex.domain import Real, Integer 5 | 6 | 7 | @pytest.mark.parametrize("lower, upper", [(0.0, 1.0), (-5.5, 5.5)]) 8 | def test_real_transform_clips_within_bounds(lower, upper): 9 | domain = Real(lower, upper) 10 | x = jnp.array([lower - 1, (lower + upper) / 2, upper + 1]) 11 | out = domain.transform(x) 12 | assert jnp.all(out >= lower) and jnp.all(out <= upper) 13 | 14 | 15 | def test_real_sample_within_bounds(): 16 | domain = Real(0.0, 1.0) 17 | key = jax.random.PRNGKey(0) 18 | shape = (100,) 19 | samples = domain.sample(key, shape) 20 | assert samples.shape == shape 21 | assert samples.dtype == jnp.float32 22 | assert jnp.all(samples >= 0.0) and jnp.all(samples <= 1.0) 23 | 24 | 25 | @pytest.mark.parametrize("lower, upper", [(0, 5), (-3, 3)]) 26 | def test_integer_transform_and_type(lower, upper): 27 | domain = Integer(lower, upper) 28 | x = jnp.array([lower - 2.3, (lower + upper) / 2.0, upper + 2.7]) 29 | out = domain.transform(x) 30 | assert jnp.issubdtype(out.dtype, jnp.floating) 31 | assert jnp.all(out >= lower) and jnp.all(out <= upper) 32 | assert jnp.all(out == jnp.round(out)) 33 | 34 | 35 | def test_integer_sample_within_bounds(): 36 | domain = Integer(1, 10) 37 | key = jax.random.PRNGKey(42) 38 | shape = (1000,) 39 | samples = domain.sample(key, shape) 40 | assert samples.shape == shape 41 | assert jnp.all(samples >= 1) and jnp.all(samples <= 10) 42 | assert jnp.all(jnp.equal(samples, jnp.round(samples))) 43 | 44 | 45 | def test_domain_equality_and_hash(): 46 | a = Real(0.0, 1.0) 47 | b = Real(0.0, 1.0) 48 | c = Real(1.0, 2.0) 49 | assert a == b 50 | assert a != c 51 | assert hash(a) == hash(b) 52 | assert hash(a) != hash(c) 53 | 54 | i1 = Integer(1, 5) 55 | i2 = Integer(1, 5) 56 | i3 = Integer(0, 4) 57 | assert i1 == i2 58 | assert i1 != i3 59 | assert hash(i1) == hash(i2) 60 | assert hash(i1) != hash(i3) 61 | 62 | 63 | @pytest.mark.parametrize("lower, upper", [ 64 | (1.0, 1.0), # equal 65 | ("a", 1.0), # wrong type 66 | (1.0, "b"), # wrong type 67 | (5.0, 3.0), # lower > upper 68 | ]) 69 | def test_real_init_invalid(lower, upper): 70 | with pytest.raises(AssertionError): 71 | Real(lower, upper) 72 | 73 | 74 | @pytest.mark.parametrize("lower, upper", [ 75 | (5, 5), # equal 76 | ("a", 5), # wrong type 77 | (0, "b"), # wrong type 78 | (10, 1), # lower > upper 79 | ]) 80 | def test_integer_init_invalid(lower, upper): 81 | with pytest.raises(AssertionError): 82 | Integer(lower, upper) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | [![tests](https://github.com/alonfnt/bayex/actions/workflows/tests.yml/badge.svg)](https://github.com/alonfnt/bayex/actions/workflows/tests.yml) 6 | [![Docs](https://readthedocs.org/projects/bayex/badge/?version=latest)](https://bayex.readthedocs.io/en/latest/) 7 | [![PyPI](https://img.shields.io/pypi/v/bayex.svg)](https://pypi.org/project/bayex/) 8 | 9 | >[!NOTE] 10 | >Bayex is currently a minimal, personally developed implementation that requires further development for broader application. If you're interested in engaging with Jax and enhancing Bayex, your contributions would be highly welcomed and appreciated. 11 | 12 | [**Installation**](#installation) 13 | | [**Usage**](#usage) 14 | | [**Reference docs**](https://bayex.readthedocs.io/en/latest/) 15 | 16 | Bayex is a lightweight Bayesian optimization library designed for efficiency and flexibility, leveraging the power of JAX for high-performance numerical computations. 17 | This library aims to provide an easy-to-use interface for optimizing expensive-to-evaluate functions through Gaussian Process (GP) models and various acquisition functions. Whether you're maximizing or minimizing your objective function, Bayex offers a simple yet powerful set of tools to guide your search for optimal parameters. 18 | 19 |

20 | 21 |

22 | 23 | 24 | ## Installation 25 | Bayex can be installed using [PyPI](https://pypi.org/project/bayex/) via `pip`: 26 | ``` 27 | pip install bayex 28 | ``` 29 | 30 | ## Usage 31 | Using Bayex is quite simple despite its low level approach: 32 | ```python 33 | import jax 34 | import numpy as np 35 | import bayex 36 | 37 | def f(x): 38 | return -(1.4 - 3 * x) * np.sin(18 * x) 39 | 40 | domain = {'x': bayex.domain.Real(0.0, 2.0)} 41 | optimizer = bayex.Optimizer(domain=domain, maximize=True, acq='PI') 42 | 43 | # Define some prior evaluations to initialise the GP. 44 | params = {'x': [0.0, 0.5, 1.0]} 45 | ys = [f(x) for x in params['x']] 46 | opt_state = optimizer.init(ys, params) 47 | 48 | # Sample new points using Jax PRNG approach. 49 | ori_key = jax.random.key(42) 50 | for step in range(20): 51 | key = jax.random.fold_in(ori_key, step) 52 | new_params = optimizer.sample(key, opt_state) 53 | y_new = f(**new_params) 54 | opt_state = optimizer.fit(opt_state, y_new, new_params) 55 | ``` 56 | 57 | with the results being saved at `opt_state.best_params`. 58 | 59 | ## Documentation 60 | Available at [https://bayex.readthedocs.io/en/latest](https://bayex.readthedocs.io/en/latest/). 61 | 62 | ## License 63 | Bayex is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details. 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Visual studio folks 121 | .vscode/ 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | 143 | # Poetry 144 | .python-version 145 | 146 | # CTags 147 | /tags* 148 | 149 | # Vim swap files 150 | *.swp 151 | *.swo 152 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Bayex 2 | =========================================== 3 | 4 | Minimal Bayesian Optimization in JAX 5 | ------------------------------------ 6 | 7 | .. image:: https://github.com/alonfnt/bayex/actions/workflows/tests.yml/badge.svg 8 | :target: https://github.com/alonfnt/bayex/actions/workflows/tests.yml 9 | :alt: Test status 10 | 11 | .. note:: 12 | Bayex is currently a minimal, personally developed implementation that requires further development for broader application. 13 | If you're interested in engaging with JAX and enhancing Bayex, your contributions would be highly welcomed and appreciated. 14 | 15 | .. raw:: html 16 | 17 |

18 | 19 |

20 | 21 | Bayex is a lightweight Bayesian optimization library designed for efficiency and flexibility, leveraging the power of JAX for high-performance numerical computations. 22 | 23 | This library aims to provide an easy-to-use interface for optimizing expensive-to-evaluate functions through Gaussian Process (GP) models and various acquisition functions. Whether you're maximizing or minimizing your objective function, Bayex offers a simple yet powerful set of tools to guide your search for optimal parameters. 24 | 25 | Installation 26 | ------------ 27 | 28 | Bayex can be installed using `PyPI `_ via ``pip``: 29 | 30 | .. code-block:: bash 31 | 32 | pip install bayex 33 | 34 | Usage 35 | ----- 36 | 37 | Using Bayex is quite simple despite its low-level approach: 38 | 39 | .. code-block:: python 40 | 41 | import jax 42 | import numpy as np 43 | import bayex 44 | 45 | def f(x): 46 | return -(1.4 - 3 * x) * np.sin(18 * x) 47 | 48 | domain = {'x': bayex.domain.Real(0.0, 2.0)} 49 | optimizer = bayex.Optimizer(domain=domain, maximize=True, acq='PI') 50 | 51 | # Define some prior evaluations to initialise the GP. 52 | params = {'x': [0.0, 0.5, 1.0]} 53 | ys = [f(x) for x in params['x']] 54 | opt_state = optimizer.init(ys, params) 55 | 56 | # Sample new points using Jax PRNG approach. 57 | ori_key = jax.random.key(42) 58 | for step in range(20): 59 | key = jax.random.fold_in(ori_key, step) 60 | new_params = optimizer.sample(key, opt_state) 61 | y_new = f(**new_params) 62 | opt_state = optimizer.fit(opt_state, y_new, new_params) 63 | 64 | With the results being saved at ``opt_state``. 65 | 66 | Contributing 67 | ------------ 68 | 69 | We welcome contributions to Bayex! Whether it's adding new features, improving documentation, or reporting issues, please feel free to make a pull request or open an issue. 70 | 71 | License 72 | ------- 73 | 74 | Bayex is licensed under the MIT License. See the `LICENSE `_ file for more details. 75 | 76 | .. toctree:: 77 | :maxdepth: 2 78 | :caption: API Reference 79 | 80 | api/index 81 | -------------------------------------------------------------------------------- /tests/test_gp.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | from bayex import gp 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "x1, x2, mask, expected", 9 | [ 10 | (jnp.array([0]), jnp.array([0]), 1, 1), 11 | (jnp.array([0]), jnp.array([1]), 1, jnp.exp(-1)), 12 | (jnp.array([1]), jnp.array([1]), 0, 0), 13 | ], 14 | ) 15 | def test_exp_quadratic(x1, x2, mask, expected): 16 | result = gp.exp_quadratic(x1, x2, mask) 17 | assert jnp.isclose(result, expected), f"Expected {expected}, got {result}" 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "x1, x2, mask1, mask2, expected_shape", 22 | [ 23 | (jnp.linspace(-5, 5, 10), jnp.linspace(-5, 5, 10), jnp.ones(10), jnp.ones(10), (10, 10)), 24 | (jnp.linspace(-5, 5, 5), jnp.linspace(-5, 5, 10), jnp.ones(5), jnp.ones(10), (5, 10)), 25 | ], 26 | ) 27 | def test_covariance_shape(x1, x2, mask1, mask2, expected_shape): 28 | cov_matrix = gp.cov(x1, x2, mask1, mask2) 29 | assert ( 30 | cov_matrix.shape == expected_shape 31 | ), f"Expected shape {expected_shape}, got {cov_matrix.shape}" 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "compute_ml, expected_output_type", 36 | [ 37 | (False, tuple), # Expecting mean and std as output 38 | (True, jnp.ndarray), # Expecting marginal likelihood as output 39 | ], 40 | ) 41 | def test_gaussian_process_output_type(compute_ml, expected_output_type): 42 | params = gp.GPParams(noise=0.1, amplitude=1.0, lengthscale=1.0) 43 | x = jnp.linspace(-5, 5, 10) 44 | y = jnp.sin(x) 45 | mask = jnp.ones_like(x, dtype=bool) 46 | xt = jnp.array([0.0]) if not compute_ml else None 47 | 48 | output = gp.gaussian_process(params, x, y, mask, xt, compute_ml=compute_ml) 49 | assert isinstance( 50 | output, expected_output_type 51 | ), f"Output type mismatch. Expected: {expected_output_type}, got: {type(output)}" 52 | 53 | 54 | @pytest.mark.parametrize("padding", [0, 1, 5, 10]) 55 | def test_masking_in_gaussian_process(padding: int): 56 | 57 | params = gp.GPParams(noise=0.1, amplitude=1.0, lengthscale=1.0) 58 | 59 | x = jnp.linspace(-5, 5, 10) 60 | y = jnp.sin(x) 61 | mask = jnp.ones_like(x, dtype=float) 62 | reference = gp.marginal_likelihood(params, x, y, mask) 63 | 64 | x_pad, y_pad, mask_pad = jax.tree.map( 65 | lambda x: jnp.pad(x, (0, padding)), (x, y, mask) 66 | ) 67 | assert len(x_pad) == len(x) + padding, "X should be padded to 10 + padding length" 68 | assert mask_pad[len(mask):].sum() == 0, "Mask should be zero for padded values" 69 | 70 | output = gp.marginal_likelihood(params, x_pad, y_pad, mask_pad) 71 | 72 | assert jnp.allclose(reference, output), f"Mismatch for padding={padding}" 73 | 74 | 75 | @pytest.mark.parametrize("padding", [0, 1, 5, 10]) 76 | def test_masking_in_prediction(padding: int): 77 | 78 | params = gp.GPParams(noise=0.1, amplitude=1.0, lengthscale=1.0) 79 | 80 | x = jnp.linspace(-5, 5, 10) 81 | y = jnp.sin(x) 82 | mask = jnp.ones_like(x) 83 | xt = jnp.linspace(-6, 6, 20) 84 | 85 | # Reference prediction without padding 86 | mean_ref, std_ref = gp.predict(params, x, y, mask, xt) 87 | 88 | # Apply padding 89 | x_pad, y_pad, mask_pad = jax.tree.map( 90 | lambda a: jnp.pad(a, (0, padding)), (x, y, mask) 91 | ) 92 | 93 | # Prediction with padded inputs 94 | mean_pad, std_pad = gp.predict(params, x_pad, y_pad, mask_pad, xt) 95 | 96 | assert mean_pad.shape == mean_ref.shape 97 | assert std_pad.shape == std_ref.shape 98 | assert jnp.allclose(mean_pad, mean_ref, atol=1e-5), f"Mean mismatch for padding={padding}" 99 | assert jnp.allclose(std_pad, std_ref, atol=1e-5), f"Std mismatch for padding={padding}" 100 | -------------------------------------------------------------------------------- /bayex/gp.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from functools import partial 3 | from typing import Any, Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jax.scipy.linalg import cholesky, solve_triangular 8 | import optax 9 | 10 | MASK_VARIANCE = 1e12 # High variance for masked points to not affect the process. 11 | 12 | GPParams = namedtuple("GPParams", ["noise", "amplitude", "lengthscale"]) 13 | GPState = namedtuple("GPState", ["params", "momentums", "scales"]) 14 | 15 | def exp_quadratic(x1, x2, mask): 16 | distance = jnp.sum((x1 - x2) ** 2) 17 | return jnp.exp(-distance) * mask 18 | 19 | 20 | def cov(x1, x2, mask1, mask2): 21 | M = jnp.outer(mask1, mask2) 22 | k = exp_quadratic 23 | return jax.vmap(jax.vmap(k, in_axes=(None, 0, 0)), in_axes=(0, None, 0))(x1, x2, M) 24 | 25 | 26 | def softplus(x): 27 | return jnp.logaddexp(x, 0.0) 28 | 29 | 30 | def gaussian_process( 31 | params, 32 | x: jnp.ndarray, 33 | y: jnp.ndarray, 34 | mask, 35 | xt: Optional[jnp.ndarray] = None, 36 | compute_ml: bool = False, 37 | ) -> Any: 38 | # Number of points in the prior distribution 39 | n = x.shape[0] 40 | 41 | noise, amp, ls = jax.tree_util.tree_map(softplus, params) 42 | 43 | ymean = jnp.mean(y, where=mask.astype(bool)) 44 | y = (y - ymean) * mask 45 | x = x / ls 46 | K = amp * cov(x, x, mask, mask) + (jnp.eye(n) * (noise + 1e-6)) 47 | K += jnp.eye(n) * (1.0 - mask.astype(float)) * MASK_VARIANCE 48 | L = cholesky(K, lower=True) 49 | K_inv_y = solve_triangular(L.T, solve_triangular(L, y, lower=True), lower=False) 50 | 51 | if compute_ml: 52 | logp = 0.5 * jnp.dot(y.T, K_inv_y) 53 | logp += jnp.sum(jnp.log(jnp.diag(L))) 54 | logp -= jnp.sum(1.0 - mask) * 0.5 * jnp.log(MASK_VARIANCE) 55 | logp += (jnp.sum(mask) / 2) * jnp.log(2 * jnp.pi) 56 | logp += jnp.sum(-0.5 * jnp.log(2*jnp.pi) - jnp.log(amp) - jnp.log(amp)**2) 57 | return jnp.sum(logp) 58 | 59 | assert xt is not None, "xt can't be None during prediction." 60 | xt = xt / ls 61 | 62 | # Compute the covariance with the new point xt 63 | mask_t = jnp.ones(len(xt))==1 64 | K_cross = amp * cov(x, xt, mask, mask_t) 65 | 66 | K_inv_y = K_inv_y * mask # masking 67 | pred_mean = jnp.dot(K_cross.T, K_inv_y) + ymean 68 | v = solve_triangular(L, K_cross, lower=True) 69 | pred_var = amp * cov(xt, xt, mask_t, mask_t) - v.T @ v 70 | pred_std = jnp.sqrt(jnp.maximum(jnp.diag(pred_var), 1e-10)) 71 | return pred_mean, pred_std 72 | 73 | 74 | marginal_likelihood = partial(gaussian_process, compute_ml=True) 75 | grad_fun = jax.jit(jax.grad(marginal_likelihood)) 76 | predict = jax.jit(partial(gaussian_process, compute_ml=False)) 77 | 78 | def neg_log_likelihood(params, x, y, mask): 79 | ll = marginal_likelihood(params, x, y, mask) 80 | 81 | # Weak priors to keep things sane 82 | # params = jax.tree.map(softplus, params) 83 | priors = GPParams(-8.0, 1.0, 1.0) 84 | log_prior = jax.tree.map(lambda p, m: jnp.sum((p - m) ** 2), params, priors) 85 | log_prior = sum(jax.tree.leaves(log_prior)) 86 | log_posterior = ll - 0.5 * log_prior 87 | return -log_posterior 88 | 89 | 90 | def posterior_fit( 91 | y: jax.Array, 92 | x: jax.Array, 93 | mask: jax.Array, 94 | params: GPParams, 95 | lr: float = 1e-3, 96 | trainsteps: int = 100, 97 | ) -> GPState: 98 | 99 | optimizer = optax.chain(optax.clip_by_global_norm(10.0), optax.adamw(lr)) 100 | opt_state = optimizer.init(params) 101 | 102 | def train_step(carry, _): 103 | params, opt_state = carry 104 | grads = jax.grad(neg_log_likelihood)(params, x, y, mask) 105 | updates, opt_state = optimizer.update(grads, opt_state, params=params) 106 | params = optax.apply_updates(params, updates) 107 | return (params, opt_state), None 108 | 109 | (params, _), __ = jax.lax.scan(train_step, (params, opt_state), None, length=trainsteps) 110 | return params 111 | -------------------------------------------------------------------------------- /bayex/acq.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.scipy.stats import norm 4 | 5 | from bayex.gp import GPParams, predict 6 | 7 | 8 | def expected_improvement( 9 | x_pred: jnp.ndarray, 10 | xs: jax.Array, 11 | ys: jax.Array, 12 | mask: jax.Array, 13 | gp_params: GPParams, 14 | xi: float = 0.01, 15 | ): 16 | r""" 17 | Expected Improvement (EI) acquisition function. 18 | 19 | Favors points with high improvement over the current best observed value, 20 | balancing exploitation and exploration. 21 | 22 | The formula is: 23 | 24 | .. math:: 25 | 26 | EI(x) = (\mu(x) - y^* - \xi) \Phi(z) + \sigma(x) \phi(z) 27 | 28 | where: 29 | 30 | .. math:: 31 | 32 | z = \frac{\mu(x) - y^* - \xi}{\sigma(x)} 33 | 34 | Args: 35 | x_pred: Candidate input locations to evaluate. 36 | xs: Observed inputs. 37 | ys: Observed function values. 38 | mask: Boolean mask indicating valid entries in `ys`. 39 | gp_params: Gaussian Process hyperparameters. 40 | xi: Exploration-exploitation tradeoff parameter. 41 | 42 | Returns: 43 | EI scores at `x_pred`. 44 | """ 45 | ymax = jnp.max(ys, where=mask.astype(bool), initial=-jnp.inf) 46 | mu, std = predict(gp_params, xs, ys, mask, xt=x_pred) 47 | a = mu - ymax - xi 48 | z = a / (std + 1e-3) 49 | ei = a * norm.cdf(z) + std * norm.pdf(z) 50 | return ei 51 | 52 | 53 | def probability_improvement( 54 | x_pred: jnp.ndarray, 55 | xs: jax.Array, 56 | ys: jax.Array, 57 | mask: jax.Array, 58 | gp_params: GPParams, 59 | xi: float = 0.01, 60 | ): 61 | r""" 62 | Probability of Improvement (PI) acquisition function. 63 | 64 | Estimates the probability that a candidate point will improve 65 | over the current best observed value. 66 | 67 | The formula is: 68 | 69 | .. math:: 70 | 71 | PI(x) = \Phi\left(\frac{\mu(x) - y^* - \xi}{\sigma(x)}\right) 72 | 73 | Args: 74 | x_pred: Candidate input locations to evaluate. 75 | xs: Observed inputs. 76 | ys: Observed function values. 77 | mask: Boolean mask indicating valid entries in `ys`. 78 | gp_params: Gaussian Process hyperparameters. 79 | xi: Improvement margin for sensitivity. 80 | 81 | Returns: 82 | PI scores at `x_pred`. 83 | """ 84 | y_max = ys.max() 85 | mu, std = predict(gp_params, xs, ys, mask, xt=x_pred) 86 | z = (mu - y_max - xi) / std 87 | return norm.cdf(z) 88 | 89 | 90 | def upper_confidence_bounds( 91 | x_pred: jnp.ndarray, 92 | xs: jax.Array, 93 | ys: jax.Array, 94 | mask: jax.Array, 95 | gp_params: GPParams, 96 | kappa: float = 0.01, 97 | ): 98 | r""" 99 | Upper Confidence Bound (UCB) acquisition function. 100 | 101 | Promotes exploration by favoring points with high predictive uncertainty. 102 | 103 | The formula is: 104 | 105 | .. math:: 106 | 107 | UCB(x) = \mu(x) + \kappa \cdot \sigma(x) 108 | 109 | Args: 110 | x_pred: Candidate input locations to evaluate. 111 | xs: Observed inputs. 112 | ys: Observed function values. 113 | mask: Boolean mask indicating valid entries in `ys`. 114 | gp_params: Gaussian Process hyperparameters. 115 | kappa: Weighting factor for uncertainty. 116 | 117 | Returns: 118 | UCB scores at `x_pred`. 119 | """ 120 | mu, std = predict(gp_params, xs, ys, mask, xt=x_pred) 121 | return mu + kappa * std 122 | 123 | 124 | def lower_confidence_bounds( 125 | x_pred: jnp.ndarray, 126 | xs: jax.Array, 127 | ys: jax.Array, 128 | mask: jax.Array, 129 | gp_params: GPParams, 130 | kappa: float = 2.576, 131 | ): 132 | r""" 133 | Lower Confidence Bound (LCB) acquisition function. 134 | 135 | Useful for minimization tasks. Encourages sampling in uncertain regions 136 | with low predicted values. 137 | 138 | The formula is: 139 | 140 | .. math:: 141 | 142 | LCB(x) = \mu(x) - \kappa \cdot \sigma(x) 143 | 144 | Args: 145 | x_pred: Candidate input locations to evaluate. 146 | xs: Observed inputs. 147 | ys: Observed function values. 148 | mask: Boolean mask indicating valid entries in `ys`. 149 | gp_params: Gaussian Process hyperparameters. 150 | kappa: Weighting factor for uncertainty. 151 | 152 | Returns: 153 | LCB scores at `x_pred`. 154 | """ 155 | mu, std = predict(gp_params, xs, ys, mask, xt=x_pred) 156 | return mu - kappa * std 157 | -------------------------------------------------------------------------------- /bayex/domain.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import jax 3 | 4 | 5 | class Domain: 6 | def __init__(self, dtype): 7 | self.dtype = dtype 8 | 9 | def __hash__(self): 10 | return hash(self.dtype) 11 | 12 | def __eq__(self, other): 13 | return self.dtype == other.dtype 14 | 15 | def transform(self, x: jax.Array): 16 | raise NotImplementedError 17 | 18 | def sample(self, key: jax.Array, shape: Tuple): 19 | raise NotImplementedError 20 | 21 | 22 | class Real(Domain): 23 | """ 24 | Continuous real-valued domain with clipping. 25 | 26 | Represents a parameter that can take real values within [lower, upper]. 27 | """ 28 | 29 | def __init__(self, lower, upper): 30 | """ 31 | Initializes a real domain with bounds. 32 | 33 | Args: 34 | lower: Lower bound (inclusive). 35 | upper: Upper bound (inclusive). 36 | """ 37 | assert isinstance(lower, float) or isinstance(lower, int), "Lower bound must be a float" 38 | assert isinstance(upper, float) or isinstance(lower, int), "Upper bound must be a float" 39 | assert lower < upper, "Lower bound must be less than upper bound" 40 | 41 | self.lower = float(lower) 42 | self.upper = float(upper) 43 | super().__init__(dtype=jax.numpy.float32) 44 | 45 | def __hash__(self): 46 | return hash((self.lower, self.upper)) 47 | 48 | def __eq__(self, other): 49 | return self.lower == other.lower and self.upper == other.upper 50 | 51 | def transform(self, x: jax.Array): 52 | """ 53 | Clips values to the domain range [lower, upper]. 54 | 55 | Args: 56 | x: Input values. 57 | 58 | Returns: 59 | Clipped values within bounds. 60 | """ 61 | return jax.numpy.clip(x, self.lower, self.upper) 62 | 63 | def sample(self, key: jax.Array, shape: Tuple): 64 | """ 65 | Samples uniformly from the domain. 66 | 67 | Args: 68 | key: JAX PRNGKey. 69 | shape: Desired output shape. 70 | 71 | Returns: 72 | Sampled values clipped to the domain. 73 | """ 74 | samples = jax.random.uniform(key, shape, minval=self.lower, maxval=self.upper) 75 | return self.transform(samples) 76 | 77 | 78 | class Integer(Domain): 79 | """ 80 | Discrete integer-valued domain with rounding and clipping. 81 | 82 | Represents a parameter that can take integer values within [lower, upper]. 83 | """ 84 | 85 | def __init__(self, lower, upper): 86 | """ 87 | Initializes an integer domain with bounds. 88 | 89 | Args: 90 | lower: Lower integer bound (inclusive). 91 | upper: Upper integer bound (inclusive). 92 | """ 93 | assert isinstance(lower, int), "Lower bound must be an integer" 94 | assert isinstance(upper, int), "Upper bound must be an integer" 95 | assert lower < upper, "Lower bound must be less than upper bound" 96 | 97 | self.lower = int(lower) 98 | self.upper = int(upper) 99 | super().__init__(dtype=jax.numpy.int32) 100 | 101 | def __hash__(self): 102 | return hash((self.lower, self.upper)) 103 | 104 | def __eq__(self, other): 105 | return self.lower == other.lower and self.upper == other.upper 106 | 107 | def transform(self, x: jax.Array): 108 | """ 109 | Rounds and clips values to the integer domain. 110 | 111 | Args: 112 | x: Input values. 113 | 114 | Returns: 115 | Rounded and clipped values as float32. 116 | """ 117 | return jax.numpy.clip(jax.numpy.round(x), self.lower, self.upper).astype(jax.numpy.float32) 118 | 119 | def sample(self, key: jax.Array, shape: Tuple): 120 | """ 121 | Samples integers uniformly from the domain. 122 | 123 | Args: 124 | key: JAX PRNGKey. 125 | shape: Desired output shape. 126 | 127 | Returns: 128 | Sampled values clipped to valid integer range. 129 | """ 130 | samples = jax.random.randint(key, shape, minval=self.lower, maxval=self.upper + 1) 131 | return self.transform(samples) 132 | 133 | 134 | class ParamSpace: 135 | """ 136 | Internal class that manages a collection of named parameter domains. 137 | 138 | This utility encapsulates logic for sampling, transforming, and handling 139 | structured parameter inputs defined by a mapping of variable names to Domain 140 | instances (e.g., Real, Integer). 141 | 142 | Example: 143 | >>> space = ParamSpace({ 144 | ... "x1": Real(0.0, 1.0), 145 | ... "x2": Integer(1, 5) 146 | ... }) 147 | >>> key = jax.random.PRNGKey(0) 148 | >>> samples = space.sample_tree(key, (128,)) 149 | >>> xs = space.transform_tree(samples) 150 | 151 | Notes: 152 | This class is intended for internal use by the optimizer and should not 153 | be exposed as part of the public API. 154 | """ 155 | 156 | def __init__(self, space: dict): 157 | self.space = space 158 | 159 | def sample_params(self, key: jax.Array, shape: Tuple) -> dict: 160 | keys = jax.random.split(key, len(self.space)) 161 | return {name: self.space[name].sample(k, shape) for name, k in zip(self.space, keys)} 162 | 163 | def to_array(self, tree: dict) -> jax.Array: 164 | """ 165 | Transforms a batch of parameter values into a 2D array suitable for GP input. 166 | 167 | Applies each domain's `.transform()` to its corresponding parameter values. 168 | 169 | Args: 170 | tree: A dictionary of parameter name → array of raw values. 171 | 172 | Returns: 173 | A JAX array of shape (batch_size, num_params) with transformed values. 174 | """ 175 | return jax.numpy.stack([self.space[k].transform(tree[k]) for k in self.space], axis=1) 176 | 177 | 178 | def to_dict(self, xs: jax.Array) -> dict: 179 | """ 180 | Converts a stacked parameter matrix back into named parameter trees. 181 | 182 | Typically used after optimization in transformed space. 183 | 184 | Args: 185 | xs: A 2D JAX array of shape (batch_size, num_params), with each column 186 | corresponding to a parameter. 187 | 188 | Returns: 189 | A dictionary mapping parameter names to individual 1D arrays. 190 | """ 191 | return {k: self.space[k].transform(xs[:, i]) for i, k in enumerate(self.space)} -------------------------------------------------------------------------------- /bayex/optimizer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Union, NamedTuple, Callable 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import optax 8 | 9 | import bayex.acq as boacq 10 | from bayex.gp import GPParams, posterior_fit 11 | from bayex.domain import ParamSpace 12 | 13 | 14 | class OptimizerState(NamedTuple): 15 | """ 16 | Container for the state of the Bayesian optimizer. 17 | 18 | Attributes: 19 | params (dict): Dictionary mapping parameter names to their 20 | corresponding padded JAX arrays of observed values. 21 | ys (jax.Array or np.ndarray): Array of objective values associated 22 | with the observed parameters. Includes padding. 23 | best_score (float): Best observed objective value so far. 24 | best_params (dict): Parameter configuration corresponding to 25 | the best_score. 26 | mask (jax.Array): Boolean array indicating which entries in 27 | `params` and `ys` are valid (i.e., not padding). 28 | gp_params (GPParams): Parameters of the Gaussian Process 29 | fitted to the observations. 30 | """ 31 | params: dict 32 | ys: Union[jax.Array, np.ndarray] 33 | best_score: float 34 | best_params: dict 35 | mask: jax.Array 36 | gp_params: GPParams 37 | 38 | 39 | def _optimize_suggestion(params: dict, fun: Callable, max_iter: int = 10): 40 | """ 41 | Applies local optimization (L-BFGS) to a given starting point. 42 | 43 | This function refines candidate points proposed by the acquisition 44 | function by performing a fixed number of gradient-based optimization 45 | steps using Optax's L-BFGS optimizer. 46 | 47 | Args: 48 | params (jax.Array): Initial point in input space to optimize. 49 | fun (Callable): Objective function to maximize. It must return a 50 | scalar value and support automatic differentiation. 51 | max_iter (int): Maximum number of L-BFGS steps to apply. 52 | 53 | Returns: 54 | jax.Array: The optimized point after `max_iter` iterations. 55 | """ 56 | 57 | # L-BFGS optimizer is used for minimization, so to maximize acquisition 58 | # we need to negate the function value. 59 | opt = optax.lbfgs() 60 | value_and_grad_fun = optax.value_and_grad_from_state(lambda x: -fun(x)) 61 | 62 | def step(carry, _): 63 | params, state = carry 64 | value, grad = value_and_grad_fun(params, state=state) 65 | updates, state = opt.update(grad, state, params, 66 | value=value, grad=grad, value_fn=fun) 67 | params = optax.apply_updates(params, updates) 68 | params = jnp.clip(params, -1e6, 1e6) 69 | return (params, state), None 70 | 71 | init_carry = (params, opt.init(params)) 72 | (final_params, _), __ = jax.lax.scan(step, init_carry, None, length=max_iter) 73 | return final_params 74 | 75 | 76 | class Optimizer: 77 | """ 78 | Bayesian optimizer using Gaussian Processes and acquisition functions. 79 | 80 | This class manages the optimization loop for expensive black-box functions 81 | by modeling them with a Gaussian Process and selecting samples via 82 | acquisition functions such as EI, PI, UCB, or LCB. 83 | """ 84 | 85 | def __init__(self, domain: dict, acq: str = 'EI', maximize: bool = False): 86 | """ 87 | Initializes the optimizer. 88 | 89 | Args: 90 | domain: A dict mapping parameter names to domain objects (e.g., Real, Integer). 91 | acq: Acquisition function ('EI', 'PI', 'UCB', 'LCB'). 92 | maximize: Whether to maximize or minimize the objective. 93 | """ 94 | self.domain = domain 95 | self.sign = 1 if maximize else -1 96 | self.param_space = ParamSpace(domain) 97 | 98 | if acq == 'EI': 99 | self.acq = jax.jit(boacq.expected_improvement) 100 | elif acq == 'PI': 101 | self.acq = jax.jit(boacq.probability_improvement) 102 | elif acq == 'UCB': 103 | self.acq = jax.jit(boacq.upper_confidence_bounds) 104 | elif acq == 'LCB': 105 | self.acq = jax.jit(boacq.lower_confidence_bounds) 106 | else: 107 | raise ValueError(f"Acquisition function {acq} is not implemented") 108 | 109 | def init(self, ys: jax.Array, params: dict, noise_scale: float = -8.0): 110 | """ 111 | Initializes the optimizer state from initial data. 112 | 113 | Args: 114 | ys: Objective values for the initial parameters. 115 | params: Dict of parameter arrays (same keys as domain). 116 | 117 | Returns: 118 | Initialized OptimizerState. 119 | """ 120 | # Create a padded jax array for each parameter and each score. 121 | # In order to keep jax compilations at a bay. 122 | num_entries = len(ys) 123 | pad_value = int(np.ceil(len(ys) / 10) * 10) 124 | 125 | # Convert to jax arrays if they are not already 126 | ys = jnp.asarray(ys) 127 | ys = self.sign * ys 128 | params = jax.tree.map(lambda x: jnp.asarray(x), params) 129 | 130 | # Define padded arrays for the inputs and the outputs 131 | mask = jnp.zeros(shape=(pad_value,), dtype=jnp.bool_).at[:num_entries].set(True) 132 | ys = jnp.zeros(shape=(pad_value,), dtype=ys.dtype).at[:num_entries].set(ys) 133 | 134 | _params = {} 135 | for key, entries in params.items(): 136 | # Assert that the parameter is in the domain dictionary 137 | assert key in self.domain, f"Parameter {key} is not in the domain" 138 | 139 | # Get dytpe from the domain and create a padded array 140 | dtype = self.domain[key].dtype 141 | values = jnp.zeros(shape=(pad_value,), dtype=dtype).at[:num_entries].set(entries) 142 | _params[key] = values 143 | 144 | # From the given observation, find the better one (either maxima or minima) and return the 145 | # initial optizer state. 146 | best_score = float(jnp.max(ys[mask])) 147 | best_params_idx = jnp.argmax(ys[mask]) 148 | best_params = jax.tree.map(lambda x: x[mask][best_params_idx], _params) 149 | 150 | # Initialize the gaussian processes state 151 | gp_params = GPParams( 152 | noise=jnp.full((1, 1), 1. * noise_scale), 153 | amplitude=jnp.zeros((1, 1)), 154 | lengthscale=jnp.zeros((1, len(_params))) 155 | ) 156 | 157 | # Fit to the current observations 158 | xs = jnp.stack([self.domain[key].transform(_params[key]) for key in _params], axis=1) 159 | gp_params = posterior_fit(ys, xs, mask=mask, params=gp_params) 160 | 161 | ys = self.sign * ys 162 | best_score = self.sign * best_score 163 | opt_state = OptimizerState(params=_params, ys=ys, best_score=best_score, 164 | best_params=best_params, mask=mask, gp_params=gp_params) 165 | 166 | return opt_state 167 | 168 | @partial(jax.jit, static_argnames=('self', 'size')) 169 | def sample(self, key, state, size=10_000): 170 | """ 171 | Samples new parameters using the acquisition function. 172 | 173 | Args: 174 | key: JAX PseudoRandom key for random sampling. 175 | opt_state: Current optimizer state. 176 | size: Number of samples to draw. 177 | has_prior: If True, also return GP predictions. 178 | 179 | Returns: 180 | Sampled parameters (dict), and optionally (xs_samples, means, stds). 181 | """ 182 | # Sample 'size' elements of each distribution. 183 | samples = self.param_space.sample_params(key, (size,)) 184 | xs_samples = self.param_space.to_array(samples) 185 | 186 | # Prepare the data for the Gaussian process prediction. 187 | xs = self.param_space.to_array(state.params) 188 | ys = self.sign * state.ys 189 | mask = state.mask 190 | gp_params = state.gp_params 191 | 192 | # Compute the acquisition function values for the sampled points. 193 | acq_vals = self.acq(xs_samples, xs, ys, mask, gp_params) 194 | 195 | # Of those, find the best 50 points and optimize them using BFGS. 196 | top_idxs = jnp.argsort(acq_vals)[-50:] 197 | init_points = xs_samples[top_idxs] 198 | f = lambda x: jnp.squeeze(self.acq(x[None, :], xs, ys, mask, gp_params)) 199 | optimized = jax.vmap(lambda x: _optimize_suggestion(x, f, max_iter=10))(init_points) 200 | opt_vals = self.acq(optimized, xs, ys, mask, gp_params) 201 | 202 | # Return the best point from the optimized points and the sampled points. 203 | all_points = jnp.concatenate((optimized, xs_samples), axis=0) 204 | all_vals = jnp.concatenate((opt_vals, acq_vals), axis=0) 205 | chosen_suggestion = jnp.argmax(all_vals) 206 | 207 | best_params = self.param_space.to_dict(all_points[chosen_suggestion][None]) 208 | return best_params 209 | 210 | 211 | def expand(self, opt_state: OptimizerState): 212 | """ 213 | Expands internal buffers if no space is available. 214 | 215 | Args: 216 | opt_state: Current optimizer state. 217 | 218 | Returns: 219 | OptimizerState with expanded storage. 220 | """ 221 | current = jnp.sum(opt_state.mask) 222 | 223 | if current == len(opt_state.mask): 224 | pad_value = int(np.ceil(len(opt_state.mask)*2 / 10) * 10) 225 | diff = pad_value - len(opt_state.mask) 226 | mask = jnp.pad(opt_state.mask, (0, diff)) 227 | ys = jnp.pad(opt_state.ys, (0, diff)) 228 | params = {} 229 | for key in opt_state.params: 230 | params[key] = jnp.pad(opt_state.params[key], (0, diff)) 231 | else: 232 | mask = opt_state.mask 233 | ys = opt_state.ys 234 | params = opt_state.params 235 | 236 | opt_state = OptimizerState(params=params, ys=ys, best_score=opt_state.best_score, 237 | best_params=opt_state.best_params, mask=mask, 238 | gp_params=opt_state.gp_params) 239 | return opt_state 240 | 241 | 242 | def fit(self, opt_state, y, new_params): 243 | """ 244 | Updates optimizer state with a new observation. 245 | 246 | Args: 247 | opt_state: Current optimizer state. 248 | y: New objective value. 249 | new_params: Parameters that produced y. 250 | 251 | Returns: 252 | Updated OptimizerState. 253 | """ 254 | opt_state = self.expand(opt_state) # Prompts recompilation 255 | opt_state = self._fit(opt_state, y, new_params) 256 | return opt_state 257 | 258 | 259 | @partial(jax.jit, static_argnums=(0,)) 260 | def _fit(self, opt_state, y, new_params): 261 | last_idx = jnp.arange(len(opt_state.mask)) == jnp.argmin(opt_state.mask) 262 | mask = jnp.asarray(jnp.where(last_idx, True, opt_state.mask)) 263 | ys = jnp.where(last_idx, y, opt_state.ys) 264 | params = jax.tree_util.tree_map(lambda x, y: jnp.where(last_idx, y, x), opt_state.params, new_params) 265 | 266 | xs = jnp.stack([self.domain[key].transform(params[key]) 267 | for key in params], axis=1) 268 | ys = self.sign * ys 269 | gp_params = posterior_fit(ys, xs, mask=mask, params=opt_state.gp_params) 270 | 271 | best_score = jnp.max(ys, where=mask, initial=-jnp.inf) 272 | best_params_idx = jnp.argmax(jnp.where(mask, ys, -jnp.inf)) 273 | best_params = jax.tree_util.tree_map(lambda x: x[best_params_idx], params) 274 | 275 | ys = self.sign * ys 276 | best_score = self.sign * best_score 277 | opt_state = OptimizerState(params=params, ys=ys, best_score=best_score, 278 | best_params=best_params, mask=mask, 279 | gp_params=gp_params) 280 | return opt_state 281 | 282 | --------------------------------------------------------------------------------