├── ax ├── metrics │ ├── curve.py │ ├── tests │ │ ├── test_curve.py │ │ └── __init__.py │ ├── chemistry_data.zip │ ├── l2norm.py │ ├── __init__.py │ ├── hartmann6.py │ └── branin.py ├── utils │ ├── testing │ │ ├── manifest.py │ │ ├── __init__.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ └── backend_simulator_map.py │ │ ├── tests │ │ │ └── __init__.py │ │ ├── utils_testing_stubs.py │ │ ├── test_init_files.py │ │ └── torch_stubs.py │ ├── __init__.py │ ├── stats │ │ ├── __init__.py │ │ └── tests │ │ │ └── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── test_serialization.py │ │ │ ├── test_func_enum.py │ │ │ ├── test_docutils.py │ │ │ ├── test_kwargutils.py │ │ │ └── test_typeutils.py │ │ ├── typeutils_nonnative.py │ │ ├── mock.py │ │ ├── typeutils_torch.py │ │ ├── deprecation.py │ │ ├── random.py │ │ ├── func_enum.py │ │ ├── timeutils.py │ │ ├── docutils.py │ │ └── base.py │ ├── notebook │ │ ├── __init__.py │ │ └── plotting.py │ ├── report │ │ ├── __init__.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ └── test_render.py │ │ └── resources │ │ │ ├── __init__.py │ │ │ ├── simple_template.html │ │ │ ├── sufficient_statistic.html │ │ │ └── base_template.html │ ├── tutorials │ │ └── __init__.py │ ├── flake8_plugins │ │ └── __init__.py │ ├── measurement │ │ ├── __init__.py │ │ └── tests │ │ │ └── __init__.py │ └── sensitivity │ │ ├── __init__.py │ │ └── tests │ │ └── __init__.py ├── benchmark │ ├── problems │ │ ├── synthetic │ │ │ ├── __init__.py │ │ │ ├── hss │ │ │ │ └── __init__.py │ │ │ └── discretized │ │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── hpo │ │ │ └── __init__.py │ │ ├── runtime_funcs.py │ │ └── hd_embedding.py │ ├── __init__.py │ ├── methods │ │ ├── __init__.py │ │ └── sobol.py │ ├── benchmark_test_functions │ │ └── __init__.py │ ├── tests │ │ ├── __init__.py │ │ ├── methods │ │ │ └── __init__.py │ │ └── problems │ │ │ └── __init__.py │ ├── benchmark_step_runtime_function.py │ ├── benchmark_trial_metadata.py │ └── benchmark_test_function.py ├── plot │ ├── __init__.py │ ├── css │ │ ├── __init__.py │ │ └── base.css │ ├── js │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── css.js │ │ │ ├── plotly_requires.js │ │ │ ├── plotly_offline.js │ │ │ └── plotly_online.js │ │ └── generic_plotly.js │ └── tests │ │ ├── __init__.py │ │ ├── test_parallel_coordinates.py │ │ ├── test_helper.py │ │ ├── test_slices.py │ │ └── test_diagnostic.py ├── core │ ├── tests │ │ ├── __init__.py │ │ ├── test_risk_measures.py │ │ ├── test_auxiliary.py │ │ └── test_map_metric.py │ ├── map_metric.py │ ├── auxiliary.py │ └── __init__.py ├── exceptions │ ├── __init__.py │ ├── model.py │ ├── constants.py │ ├── storage.py │ └── data_provider.py ├── preview │ ├── __init__.py │ ├── api │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── instantiation │ │ │ │ └── __init__.py │ │ │ └── storage.py │ │ ├── protocols │ │ │ ├── __init__.py │ │ │ ├── metric.py │ │ │ └── runner.py │ │ ├── types.py │ │ └── __init__.py │ └── modelbridge │ │ └── __init__.py ├── analysis │ ├── tests │ │ ├── __init__.py │ │ └── test_utils.py │ ├── plotly │ │ ├── tests │ │ │ └── __init__.py │ │ ├── surface │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── arm_effects │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── plotly_analysis.py │ ├── markdown │ │ └── __init__.py │ ├── __init__.py │ └── healthcheck │ │ ├── __init__.py │ │ ├── healthcheck_analysis.py │ │ ├── tests │ │ └── test_should_generate_candidates.py │ │ └── should_generate_candidates.py ├── models │ ├── discrete │ │ ├── __init__.py │ │ └── eb_thompson.py │ ├── random │ │ ├── __init__.py │ │ └── uniform.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_base.py │ │ ├── test_discrete.py │ │ └── test_randomforest.py │ ├── torch │ │ ├── __init__.py │ │ ├── tests │ │ │ └── __init__.py │ │ └── botorch_modular │ │ │ ├── __init__.py │ │ │ └── input_constructors │ │ │ └── __init__.py │ ├── __init__.py │ ├── types.py │ └── winsorization_config.py ├── runners │ ├── tests │ │ ├── __init__.py │ │ └── test_single_running_trial_mixin.py │ ├── __init__.py │ └── synthetic.py ├── service │ ├── utils │ │ ├── __init__.py │ │ ├── best_point_utils.py │ │ └── early_stopping.py │ ├── tests │ │ ├── __init__.py │ │ └── test_early_stopping.py │ └── __init__.py ├── modelbridge │ ├── tests │ │ └── __init__.py │ ├── transforms │ │ ├── __init__.py │ │ ├── tests │ │ │ └── test_rounding_transform.py │ │ ├── search_space_to_float.py │ │ └── deprecated_transform_mixin.py │ └── __init__.py ├── storage │ ├── json_store │ │ ├── tests │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── load.py │ │ └── save.py │ ├── sqa_store │ │ ├── tests │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── structs.py │ │ ├── timestamp.py │ │ └── reduced_state.py │ ├── __init__.py │ └── tests │ │ ├── test_registry_bundle.py │ │ └── test_botorch_modular_registry.py ├── early_stopping │ ├── tests │ │ └── __init__.py │ ├── __init__.py │ └── strategies │ │ └── __init__.py ├── global_stopping │ ├── tests │ │ └── __init__.py │ ├── __init__.py │ └── strategies │ │ └── __init__.py └── __init__.py ├── website ├── static │ ├── .nojekyll │ ├── CNAME │ ├── img │ │ ├── favicon.png │ │ ├── oss_logo.png │ │ ├── favicon │ │ │ └── favicon.ico │ │ ├── ax_wireframe.svg │ │ ├── database-solid.svg │ │ ├── th-large-solid.svg │ │ ├── dice-solid.svg │ │ ├── ax.svg │ │ ├── ax_lockup.svg │ │ ├── ax_logo_lockup.svg │ │ └── ax_lockup_white.svg │ └── js │ │ ├── mathjax.js │ │ └── plotUtils.js ├── versioned_docs │ └── .gitkeep ├── versioned_sidebars │ └── .gitkeep ├── sidebars.json ├── package.json ├── core │ └── Footer.js └── tutorials.json ├── .github ├── ISSUE_TEMPLATE │ ├── config.yaml │ ├── 3_general_support.yaml │ ├── 2_feature_request.yaml │ └── 1_bug_report.yaml └── workflows │ ├── tutorials.yml │ ├── lint.yml │ ├── cron_pinned.yml │ ├── build-and-test.yml │ ├── reusable_tutorials.yml │ └── deploy.yml ├── docs ├── assets │ ├── gp_opt.png │ ├── bo_1d_opt.gif │ ├── mab_probs.png │ ├── mab_regret.png │ ├── gp_posterior.png │ ├── mab_animate.gif │ ├── bandit_allocation.png │ └── example_shrinkage.png ├── algo-overview.md └── why-ax.md ├── CHANGELOG.md ├── requirements-fmt.txt ├── sphinx └── source │ ├── ax.rst │ ├── health_check.rst │ ├── global_stopping.rst │ ├── index.rst │ ├── runners.rst │ ├── exceptions.rst │ ├── early_stopping.rst │ ├── metrics.rst │ ├── preview.rst │ └── service.rst ├── pyproject.toml ├── pytest.ini ├── scripts ├── import_ax.py └── patch_site_config.py ├── .flake8 ├── .pre-commit-config.yaml └── LICENSE /ax/metrics/curve.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ax/utils/testing/manifest.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /website/static/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ax/metrics/tests/test_curve.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /website/static/CNAME: -------------------------------------------------------------------------------- 1 | ax.dev 2 | -------------------------------------------------------------------------------- /website/versioned_docs/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /website/versioned_sidebars/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yaml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | -------------------------------------------------------------------------------- /docs/assets/gp_opt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/gp_opt.png -------------------------------------------------------------------------------- /docs/assets/bo_1d_opt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/bo_1d_opt.gif -------------------------------------------------------------------------------- /docs/assets/mab_probs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/mab_probs.png -------------------------------------------------------------------------------- /docs/assets/mab_regret.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/mab_regret.png -------------------------------------------------------------------------------- /ax/metrics/chemistry_data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/ax/metrics/chemistry_data.zip -------------------------------------------------------------------------------- /docs/assets/gp_posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/gp_posterior.png -------------------------------------------------------------------------------- /docs/assets/mab_animate.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/mab_animate.gif -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | Ax uses GitHub tags for managing releases. See changelog [here](https://github.com/facebook/Ax/releases). 2 | -------------------------------------------------------------------------------- /website/static/img/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/website/static/img/favicon.png -------------------------------------------------------------------------------- /website/static/img/oss_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/website/static/img/oss_logo.png -------------------------------------------------------------------------------- /docs/assets/bandit_allocation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/bandit_allocation.png -------------------------------------------------------------------------------- /docs/assets/example_shrinkage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/example_shrinkage.png -------------------------------------------------------------------------------- /website/static/img/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/website/static/img/favicon/favicon.ico -------------------------------------------------------------------------------- /requirements-fmt.txt: -------------------------------------------------------------------------------- 1 | # generated by `pyfmt --requirements` 2 | black==24.4.2 3 | ruff-api==0.1.0 4 | stdlibs==2024.1.28 5 | ufmt==2.8.0 6 | usort==1.0.8.post1 7 | -------------------------------------------------------------------------------- /docs/algo-overview.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: algo-overview 3 | title: Overview 4 | --- 5 | 6 | Ax supports: 7 | * Bandit optimization 8 | * Empirical Bayes with Thompson sampling 9 | * Bayesian optimization 10 | -------------------------------------------------------------------------------- /sphinx/source/ax.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax 5 | =================================== 6 | 7 | .. automodule:: ax 8 | :members: 9 | :noindex: 10 | 11 | .. currentmodule:: ax 12 | -------------------------------------------------------------------------------- /ax/benchmark/problems/synthetic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/plot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/core/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/plot/css/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/plot/js/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/preview/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/stats/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/analysis/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | -------------------------------------------------------------------------------- /ax/benchmark/problems/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | -------------------------------------------------------------------------------- /ax/metrics/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/discrete/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/random/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/torch/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/plot/js/common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/runners/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/service/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/notebook/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/report/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/testing/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/tutorials/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/analysis/plotly/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/problems/hpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | -------------------------------------------------------------------------------- /ax/modelbridge/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/torch/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/preview/api/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/preview/modelbridge/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/common/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/flake8_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/measurement/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/report/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/sensitivity/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/testing/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/testing/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=34.4", "wheel", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.usort] 6 | first_party_detection = false 7 | 8 | [tool.ufmt] 9 | formatter = "ruff-api" 10 | -------------------------------------------------------------------------------- /ax/storage/json_store/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/measurement/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/report/resources/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/sensitivity/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/benchmark_test_functions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | -------------------------------------------------------------------------------- /ax/benchmark/problems/synthetic/hss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | -------------------------------------------------------------------------------- /ax/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/preview/api/utils/instantiation/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/problems/synthetic/discretized/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | -------------------------------------------------------------------------------- /ax/plot/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/service/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/benchmark/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/utils/stats/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/benchmark/tests/methods/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/early_stopping/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/global_stopping/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/modelbridge/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/benchmark/tests/problems/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/models/torch/botorch_modular/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /website/sidebars.json: -------------------------------------------------------------------------------- 1 | { 2 | "docs": { 3 | "Introduction": ["why-ax"], 4 | "Getting Started": ["installation", "api", "glossary"], 5 | "Algorithms": ["bayesopt", "banditopt"], 6 | "Components": ["core", "trial-evaluation", "data", "models", "storage"] 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /ax/models/torch/botorch_modular/input_constructors/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /ax/utils/report/resources/simple_template.html: -------------------------------------------------------------------------------- 1 | 4 | {% extends "base_template.html" %} 5 | {% block content %} 6 | {% for element in html_elements %} 7 | {{element}} 8 | {% endfor %} 9 | {% endblock %} 10 | -------------------------------------------------------------------------------- /ax/plot/js/generic_plotly.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | Plotly.newPlot( 9 | {{id}}, 10 | {{data}}, 11 | {{layout}}, 12 | {"showLink": false} 13 | ); 14 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # Configuration for pytest. 2 | [pytest] 3 | filterwarnings = 4 | # Filter out parameters and sklearn deprecation warnings. 5 | ignore::DeprecationWarning:.*paramz.* 6 | ignore::DeprecationWarning:.*sklearn* 7 | # Filter out numpy non-integer indices warning. 8 | ignore::DeprecationWarning:.*using a non-integer array as obj in delete* 9 | -------------------------------------------------------------------------------- /ax/early_stopping/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from ax.early_stopping import strategies 11 | 12 | __all__ = ["strategies"] 13 | -------------------------------------------------------------------------------- /ax/global_stopping/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from ax.global_stopping import strategies 11 | 12 | __all__ = ["strategies"] 13 | -------------------------------------------------------------------------------- /website/static/img/ax_wireframe.svg: -------------------------------------------------------------------------------- 1 | 06_OneColor_White -------------------------------------------------------------------------------- /ax/service/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.service.managed_loop import OptimizationLoop, optimize 10 | 11 | 12 | __all__ = ["OptimizationLoop", "optimize"] 13 | -------------------------------------------------------------------------------- /ax/storage/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.storage.json_store import load as json_load, save as json_save 10 | 11 | 12 | __all__ = ["json_save", "json_load"] 13 | -------------------------------------------------------------------------------- /ax/plot/js/common/css.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | var css = document.createElement('style'); 9 | css.type = 'text/css'; 10 | css.innerHTML = "{{css}}"; 11 | document.getElementsByTagName("head")[0].appendChild(css); 12 | -------------------------------------------------------------------------------- /scripts/import_ax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ax 8 | from ax.service.ax_client import AxClient 9 | 10 | 11 | if __name__ == "__main__": 12 | assert ax is not None 13 | assert AxClient is not None 14 | -------------------------------------------------------------------------------- /ax/analysis/plotly/surface/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.analysis.plotly.surface.contour import ContourPlot 9 | from ax.analysis.plotly.surface.slice import SlicePlot 10 | 11 | __all__ = ["ContourPlot", "SlicePlot"] 12 | -------------------------------------------------------------------------------- /ax/plot/js/common/plotly_requires.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | require(['plotly'], function(Plotly) { 9 | window.PLOTLYENV = window.PLOTLYENV || {}; 10 | window.PLOTLYENV.BASE_URL = 'https://plot.ly'; 11 | {{script}} 12 | }); 13 | -------------------------------------------------------------------------------- /ax/analysis/markdown/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.analysis.markdown.markdown_analysis import ( 9 | MarkdownAnalysis, 10 | MarkdownAnalysisCard, 11 | ) 12 | 13 | __all__ = ["MarkdownAnalysis", "MarkdownAnalysisCard"] 14 | -------------------------------------------------------------------------------- /sphinx/source/health_check.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.health_check 5 | =============== 6 | 7 | .. automodule:: ax.health_check 8 | .. currentmodule:: ax.health_check 9 | 10 | Ax Experiment Health Checks 11 | --------------------------- 12 | 13 | Search Space 14 | ~~~~~~~~~~~~ 15 | 16 | .. automodule:: ax.health_check.search_space 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /website/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "scripts": { 3 | "examples": "docusaurus-examples", 4 | "start": "docusaurus-start", 5 | "build": "docusaurus-build", 6 | "publish-gh-pages": "docusaurus-publish", 7 | "write-translations": "docusaurus-write-translations", 8 | "version": "docusaurus-version", 9 | "rename-version": "docusaurus-rename-version" 10 | }, 11 | "devDependencies": { 12 | "docusaurus": "^1.7.2" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | 3 | # E704: the linter doesn't parse types properly 4 | # T499, T484: silence mypy since using pyre for typechecking 5 | # W503: black and flake8 disagree on how to place operators 6 | # E231: black and flake8 disagree on whitespace after ',' 7 | # E203: black and flake8 disagree on whitespace before ':' 8 | ignore = T484, T499, W503, E704, E231, E203 9 | 10 | # Black really wants lines to be 88 chars... 11 | max-line-length = 88 12 | -------------------------------------------------------------------------------- /ax/plot/js/common/plotly_offline.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | if (!window.Plotly) { 9 | define('plotly', function(require, exports, module) { 10 | {{library}} 11 | }); 12 | require(['plotly'], function(Plotly) { 13 | window.Plotly = Plotly; 14 | }); 15 | } 16 | -------------------------------------------------------------------------------- /ax/preview/api/protocols/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.preview.api.protocols.metric import IMetric 10 | from ax.preview.api.protocols.runner import IRunner 11 | 12 | __all__ = [ 13 | "IMetric", 14 | "IRunner", 15 | ] 16 | -------------------------------------------------------------------------------- /ax/storage/json_store/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.storage.json_store.load import load_experiment as json_load 10 | from ax.storage.json_store.save import save_experiment as json_save 11 | 12 | 13 | __all__ = ["json_load", "json_save"] 14 | -------------------------------------------------------------------------------- /ax/preview/api/types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from typing import Mapping 9 | 10 | TParameterValue = int | float | str | bool 11 | TParameterization = Mapping[str, TParameterValue] 12 | 13 | # Metric name => mean | (mean, sem) 14 | TOutcome = Mapping[str, float | tuple[float, float]] 15 | -------------------------------------------------------------------------------- /ax/runners/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # flake8: noqa F401 10 | from ax.runners.simulated_backend import SimulatedBackendRunner 11 | from ax.runners.synthetic import SyntheticRunner 12 | 13 | 14 | __all__ = ["SimulatedBackendRunner", "SyntheticRunner"] 15 | -------------------------------------------------------------------------------- /ax/plot/js/common/plotly_online.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | requirejs.config({ 9 | paths: { 10 | plotly: ['https://cdn.plot.ly/plotly-latest.min'], 11 | }, 12 | }); 13 | if (!window.Plotly) { 14 | require(['plotly'], function(plotly) { 15 | window.Plotly = plotly; 16 | }); 17 | } 18 | -------------------------------------------------------------------------------- /ax/analysis/plotly/arm_effects/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-unsafe 7 | 8 | from ax.analysis.plotly.arm_effects.insample_effects import InSampleEffectsPlot 9 | from ax.analysis.plotly.arm_effects.predicted_effects import PredictedEffectsPlot 10 | 11 | __all__ = ["PredictedEffectsPlot", "InSampleEffectsPlot"] 12 | -------------------------------------------------------------------------------- /.github/workflows/tutorials.yml: -------------------------------------------------------------------------------- 1 | name: Tutorials 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [ main ] 7 | paths: 8 | - "tutorials/**" 9 | pull_request: 10 | branches: [ main ] 11 | paths: 12 | - "tutorials/**" 13 | 14 | 15 | jobs: 16 | 17 | build-tutorials-with-latest-botorch: 18 | name: Tutorials with latest BoTorch 19 | uses: ./.github/workflows/reusable_tutorials.yml 20 | with: 21 | smoke_test: true 22 | pinned_botorch: false 23 | -------------------------------------------------------------------------------- /ax/metrics/l2norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import numpy as np 10 | import numpy.typing as npt 11 | from ax.metrics.noisy_function import NoisyFunctionMetric 12 | 13 | 14 | class L2NormMetric(NoisyFunctionMetric): 15 | def f(self, x: npt.NDArray) -> float: 16 | return np.sqrt((x**2).sum()) 17 | -------------------------------------------------------------------------------- /ax/global_stopping/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy 10 | from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy 11 | 12 | 13 | __all__ = [ 14 | "BaseGlobalStoppingStrategy", 15 | "ImprovementGlobalStoppingStrategy", 16 | ] 17 | -------------------------------------------------------------------------------- /ax/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.analysis.analysis import ( 9 | Analysis, 10 | AnalysisCard, 11 | AnalysisCardLevel, 12 | display_cards, 13 | ) 14 | from ax.analysis.summary import Summary 15 | from ax.analysis.markdown import * # noqa 16 | from ax.analysis.plotly import * # noqa 17 | 18 | __all__ = ["Analysis", "AnalysisCard", "AnalysisCardLevel", "display_cards", "Summary"] 19 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | workflow_dispatch: 9 | 10 | 11 | jobs: 12 | 13 | lint: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.10" 22 | 23 | - name: Install dependencies 24 | run: pip install pre-commit 25 | 26 | - name: Run pre-commit 27 | run: pre-commit run --all-files --show-diff-on-failure 28 | -------------------------------------------------------------------------------- /ax/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # flake8: noqa F401 10 | from ax.metrics.branin import BraninMetric 11 | from ax.metrics.chemistry import ChemistryMetric 12 | from ax.metrics.factorial import FactorialMetric 13 | from ax.metrics.sklearn import SklearnMetric 14 | 15 | __all__ = [ 16 | "BraninMetric", 17 | "ChemistryMetric", 18 | "FactorialMetric", 19 | "SklearnMetric", 20 | ] 21 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # necessary to import this file so SQLAlchemy knows about the event listeners 10 | # see https://fburl.com/8mn7yjt2 11 | from ax.storage.sqa_store import validation 12 | from ax.storage.sqa_store.load import load_experiment as sqa_load 13 | from ax.storage.sqa_store.save import save_experiment as sqa_save 14 | 15 | 16 | __all__ = ["sqa_load", "sqa_save"] 17 | 18 | del validation 19 | -------------------------------------------------------------------------------- /sphinx/source/global_stopping.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.global_stopping 5 | ================== 6 | 7 | .. automodule:: ax.global_stopping 8 | .. currentmodule:: ax.global_stopping 9 | 10 | Strategies 11 | ---------- 12 | 13 | Base Strategies 14 | ~~~~~~~~~~~~~~~ 15 | 16 | .. automodule:: ax.global_stopping.strategies.base 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | 21 | `ImprovementGlobalStoppingStrategy` 22 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 23 | 24 | .. automodule:: ax.global_stopping.strategies.improvement 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | -------------------------------------------------------------------------------- /ax/benchmark/problems/runtime_funcs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from collections.abc import Mapping 9 | 10 | from ax.core.arm import Arm 11 | from ax.core.types import TParamValue 12 | 13 | 14 | def int_from_params( 15 | params: Mapping[str, TParamValue], n_possibilities: int = 10 16 | ) -> int: 17 | """ 18 | Get an int between 0 and n_possibilities - 1, using a hash of the parameters. 19 | """ 20 | arm_hash = Arm.md5hash(parameters=params) 21 | return int(arm_hash[-1], base=16) % n_possibilities 22 | -------------------------------------------------------------------------------- /ax/benchmark/benchmark_step_runtime_function.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from collections.abc import Mapping 9 | from typing import Protocol, runtime_checkable 10 | 11 | from ax.core.types import TParamValue 12 | 13 | 14 | @runtime_checkable 15 | class TBenchmarkStepRuntimeFunction(Protocol): 16 | def __call__(self, params: Mapping[str, TParamValue]) -> float: 17 | """ 18 | Return the runtime for each step. 19 | 20 | Each step within an arm will take the same amount of time. 21 | """ 22 | ... 23 | -------------------------------------------------------------------------------- /ax/utils/testing/utils_testing_stubs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from ax.utils.testing.backend_simulator import BackendSimulator, BackendSimulatorOptions 11 | 12 | 13 | def get_backend_simulator_with_trials() -> BackendSimulator: 14 | options = BackendSimulatorOptions( 15 | internal_clock=0.0, use_update_as_start_time=True, max_concurrency=2 16 | ) 17 | sim = BackendSimulator(options=options) 18 | sim.run_trial(0, 2) 19 | sim.run_trial(1, 1) 20 | sim.run_trial(2, 10) 21 | return sim 22 | -------------------------------------------------------------------------------- /sphinx/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Ax documentation index file, created by 2 | sphinx-quickstart on Sat Mar 2 00:03:32 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | API Reference 7 | ============= 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | 12 | analysis 13 | ax 14 | benchmark 15 | core 16 | early_stopping 17 | exceptions 18 | global_stopping 19 | health_check 20 | metrics 21 | modelbridge 22 | models 23 | plot 24 | preview 25 | runners 26 | service 27 | storage 28 | telemetry 29 | utils 30 | 31 | 32 | Indices and tables 33 | ================== 34 | 35 | * :ref:`genindex` 36 | * :ref:`modindex` 37 | * :ref:`search` 38 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: check-requirements-versions 5 | name: Check pre-commit formatting versions 6 | entry: python scripts/check_pre_commit_reqs.py 7 | language: python 8 | always_run: true 9 | pass_filenames: false 10 | additional_dependencies: 11 | - PyYAML 12 | 13 | - repo: https://github.com/omnilib/ufmt 14 | rev: v2.8.0 15 | hooks: 16 | - id: ufmt 17 | additional_dependencies: 18 | - black==24.4.2 19 | - usort==1.0.8.post1 20 | - ruff-api==0.1.0 21 | - stdlibs==2024.1.28 22 | args: [format] 23 | 24 | - repo: https://github.com/pycqa/flake8 25 | rev: 7.0.0 26 | hooks: 27 | - id: flake8 28 | -------------------------------------------------------------------------------- /ax/utils/report/resources/sufficient_statistic.html: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | {% for cell in cells %} 14 | {% if cell.html %} 15 |
16 |

{{cell.caption}}

17 |
18 |
19 | {{cell.html}} 20 |
21 | {% endif %} 22 | {% endfor %} 23 |
24 | 25 | 26 | -------------------------------------------------------------------------------- /ax/metrics/hartmann6.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import numpy.typing as npt 10 | from ax.metrics.noisy_function import NoisyFunctionMetric 11 | from ax.utils.common.typeutils import checked_cast 12 | from ax.utils.measurement.synthetic_functions import aug_hartmann6, hartmann6 13 | 14 | 15 | class Hartmann6Metric(NoisyFunctionMetric): 16 | def f(self, x: npt.NDArray) -> float: 17 | return checked_cast(float, hartmann6(x)) 18 | 19 | 20 | class AugmentedHartmann6Metric(NoisyFunctionMetric): 21 | def f(self, x: npt.NDArray) -> float: 22 | return checked_cast(float, aug_hartmann6(x)) 23 | -------------------------------------------------------------------------------- /website/static/js/mathjax.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | window.MathJax = { 9 | tex2jax: { 10 | inlineMath: [['$', '$'], ['\\(', '\\)']], 11 | displayMath: [['$$', '$$'], ['\\[', '\\]']], 12 | processEscapes: true, 13 | processEnvironments: true, 14 | }, 15 | // Center justify equations in code and markdown cells. Note that this 16 | // doesn't work with Plotly though, hence the !important declaratio 17 | // below. 18 | displayAlign: 'center', 19 | 'HTML-CSS': { 20 | styles: { 21 | '.MathJax_Display': {margin: 0, 'text-align': 'center !important'}, 22 | }, 23 | linebreaks: {automatic: true}, 24 | }, 25 | }; 26 | -------------------------------------------------------------------------------- /website/static/img/database-solid.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 11 | 12 | -------------------------------------------------------------------------------- /.github/workflows/cron_pinned.yml: -------------------------------------------------------------------------------- 1 | name: Replicate Nightly Cron with Pinned BoTorch 2 | 3 | on: 4 | 5 | workflow_dispatch: 6 | 7 | jobs: 8 | 9 | tests-and-coverage-minimal: 10 | name: Tests with pinned BoTorch & minimal dependencies 11 | uses: ./.github/workflows/reusable_test.yml 12 | with: 13 | pinned_botorch: true 14 | minimal_dependencies: true 15 | secrets: inherit 16 | 17 | tests-and-coverage-full: 18 | name: Tests with pinned BoTorch & full dependencies 19 | uses: ./.github/workflows/reusable_test.yml 20 | with: 21 | pinned_botorch: true 22 | minimal_dependencies: false 23 | secrets: inherit 24 | 25 | build-tutorials: 26 | name: Build tutorials with pinned BoTorch 27 | uses: ./.github/workflows/reusable_tutorials.yml 28 | with: 29 | smoke_test: false 30 | pinned_botorch: true 31 | -------------------------------------------------------------------------------- /ax/models/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Any, Union 10 | 11 | from ax.core.optimization_config import OptimizationConfig 12 | from ax.models.winsorization_config import WinsorizationConfig 13 | from botorch.acquisition import AcquisitionFunction 14 | 15 | # pyre-ignore [33]: `TConfig` cannot alias to a type containing `Any`. 16 | TConfig = dict[ 17 | str, 18 | Union[ 19 | int, 20 | float, 21 | str, 22 | AcquisitionFunction, 23 | list[str], 24 | dict[int, Any], 25 | dict[str, Any], 26 | OptimizationConfig, 27 | WinsorizationConfig, 28 | None, 29 | ], 30 | ] 31 | -------------------------------------------------------------------------------- /ax/exceptions/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.exceptions.core import AxError 10 | 11 | 12 | class ModelError(AxError): 13 | """Raised when an error occurs during modeling.""" 14 | 15 | pass 16 | 17 | 18 | class CVNotSupportedError(AxError): 19 | """Raised when cross validation is applied to a model which doesn't 20 | support it. 21 | """ 22 | 23 | pass 24 | 25 | 26 | class ModelBridgeMethodNotImplementedError(AxError, NotImplementedError): 27 | """Raised when a ``ModelBridge`` method is not implemented by subclasses. 28 | 29 | NOTE: ``ModelBridge`` may catch and silently discard this error. 30 | """ 31 | 32 | pass 33 | -------------------------------------------------------------------------------- /ax/modelbridge/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # flake8: noqa F401 10 | from ax.modelbridge import transforms 11 | from ax.modelbridge.base import ModelBridge 12 | from ax.modelbridge.factory import ( 13 | get_factorial, 14 | get_sobol, 15 | get_thompson, 16 | get_uniform, 17 | Models, 18 | ) 19 | from ax.modelbridge.map_torch import MapTorchModelBridge 20 | from ax.modelbridge.torch import TorchModelBridge 21 | 22 | __all__ = [ 23 | "MapTorchModelBridge", 24 | "ModelBridge", 25 | "Models", 26 | "TorchModelBridge", 27 | "get_factorial", 28 | "get_sobol", 29 | "get_thompson", 30 | "get_uniform", 31 | "transforms", 32 | ] 33 | -------------------------------------------------------------------------------- /ax/models/tests/test_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.models.base import Model 10 | from ax.utils.common.testutils import TestCase 11 | 12 | 13 | class BaseModelTest(TestCase): 14 | def test_base_model(self) -> None: 15 | model = Model() 16 | raw_state = {"foo": "bar", "two": 3.0} 17 | self.assertEqual(model.serialize_state(raw_state), raw_state) 18 | self.assertEqual(model.deserialize_state(raw_state), raw_state) 19 | self.assertEqual(model._get_state(), {}) 20 | with self.assertRaisesRegex( 21 | NotImplementedError, "Feature importance not available" 22 | ): 23 | model.feature_importances() 24 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/structs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from collections.abc import Callable 10 | from typing import NamedTuple 11 | 12 | from ax.storage.sqa_store.decoder import Decoder 13 | from ax.storage.sqa_store.encoder import Encoder 14 | from ax.storage.sqa_store.sqa_config import SQAConfig 15 | 16 | 17 | class DBSettings(NamedTuple): 18 | """ 19 | Defines behavior for loading/saving experiment to/from db. 20 | Either creator or url must be specified as a way to connect to the SQL db. 21 | """ 22 | 23 | creator: Callable | None = None 24 | decoder: Decoder = Decoder(config=SQAConfig()) 25 | encoder: Encoder = Encoder(config=SQAConfig()) 26 | url: str | None = None 27 | -------------------------------------------------------------------------------- /ax/utils/common/typeutils_nonnative.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Any 10 | 11 | import numpy as np 12 | 13 | 14 | # pyre-fixme[3]: Return annotation cannot be `Any`. 15 | # pyre-fixme[2]: Parameter annotation cannot be `Any`. 16 | def numpy_type_to_python_type(value: Any) -> Any: 17 | """If `value` is a Numpy int or float, coerce to a Python int or float. 18 | This is necessary because some of our transforms return Numpy values. 19 | """ 20 | if isinstance(value, np.integer): 21 | value = int(value) # pragma: nocover (covered by generator tests) 22 | if isinstance(value, np.floating): 23 | value = float(value) # pragma: nocover (covered by generator tests) 24 | return value 25 | -------------------------------------------------------------------------------- /ax/analysis/plotly/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.analysis.plotly.cross_validation import CrossValidationPlot 9 | from ax.analysis.plotly.interaction import InteractionPlot 10 | from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot 11 | from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard 12 | from ax.analysis.plotly.scatter import ScatterPlot 13 | from ax.analysis.plotly.surface.contour import ContourPlot 14 | from ax.analysis.plotly.surface.slice import SlicePlot 15 | 16 | __all__ = [ 17 | "ContourPlot", 18 | "CrossValidationPlot", 19 | "InteractionPlot", 20 | "PlotlyAnalysis", 21 | "PlotlyAnalysisCard", 22 | "ParallelCoordinatesPlot", 23 | "ScatterPlot", 24 | "SlicePlot", 25 | ] 26 | -------------------------------------------------------------------------------- /ax/plot/tests/test_parallel_coordinates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | 9 | from ax.plot.base import AxPlotConfig, AxPlotTypes 10 | from ax.plot.parallel_coordinates import plot_parallel_coordinates 11 | from ax.utils.common.testutils import TestCase 12 | from ax.utils.testing.core_stubs import get_branin_experiment 13 | 14 | 15 | class ParallelCoordinatesTest(TestCase): 16 | def test_ParallelCoordinates(self) -> None: 17 | exp = get_branin_experiment(with_batch=True) 18 | exp.trials[0].run() 19 | 20 | # Assert that each type of plot can be constructed successfully 21 | plot = plot_parallel_coordinates(experiment=exp) 22 | 23 | self.assertIsInstance(plot, AxPlotConfig) 24 | self.assertEqual(plot.plot_type, AxPlotTypes.GENERIC) 25 | -------------------------------------------------------------------------------- /website/static/img/th-large-solid.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 13 | 14 | -------------------------------------------------------------------------------- /ax/preview/api/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.preview.api.client import Client 10 | from ax.preview.api.configs import ( 11 | ChoiceParameterConfig, 12 | ExperimentConfig, 13 | GenerationStrategyConfig, 14 | OrchestrationConfig, 15 | ParameterScaling, 16 | ParameterType, 17 | RangeParameterConfig, 18 | StorageConfig, 19 | ) 20 | from ax.preview.api.types import TOutcome, TParameterization 21 | 22 | __all__ = [ 23 | "Client", 24 | "ChoiceParameterConfig", 25 | "ExperimentConfig", 26 | "GenerationStrategyConfig", 27 | "OrchestrationConfig", 28 | "ParameterScaling", 29 | "ParameterType", 30 | "RangeParameterConfig", 31 | "StorageConfig", 32 | "TOutcome", 33 | "TParameterization", 34 | ] 35 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import NamedTuple 10 | 11 | from ax.utils.common.serialization import named_tuple_to_dict 12 | from ax.utils.common.testutils import TestCase 13 | 14 | 15 | class TestSerializationUtils(TestCase): 16 | def test_named_tuple_to_dict(self) -> None: 17 | class Foo(NamedTuple): 18 | x: int 19 | y: str 20 | 21 | foo = Foo(x=5, y="g") 22 | self.assertEqual(named_tuple_to_dict(foo), {"x": 5, "y": "g"}) 23 | 24 | bar = {"x": 5, "foo": foo, "y": [(1, True), foo]} 25 | self.assertEqual( 26 | named_tuple_to_dict(bar), 27 | {"x": 5, "foo": {"x": 5, "y": "g"}, "y": [(1, True), {"x": 5, "y": "g"}]}, 28 | ) 29 | -------------------------------------------------------------------------------- /ax/metrics/branin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import numpy.typing as npt 10 | from ax.metrics.noisy_function import NoisyFunctionMetric 11 | from ax.utils.common.typeutils import checked_cast 12 | from ax.utils.measurement.synthetic_functions import aug_branin, branin 13 | 14 | 15 | class BraninMetric(NoisyFunctionMetric): 16 | def f(self, x: npt.NDArray) -> float: 17 | x1, x2 = x 18 | return checked_cast(float, branin(x1=x1, x2=x2)) 19 | 20 | 21 | class NegativeBraninMetric(BraninMetric): 22 | def f(self, x: npt.NDArray) -> float: 23 | fpos = super().f(x) 24 | return -fpos 25 | 26 | 27 | class AugmentedBraninMetric(NoisyFunctionMetric): 28 | def f(self, x: npt.NDArray) -> float: 29 | return checked_cast(float, aug_branin(x)) 30 | -------------------------------------------------------------------------------- /ax/benchmark/methods/sobol.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | 9 | from ax.benchmark.benchmark_method import BenchmarkMethod 10 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 11 | from ax.modelbridge.registry import Models 12 | 13 | 14 | def get_sobol_generation_strategy() -> GenerationStrategy: 15 | return GenerationStrategy( 16 | name="Sobol", 17 | steps=[ 18 | GenerationStep(model=Models.SOBOL, num_trials=-1), 19 | ], 20 | ) 21 | 22 | 23 | def get_sobol_benchmark_method( 24 | distribute_replications: bool, 25 | batch_size: int = 1, 26 | ) -> BenchmarkMethod: 27 | return BenchmarkMethod( 28 | generation_strategy=get_sobol_generation_strategy(), 29 | batch_size=batch_size, 30 | distribute_replications=distribute_replications, 31 | ) 32 | -------------------------------------------------------------------------------- /ax/exceptions/constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | TS_MIN_WEIGHT_ERROR = """\ 10 | No arms generated by Thompson Sampling had weight > min_weight. \ 11 | The minimum weight required is {min_weight:2.4}, and the \ 12 | maximum weight of any arm generated is {max_weight:2.4}. 13 | """ 14 | 15 | TS_NO_FEASIBLE_ARMS_ERROR = """\ 16 | Less than 1% of samples have a feasible arm. \ 17 | Check your outcome constraints. 18 | """ 19 | 20 | CHOLESKY_ERROR_ANNOTATION = ( 21 | "Cholesky errors typically occur when the same or very similar " 22 | "arms are suggested repeatedly. This can mean the model has " 23 | "already converged and you should avoid running further trials. " 24 | "It will also help to convert integer or categorical parameters " 25 | "to float ranges where reasonable.\nOriginal error: " 26 | ) 27 | -------------------------------------------------------------------------------- /ax/core/tests/test_risk_measures.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.core.risk_measures import RiskMeasure 10 | from ax.utils.common.testutils import TestCase 11 | 12 | 13 | class TestRiskMeasure(TestCase): 14 | def test_risk_measure(self) -> None: 15 | rm = RiskMeasure( 16 | risk_measure="VaR", 17 | options={"alpha": 0.8, "n_w": 5}, 18 | ) 19 | self.assertEqual(rm.risk_measure, "VaR") 20 | self.assertEqual(rm.options, {"alpha": 0.8, "n_w": 5}) 21 | 22 | # Test repr. 23 | expected_repr = ( 24 | "RiskMeasure(risk_measure=VaR, options={'alpha': 0.8, 'n_w': 5})" 25 | ) 26 | self.assertEqual(str(rm), expected_repr) 27 | 28 | # Test clone. 29 | rm_clone = rm.clone() 30 | self.assertEqual(str(rm), str(rm_clone)) 31 | -------------------------------------------------------------------------------- /sphinx/source/runners.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.runners 5 | ========== 6 | 7 | .. automodule:: ax.runners 8 | .. currentmodule:: ax.runners 9 | 10 | BoTorch Test Problem 11 | ~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. automodule:: ax.runners.botorch_test_problem 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | SingleRunningTrialMixin 19 | ~~~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. automodule:: ax.runners.single_running_trial_mixin 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | Synthetic Runner 27 | ~~~~~~~~~~~~~~~~ 28 | 29 | .. automodule:: ax.runners.synthetic 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | Simulated Backend Runner 35 | ~~~~~~~~~~~~~~~~~~~~~~~~ 36 | 37 | .. automodule:: ax.runners.simulated_backend 38 | :members: 39 | :undoc-members: 40 | :show-inheritance: 41 | 42 | TorchX Runner 43 | ~~~~~~~~~~~~~ 44 | 45 | .. automodule:: ax.runners.torchx 46 | :members: 47 | :undoc-members: 48 | :show-inheritance: 49 | -------------------------------------------------------------------------------- /ax/analysis/healthcheck/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.analysis.healthcheck.can_generate_candidates import ( 9 | CanGenerateCandidatesAnalysis, 10 | ) 11 | 12 | from ax.analysis.healthcheck.constraints_feasibility import ( 13 | ConstraintsFeasibilityAnalysis, 14 | ) 15 | from ax.analysis.healthcheck.healthcheck_analysis import ( 16 | HealthcheckAnalysis, 17 | HealthcheckAnalysisCard, 18 | HealthcheckStatus, 19 | ) 20 | 21 | from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis 22 | from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates 23 | 24 | __all__ = [ 25 | "ConstraintsFeasibilityAnalysis", 26 | "CanGenerateCandidatesAnalysis", 27 | "HealthcheckAnalysis", 28 | "HealthcheckAnalysisCard", 29 | "HealthcheckStatus", 30 | "ShouldGenerateCandidates", 31 | "SearchSpaceAnalysis", 32 | ] 33 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/timestamp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import datetime 10 | 11 | from sqlalchemy.engine.interfaces import Dialect 12 | from sqlalchemy.types import Integer, TypeDecorator 13 | 14 | 15 | class IntTimestamp(TypeDecorator): 16 | impl = Integer 17 | cache_ok = True 18 | 19 | # pyre-fixme[15]: `process_bind_param` overrides method defined in 20 | # `TypeDecorator` inconsistently. 21 | def process_bind_param( 22 | self, value: datetime.datetime | None, dialect: Dialect 23 | ) -> int | None: 24 | if value is None: 25 | return None 26 | else: 27 | return int(value.timestamp()) 28 | 29 | def process_result_value( 30 | self, value: int | None, dialect: Dialect 31 | ) -> datetime.datetime | None: 32 | return None if value is None else datetime.datetime.fromtimestamp(value) 33 | -------------------------------------------------------------------------------- /ax/utils/testing/test_init_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import os 10 | from glob import glob 11 | 12 | from ax.utils.common.testutils import TestCase 13 | 14 | DIRS_TO_SKIP = ["ax/fb", "ax/github", "tests"] 15 | 16 | 17 | class InitTest(TestCase): 18 | def test_InitFiles(self) -> None: 19 | """__init__.py files are necessary for the inclusion of the directories 20 | in pip builds.""" 21 | for root, _, files in os.walk("./ax", topdown=False): 22 | if any(s in root for s in DIRS_TO_SKIP): 23 | continue 24 | if len(glob(f"{root}/**/*.py", recursive=True)) > 0: 25 | with self.subTest(root): 26 | self.assertTrue( 27 | "__init__.py" in files, 28 | "directory " + root + " does not contain a .__init__.py file", 29 | ) 30 | -------------------------------------------------------------------------------- /sphinx/source/exceptions.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.exceptions 5 | ============= 6 | 7 | .. automodule:: ax.exceptions 8 | .. currentmodule:: ax.exceptions 9 | 10 | 11 | Constants 12 | ~~~~~~~~~ 13 | 14 | .. automodule:: ax.exceptions.constants 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | 20 | Core 21 | ~~~~ 22 | 23 | .. automodule:: ax.exceptions.core 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Data 29 | ~~~~ 30 | 31 | .. automodule:: ax.exceptions.data_provider 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Generation Strategy 37 | ~~~~~~~~~~~~~~~~~~~ 38 | 39 | .. automodule:: ax.exceptions.generation_strategy 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | Model 45 | ~~~~~ 46 | 47 | .. automodule:: ax.exceptions.model 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | Storage 53 | ~~~~~~~ 54 | 55 | .. automodule:: ax.exceptions.storage 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | -------------------------------------------------------------------------------- /ax/preview/api/utils/storage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.preview.api.configs import StorageConfig 7 | from ax.storage.sqa_store.decoder import Decoder 8 | from ax.storage.sqa_store.encoder import Encoder 9 | from ax.storage.sqa_store.sqa_config import SQAConfig 10 | from ax.storage.sqa_store.structs import DBSettings 11 | 12 | 13 | def db_settings_from_storage_config( 14 | storage_config: StorageConfig, 15 | ) -> DBSettings: 16 | """Construct DBSettings (expected by WithDBSettingsBase) from StorageConfig.""" 17 | if (bundle := storage_config.registry_bundle) is not None: 18 | encoder = bundle.encoder 19 | decoder = bundle.decoder 20 | else: 21 | encoder = Encoder(config=SQAConfig()) 22 | decoder = Decoder(config=SQAConfig()) 23 | 24 | return DBSettings( 25 | creator=storage_config.creator, 26 | url=storage_config.url, 27 | encoder=encoder, 28 | decoder=decoder, 29 | ) 30 | -------------------------------------------------------------------------------- /ax/early_stopping/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.early_stopping.strategies.base import ( 10 | BaseEarlyStoppingStrategy, 11 | EarlyStoppingTrainingData, 12 | ModelBasedEarlyStoppingStrategy, 13 | ) 14 | from ax.early_stopping.strategies.logical import ( 15 | AndEarlyStoppingStrategy, 16 | LogicalEarlyStoppingStrategy, 17 | OrEarlyStoppingStrategy, 18 | ) 19 | from ax.early_stopping.strategies.percentile import PercentileEarlyStoppingStrategy 20 | from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy 21 | 22 | 23 | __all__ = [ 24 | "BaseEarlyStoppingStrategy", 25 | "EarlyStoppingTrainingData", 26 | "ModelBasedEarlyStoppingStrategy", 27 | "PercentileEarlyStoppingStrategy", 28 | "ThresholdEarlyStoppingStrategy", 29 | "AndEarlyStoppingStrategy", 30 | "OrEarlyStoppingStrategy", 31 | "LogicalEarlyStoppingStrategy", 32 | ] 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ax/analysis/healthcheck/healthcheck_analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | import json 8 | from enum import IntEnum 9 | 10 | from ax.analysis.analysis import Analysis, AnalysisCard 11 | from ax.core.experiment import Experiment 12 | from ax.core.generation_strategy_interface import GenerationStrategyInterface 13 | 14 | 15 | class HealthcheckStatus(IntEnum): 16 | PASS = 0 17 | FAIL = 1 18 | WARNING = 2 19 | 20 | 21 | class HealthcheckAnalysisCard(AnalysisCard): 22 | blob_annotation = "healthcheck" 23 | 24 | def get_status(self) -> HealthcheckStatus: 25 | return HealthcheckStatus(json.loads(self.blob)["status"]) 26 | 27 | 28 | class HealthcheckAnalysis(Analysis): 29 | """ 30 | An analysis that performs a health check. 31 | """ 32 | 33 | def compute( 34 | self, 35 | experiment: Experiment | None = None, 36 | generation_strategy: GenerationStrategyInterface | None = None, 37 | ) -> HealthcheckAnalysisCard: ... 38 | -------------------------------------------------------------------------------- /ax/core/tests/test_auxiliary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.core.auxiliary import AuxiliaryExperiment 9 | from ax.utils.common.testutils import TestCase 10 | from ax.utils.testing.core_stubs import get_experiment, get_experiment_with_data 11 | 12 | 13 | class AuxiliaryExperimentTest(TestCase): 14 | def test_AuxiliaryExperiment(self) -> None: 15 | for get_exp_func in [get_experiment, get_experiment_with_data]: 16 | exp = get_exp_func() 17 | data = exp.lookup_data() 18 | 19 | # Test init 20 | aux_exp = AuxiliaryExperiment(experiment=exp) 21 | self.assertEqual(aux_exp.experiment, exp) 22 | self.assertEqual(aux_exp.data, data) 23 | 24 | another_aux_exp = AuxiliaryExperiment( 25 | experiment=exp, data=exp.lookup_data() 26 | ) 27 | self.assertEqual(another_aux_exp.experiment, exp) 28 | self.assertEqual(another_aux_exp.data, data) 29 | -------------------------------------------------------------------------------- /sphinx/source/early_stopping.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.early_stopping 5 | ================= 6 | 7 | .. automodule:: ax.early_stopping 8 | .. currentmodule:: ax.early_stopping 9 | 10 | Strategies 11 | ---------- 12 | 13 | Base Strategies 14 | ~~~~~~~~~~~~~~~ 15 | 16 | .. automodule:: ax.early_stopping.strategies.base 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | 21 | Logical Strategies 22 | ~~~~~~~~~~~~~~~~~~ 23 | 24 | .. automodule:: ax.early_stopping.strategies.logical 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | `PercentileEarlyStoppingStrategy` 30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 31 | 32 | .. automodule:: ax.early_stopping.strategies.percentile 33 | :members: 34 | :undoc-members: 35 | :show-inheritance: 36 | 37 | `ThresholdEarlyStoppingStrategy` 38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 39 | 40 | .. automodule:: ax.early_stopping.strategies.threshold 41 | :members: 42 | :undoc-members: 43 | :show-inheritance: 44 | 45 | Utils 46 | ----- 47 | 48 | .. automodule:: ax.early_stopping.utils 49 | :members: 50 | :undoc-members: 51 | :show-inheritance: 52 | -------------------------------------------------------------------------------- /ax/plot/css/base.css: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | .hovertext { 9 | opacity: 0.85; 10 | } 11 | 12 | .plot-menu div { 13 | margin-top: 5px; 14 | margin-bottom: 5px; 15 | } 16 | 17 | .plot-menu select { 18 | height: 30px; 19 | padding: 6px 10px; 20 | background-color: #fff; 21 | border: 1px solid #D1D1D1; 22 | border-radius: 4px; 23 | box-shadow: none; 24 | box-sizing: border-box; 25 | margin-left: 15px; 26 | margin-bottom: 0px; 27 | } 28 | 29 | .plot-menu select:focus { 30 | border: 1px solid #33C3F0; 31 | outline: 0; 32 | } 33 | 34 | .plot-menu button { 35 | display: inline-block; 36 | height: 30px; 37 | padding: 0 10px; 38 | color: #555; 39 | text-align: center; 40 | font-size: 12px; 41 | font-weight: 600; 42 | text-transform: uppercase; 43 | text-decoration: none; 44 | white-space: nowrap; 45 | background-color: transparent; 46 | border-radius: 4px; 47 | border: 1px solid #bbb; 48 | cursor: pointer; 49 | box-sizing: border-box; 50 | background-color: #fff; 51 | margin-left: 15px; 52 | } 53 | -------------------------------------------------------------------------------- /ax/modelbridge/transforms/tests/test_rounding_transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import numpy as np 10 | from ax.modelbridge.transforms.rounding import ( 11 | randomized_onehot_round, 12 | strict_onehot_round, 13 | ) 14 | from ax.utils.common.testutils import TestCase 15 | 16 | 17 | class RoundingTest(TestCase): 18 | def test_OneHotRound(self) -> None: 19 | self.assertTrue( 20 | np.allclose( 21 | strict_onehot_round(np.array([0.1, 0.5, 0.3])), np.array([0, 1, 0]) 22 | ) 23 | ) 24 | # One item should be set to one at random. 25 | self.assertEqual( 26 | np.count_nonzero( 27 | np.isclose( 28 | randomized_onehot_round(np.array([0.0, 0.0, 0.0])), 29 | np.array([1, 1, 1]), 30 | ) 31 | ), 32 | 1, 33 | ) 34 | # Negative value is not selected. 35 | self.assertEqual(randomized_onehot_round(np.array([0.0, -1.0, 0.0]))[1], 0.0) 36 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_func_enum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.utils.common.func_enum import FuncEnum 10 | from ax.utils.common.testutils import TestCase 11 | 12 | 13 | def morph_into_salamander(how_soon: int) -> bool: 14 | if how_soon <= 0: # Now is the time! 15 | return True 16 | else: # Not morphing yet 17 | return False 18 | 19 | 20 | class AnimalAbilities(FuncEnum): 21 | # ꒰(˶• ᴗ •˶)꒱ 22 | AXOLOTL_MORPH = "morph_into_salamander" 23 | 24 | 25 | class EqualityTest(TestCase): 26 | def test_basic(self) -> None: 27 | self.assertEqual( # Check underlying function correctness. 28 | AnimalAbilities.AXOLOTL_MORPH._get_function_for_value(), 29 | morph_into_salamander, 30 | ) 31 | 32 | def test_call(self) -> None: 33 | # Should be too early to morph... 34 | self.assertFalse(AnimalAbilities.AXOLOTL_MORPH(how_soon=1)) 35 | # Should've morphed yesterday! 36 | self.assertTrue(AnimalAbilities.AXOLOTL_MORPH(how_soon=-1)) 37 | -------------------------------------------------------------------------------- /ax/core/map_metric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from __future__ import annotations 10 | 11 | from ax.core.map_data import MapData, MapKeyInfo 12 | from ax.core.metric import Metric, MetricFetchE 13 | from ax.utils.common.result import Result 14 | 15 | MapMetricFetchResult = Result[MapData, MetricFetchE] 16 | 17 | 18 | class MapMetric(Metric): 19 | """Base class for representing metrics that return `MapData`. 20 | 21 | The `fetch_trial_data` method is the essential method to override when 22 | subclassing, which specifies how to retrieve a Metric, for a given trial. 23 | 24 | A MapMetric must return a MapData object, which requires (at minimum) the following: 25 | https://ax.dev/api/_modules/ax/core/data.html#Data.required_columns 26 | 27 | Attributes: 28 | lower_is_better: Flag for metrics which should be minimized. 29 | properties: Properties specific to a particular metric. 30 | """ 31 | 32 | data_constructor: type[MapData] = MapData 33 | map_key_info: MapKeyInfo[float] = MapKeyInfo(key="step", default_value=0.0) 34 | -------------------------------------------------------------------------------- /ax/analysis/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.analysis.utils import choose_analyses 9 | from ax.utils.common.testutils import TestCase 10 | from ax.utils.testing.core_stubs import ( 11 | get_branin_experiment, 12 | get_branin_experiment_with_multi_objective, 13 | ) 14 | 15 | 16 | class TestUtils(TestCase): 17 | def test_choose_analyses(self) -> None: 18 | analyses = choose_analyses(experiment=get_branin_experiment()) 19 | self.assertEqual( 20 | {analysis.name for analysis in analyses}, 21 | { 22 | "ParallelCoordinatesPlot", 23 | "InteractionPlot", 24 | "Summary", 25 | "CrossValidationPlot", 26 | }, 27 | ) 28 | 29 | # Multi-objective case 30 | analyses = choose_analyses( 31 | experiment=get_branin_experiment_with_multi_objective() 32 | ) 33 | self.assertEqual( 34 | {analysis.name for analysis in analyses}, 35 | {"InteractionPlot", "ScatterPlot", "Summary", "CrossValidationPlot"}, 36 | ) 37 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/3_general_support.yaml: -------------------------------------------------------------------------------- 1 | name: General Support 2 | description: Get advice on an experiment you're currently running, ask questions about methods, and receive general help with Ax. 3 | labels: ["question"] 4 | title: "[GENERAL SUPPORT]: " 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for reaching out -- we will do our best to respond to your inquiry promptly. 10 | - type: textarea 11 | id: question 12 | attributes: 13 | label: Question 14 | description: Provide a detailed description of the problem you're facing or the question you would like help answering. 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: snippet 19 | attributes: 20 | label: Please provide any relevant code snippet if applicable. 21 | description: This will be automatically formatted into code, so no need for backticks. 22 | render: shell 23 | - type: checkboxes 24 | id: terms 25 | attributes: 26 | label: Code of Conduct 27 | description: By submitting this issue you agree to follow Ax's [Code of Conduct](https://github.com/facebook/Ax/blob/main/CODE_OF_CONDUCT.md). 28 | options: 29 | - label: I agree to follow this Ax's Code of Conduct 30 | required: true 31 | -------------------------------------------------------------------------------- /ax/utils/common/mock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from collections.abc import Callable 10 | from contextlib import contextmanager 11 | from typing import Any, TypeVar 12 | 13 | from unittest.mock import MagicMock, patch 14 | 15 | 16 | T = TypeVar("T") 17 | C = TypeVar("C") 18 | 19 | 20 | @contextmanager 21 | def mock_patch_method_original( 22 | mock_path: str, 23 | original_method: Callable[..., T], 24 | ) -> MagicMock: 25 | """Context manager for patching a method returning type T on class C, 26 | to track calls to it while still executing the original method. There 27 | is not a native way to do this with `mock.patch`. 28 | """ 29 | 30 | def side_effect(self: C, *args: Any, **kwargs: Any) -> T: 31 | # pyre-ignore[16]: Anonymous callable has no attribute `self` 32 | # (We can ignore because we expect C to be a class). 33 | side_effect.self = self 34 | return original_method(self, *args, **kwargs) 35 | 36 | patcher = patch(mock_path, autospec=True, side_effect=side_effect) 37 | yield patcher.start() 38 | patcher.stop() 39 | -------------------------------------------------------------------------------- /ax/utils/report/tests/test_render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.utils.common.testutils import TestCase 10 | from ax.utils.report.render import ( 11 | h2_html, 12 | h3_html, 13 | link_html, 14 | list_item_html, 15 | p_html, 16 | render_report_elements, 17 | table_cell_html, 18 | table_heading_cell_html, 19 | table_html, 20 | table_row_html, 21 | unordered_list_html, 22 | ) 23 | 24 | 25 | class RenderTest(TestCase): 26 | def test_RenderReportElements(self) -> None: 27 | elements = [ 28 | p_html("foobar"), 29 | h2_html("foobar"), 30 | h3_html("foobar"), 31 | list_item_html("foobar"), 32 | unordered_list_html(["foo", "bar"]), 33 | link_html("foo", "bar"), 34 | table_cell_html("foobar"), 35 | table_cell_html("foobar", width="100px"), 36 | table_heading_cell_html("foobar"), 37 | table_row_html(["foo", "bar"]), 38 | table_html(["foo", "bar"]), 39 | ] 40 | render_report_elements("test", elements) 41 | -------------------------------------------------------------------------------- /ax/utils/common/typeutils_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import json 10 | 11 | import torch 12 | from ax.utils.common.typeutils import checked_cast 13 | 14 | 15 | def torch_type_to_str(value: torch.dtype | torch.device | torch.Size) -> str: 16 | """Converts torch types, commonly used in Ax, to string representations.""" 17 | if isinstance(value, torch.dtype): 18 | return str(value) 19 | if isinstance(value, torch.device): 20 | return checked_cast(str, value.type) 21 | if isinstance(value, torch.Size): 22 | return json.dumps(list(value)) 23 | raise ValueError(f"Object {value} was of unexpected torch type.") 24 | 25 | 26 | def torch_type_from_str( 27 | identifier: str, type_name: str 28 | ) -> torch.dtype | torch.device | torch.Size: 29 | if type_name == "device": 30 | return torch.device(identifier) 31 | if type_name == "dtype": 32 | return getattr(torch, identifier[6:]) 33 | if type_name == "Size": 34 | return torch.Size(json.loads(identifier)) 35 | raise ValueError(f"Unexpected type: {type_name} for identifier: {identifier}.") 36 | -------------------------------------------------------------------------------- /ax/storage/json_store/load.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import json 10 | from collections.abc import Callable 11 | from typing import Any 12 | 13 | from ax.core.experiment import Experiment 14 | from ax.storage.json_store.decoder import object_from_json 15 | from ax.storage.json_store.registry import ( 16 | CORE_CLASS_DECODER_REGISTRY, 17 | CORE_DECODER_REGISTRY, 18 | ) 19 | from ax.utils.common.serialization import TDecoderRegistry 20 | 21 | 22 | def load_experiment( 23 | filepath: str, 24 | decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, 25 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 26 | class_decoder_registry: dict[ 27 | str, Callable[[dict[str, Any]], Any] 28 | ] = CORE_CLASS_DECODER_REGISTRY, 29 | ) -> Experiment: 30 | """Load experiment from file. 31 | 32 | 1) Read file. 33 | 2) Convert dictionary to Ax experiment instance. 34 | """ 35 | with open(filepath) as file: 36 | json_experiment = json.loads(file.read()) 37 | return object_from_json( 38 | json_experiment, decoder_registry, class_decoder_registry 39 | ) 40 | -------------------------------------------------------------------------------- /.github/workflows/build-and-test.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test Workflow 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | jobs: 11 | tests-and-coverage: 12 | name: Tests with latest BoTorch 13 | uses: ./.github/workflows/reusable_test.yml 14 | with: 15 | pinned_botorch: false 16 | secrets: inherit 17 | 18 | docs: 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up Python 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: "3.10" 27 | - name: Install dependencies 28 | env: 29 | ALLOW_BOTORCH_LATEST: true 30 | ALLOW_LATEST_GPYTORCH_LINOP: true 31 | run: | 32 | # use latest Botorch 33 | pip install git+https://github.com/cornellius-gp/gpytorch.git 34 | pip install git+https://github.com/pytorch/botorch.git 35 | pip install -e ".[unittest]" 36 | - name: Validate Sphinx 37 | run: | 38 | python scripts/validate_sphinx.py -p "${pwd}" 39 | - name: Run Sphinx 40 | # run even if previous step (validate Sphinx) failed 41 | if: ${{ always() }} 42 | run: | 43 | # warnings no longer treated as errors. 44 | sphinx-build -T --keep-going sphinx/source sphinx/build 45 | -------------------------------------------------------------------------------- /ax/benchmark/benchmark_trial_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from collections.abc import Mapping 9 | from dataclasses import dataclass 10 | 11 | import pandas as pd 12 | 13 | from ax.utils.testing.backend_simulator import BackendSimulator 14 | 15 | 16 | @dataclass(kw_only=True, frozen=True) 17 | class BenchmarkTrialMetadata: 18 | """ 19 | Data pertaining to one trial evaluation. 20 | 21 | Args: 22 | df: A dict mapping each metric name to a Pandas DataFrame with columns 23 | ["metric_name", "arm_name", "mean", "sem", and "step"]. The "sem" is 24 | always present in this df even if noise levels are unobserved; 25 | ``BenchmarkMetric`` and ``BenchmarkMapMetric`` hide that data if it 26 | should not be observed, and ``BenchmarkMapMetric``s drop data from 27 | time periods that that are not observed based on the (simulated) 28 | trial progression. 29 | backend_simulator: Optionally, the backend simulator that is tracking 30 | the trial's status. 31 | """ 32 | 33 | dfs: Mapping[str, pd.DataFrame] 34 | backend_simulator: BackendSimulator | None = None 35 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_docutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.utils.common.docutils import copy_doc 10 | from ax.utils.common.testutils import TestCase 11 | 12 | 13 | def has_doc() -> None: 14 | """I have a docstring""" 15 | 16 | 17 | def has_no_doc() -> None: 18 | pass 19 | 20 | 21 | class TestDocUtils(TestCase): 22 | def test_transfer_doc(self) -> None: 23 | @copy_doc(has_doc) 24 | # pyre-fixme[3]: Return type must be annotated. 25 | def inherits_doc(): 26 | pass 27 | 28 | self.assertEqual(inherits_doc.__doc__, "I have a docstring") 29 | 30 | def test_fail_when_already_has_doc(self) -> None: 31 | with self.assertRaises(ValueError): 32 | 33 | @copy_doc(has_doc) 34 | # pyre-fixme[3]: Return type must be annotated. 35 | def inherits_doc(): 36 | """I already have a doc string""" 37 | pass 38 | 39 | def test_fail_when_no_doc_to_copy(self) -> None: 40 | with self.assertRaises(ValueError): 41 | 42 | @copy_doc(has_no_doc) 43 | # pyre-fixme[3]: Return type must be annotated. 44 | def f(): 45 | pass 46 | -------------------------------------------------------------------------------- /ax/benchmark/benchmark_test_function.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from abc import ABC, abstractmethod 9 | from collections.abc import Mapping, Sequence 10 | from dataclasses import dataclass 11 | 12 | from ax.core.types import TParamValue 13 | from torch import Tensor 14 | 15 | 16 | @dataclass(kw_only=True) 17 | class BenchmarkTestFunction(ABC): 18 | """ 19 | The basic Ax class for generating deterministic data to benchmark against. 20 | 21 | (Noise - if desired - is added by the runner.) 22 | 23 | Args: 24 | outcome_names: Names of the outcomes. 25 | n_steps: Number of data points produced per metric and per evaluation. 1 26 | if data is not time-series. If data is time-series, this will 27 | eventually become the number of values on a `MapMetric` for 28 | evaluations that run to completion. 29 | """ 30 | 31 | outcome_names: Sequence[str] 32 | n_steps: int = 1 33 | 34 | @abstractmethod 35 | def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor: 36 | """ 37 | Evaluate noiselessly. 38 | 39 | Returns: 40 | A 2d tensor of shape (len(self.outcome_names), self.n_steps). 41 | """ 42 | ... 43 | -------------------------------------------------------------------------------- /ax/utils/report/resources/base_template.html: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | {{experiment_name}} 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | {% if headfoot %} 16 |
17 |

18 | {% if trial_index is none %} 19 | Report for {{experiment_name}} 20 | {% else %} 21 | Report for {{ batch_noun }} #{{ trial_index }} of 22 | {{experiment_name}} 23 | {% endif %} 24 |

25 |

Powered by Ax

26 |
27 | {% endif %} 28 |
29 | {% block content %} 30 | {% endblock %} 31 |
32 | {% if headfoot %} 33 | 34 | {% endif %} 35 |
36 | 37 | 38 | -------------------------------------------------------------------------------- /website/static/img/dice-solid.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 16 | 17 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_kwargutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from logging import Logger 11 | from unittest.mock import patch 12 | 13 | from ax.utils.common.kwargs import warn_on_kwargs 14 | from ax.utils.common.logger import get_logger 15 | from ax.utils.common.testutils import TestCase 16 | 17 | logger: Logger = get_logger("ax.utils.common.kwargs") 18 | 19 | 20 | class TestWarnOnKwargs(TestCase): 21 | def test_it_warns_if_kwargs_are_passed(self) -> None: 22 | with patch.object(logger, "warning") as mock_warning: 23 | 24 | def callable_arg() -> None: 25 | return 26 | 27 | warn_on_kwargs(callable_with_kwargs=callable_arg, foo="") 28 | mock_warning.assert_called_once_with( 29 | "Found unexpected kwargs: %s while calling %s " 30 | "from JSON. These kwargs will be ignored.", 31 | {"foo": ""}, 32 | callable_arg, 33 | ) 34 | 35 | def test_it_does_not_warn_if_no_kwargs_are_passed(self) -> None: 36 | with patch.object(logger, "warning") as mock_warning: 37 | warn_on_kwargs(callable_with_kwargs=lambda: None) 38 | mock_warning.assert_not_called() 39 | -------------------------------------------------------------------------------- /ax/plot/tests/test_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from ax.plot.helper import arm_name_to_sort_key, extend_range 11 | from ax.utils.common.testutils import TestCase 12 | 13 | 14 | class HelperTest(TestCase): 15 | def test_extend_range(self) -> None: 16 | with self.assertRaises(ValueError): 17 | extend_range(lower=1, upper=-1) 18 | self.assertEqual(extend_range(lower=-1, upper=1), (-1.2, 1.2)) 19 | self.assertEqual(extend_range(lower=-1, upper=0, percent=30), (-1.3, 0.3)) 20 | self.assertEqual(extend_range(lower=0, upper=1, percent=50), (-0.5, 1.5)) 21 | 22 | def test_arm_name_to_sort_key(self) -> None: 23 | arm_names = ["0_0", "1_10", "1_2", "10_0", "control"] 24 | sorted_names = sorted(arm_names, key=arm_name_to_sort_key, reverse=True) 25 | expected = ["control", "0_0", "1_2", "1_10", "10_0"] 26 | self.assertEqual(sorted_names, expected) 27 | 28 | arm_names = ["0_0", "0", "1_10", "3_2_x", "3_x", "1_2", "control"] 29 | sorted_names = sorted(arm_names, key=arm_name_to_sort_key, reverse=True) 30 | expected = ["control", "3_x", "3_2_x", "0", "0_0", "1_2", "1_10"] 31 | self.assertEqual(sorted_names, expected) 32 | -------------------------------------------------------------------------------- /ax/utils/common/deprecation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import warnings 10 | 11 | 12 | def _validate_force_random_search( 13 | no_bayesian_optimization: bool | None = None, 14 | force_random_search: bool = False, 15 | exception_cls: type[Exception] = ValueError, 16 | ) -> None: 17 | """Helper function to validate interaction between `force_random_search` 18 | and `no_bayesian_optimization` (supported until deprecation in [T199632397]) 19 | """ 20 | if no_bayesian_optimization is not None: 21 | # users are effectively permitted to continue using 22 | # `no_bayesian_optimization` so long as it doesn't 23 | # conflict with `force_random_search` 24 | if no_bayesian_optimization != force_random_search: 25 | raise exception_cls( 26 | "Conflicting values for `force_random_search` " 27 | "and `no_bayesian_optimization`! " 28 | "Please only specify `force_random_search`." 29 | ) 30 | warnings.warn( 31 | "`no_bayesian_optimization` is deprecated. Please use " 32 | "`force_random_search` in the future.", 33 | DeprecationWarning, 34 | stacklevel=2, 35 | ) 36 | -------------------------------------------------------------------------------- /ax/preview/api/protocols/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | 9 | from typing import Any, Mapping 10 | 11 | from ax.preview.api.protocols.utils import _APIMetric 12 | from pyre_extensions import override 13 | 14 | 15 | class IMetric(_APIMetric): 16 | """ 17 | Metrics automate the process of fetching data from external systems. They are used 18 | in conjunction with Runners in the run_n_trials method to facilitate closed-loop 19 | experimentation. 20 | """ 21 | 22 | def __init__(self, name: str) -> None: 23 | super().__init__(name=name) 24 | 25 | @override 26 | def fetch( 27 | self, 28 | trial_index: int, 29 | trial_metadata: Mapping[str, Any], 30 | ) -> tuple[int, float | tuple[float, float]]: 31 | """ 32 | Given trial metadata (the mapping returned from IRunner.run), fetches 33 | readings for the metric. 34 | 35 | Readings are returned as a pair (progression, outcome), where progression is 36 | an integer representing the progression of the trial (e.g. number of epochs 37 | for a training job, timestamp for a time series, etc.), and outcome is either 38 | direct reading or a (mean, sem) pair for the metric. 39 | """ 40 | ... 41 | -------------------------------------------------------------------------------- /ax/runners/synthetic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from collections.abc import Iterable 10 | from typing import Any 11 | 12 | from ax.core.base_trial import BaseTrial, TrialStatus 13 | from ax.core.runner import Runner 14 | 15 | 16 | class SyntheticRunner(Runner): 17 | """Class for synthetic or dummy runner. 18 | 19 | Currently acts as a shell runner, only creating a name. 20 | """ 21 | 22 | def __init__(self, dummy_metadata: str | None = None) -> None: 23 | self.dummy_metadata = dummy_metadata 24 | 25 | def run(self, trial: BaseTrial) -> dict[str, Any]: 26 | deployed_name = ( 27 | trial.experiment.name + "_" + str(trial.index) 28 | if trial.experiment.has_name 29 | else str(trial.index) 30 | ) 31 | metadata = {"name": deployed_name} 32 | 33 | # Add dummy metadata if needed for testing 34 | if self.dummy_metadata: 35 | metadata["dummy_metadata"] = self.dummy_metadata 36 | return metadata 37 | 38 | def poll_trial_status( 39 | self, trials: Iterable[BaseTrial] 40 | ) -> dict[TrialStatus, set[int]]: 41 | return {TrialStatus.COMPLETED: {t.index for t in trials}} 42 | 43 | @property 44 | def run_metadata_report_keys(self) -> list[str]: 45 | return ["name"] 46 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2_feature_request.yaml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Request a feature or improvement to Ax. 3 | labels: ["enhancement"] 4 | title: "[FEATURE REQUEST]: " 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for taking the time to request a feature! We strive to make Ax a useful and rich library for our users. 10 | - type: textarea 11 | id: motivation 12 | attributes: 13 | label: Motivation 14 | description: Provide a detailed description of the problem you would like to solve via this new feature or improvement. 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: pitch 19 | attributes: 20 | label: Describe the solution you'd like to see implemented in Ax. 21 | validations: 22 | required: true 23 | - type: textarea 24 | id: alternatives 25 | attributes: 26 | label: Describe any alternatives you've considered to the above solution. 27 | - type: textarea 28 | id: related 29 | attributes: 30 | label: Is this related to an existing issue in Ax or another repository? If so please include links to those Issues here. 31 | - type: checkboxes 32 | id: terms 33 | attributes: 34 | label: Code of Conduct 35 | description: By submitting this issue, you agree to follow Ax's [Code of Conduct](https://github.com/facebook/Ax/blob/main/CODE_OF_CONDUCT.md). 36 | options: 37 | - label: I agree to follow Ax's Code of Conduct 38 | required: true 39 | -------------------------------------------------------------------------------- /ax/exceptions/storage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from ax.exceptions.core import AxError 11 | 12 | 13 | STORAGE_DOCS_SUFFIX = ( 14 | "Please see our storage tutorial (https://ax.dev/docs/storage.html) " 15 | "for more details ('Customizing' section will be " 16 | "relevant for saving Ax object subclasses)." 17 | ) 18 | 19 | 20 | class JSONDecodeError(AxError): 21 | """Raised when an error occurs during JSON decoding.""" 22 | 23 | pass 24 | 25 | 26 | class JSONEncodeError(AxError): 27 | """Raised when an error occurs during JSON encoding.""" 28 | 29 | pass 30 | 31 | 32 | class SQADecodeError(AxError): 33 | """Raised when an error occurs during SQA decoding.""" 34 | 35 | pass 36 | 37 | 38 | class SQAEncodeError(AxError): 39 | """Raised when an error occurs during SQA encoding.""" 40 | 41 | pass 42 | 43 | 44 | class ImmutabilityError(AxError): 45 | """Raised when an attempt is made to update an immutable object.""" 46 | 47 | pass 48 | 49 | 50 | class IncorrectDBConfigurationError(AxError): 51 | """Raised when an attempt is made to save and load an object, but 52 | the current engine and session factory is setup up incorrectly to 53 | process the call (e.g. current session factory will connect to a 54 | wrong database for the call). 55 | """ 56 | 57 | pass 58 | -------------------------------------------------------------------------------- /ax/models/winsorization_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from dataclasses import dataclass 10 | 11 | 12 | @dataclass 13 | class WinsorizationConfig: 14 | """Dataclass for storing Winsorization configuration parameters 15 | 16 | Attributes: 17 | lower_quantile_margin: Winsorization will increase any metric value below this 18 | quantile to this quantile's value. 19 | upper_quantile_margin: Winsorization will decrease any metric value above this 20 | quantile to this quantile's value. NOTE: this quantile will be inverted before 21 | any operations, e.g., a value of 0.2 will decrease values above the 80th 22 | percentile to the value of the 80th percentile. 23 | lower_boundary: If this value is lesser than the metric value corresponding to 24 | ``lower_quantile_margin``, set metric values below ``lower_boundary`` to 25 | ``lower_boundary`` and leave larger values unaffected. 26 | upper_boundary: If this value is greater than the metric value corresponding to 27 | ``upper_quantile_margin``, set metric values above ``upper_boundary`` to 28 | ``upper_boundary`` and leave smaller values unaffected. 29 | """ 30 | 31 | lower_quantile_margin: float = 0.0 32 | upper_quantile_margin: float = 0.0 33 | lower_boundary: float | None = None 34 | upper_boundary: float | None = None 35 | -------------------------------------------------------------------------------- /ax/core/auxiliary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from __future__ import annotations 9 | 10 | from enum import Enum, unique 11 | from typing import TYPE_CHECKING 12 | 13 | from ax.core.data import Data 14 | from ax.utils.common.base import SortableBase 15 | 16 | 17 | if TYPE_CHECKING: 18 | # import as module to make sphinx-autodoc-typehints happy 19 | from ax import core # noqa F401 20 | 21 | 22 | class AuxiliaryExperiment(SortableBase): 23 | """Class for defining an auxiliary experiment.""" 24 | 25 | def __init__( 26 | self, 27 | experiment: core.experiment.Experiment, 28 | data: Data | None = None, 29 | ) -> None: 30 | """ 31 | Lightweight container of an experiment, and its data, 32 | that will be used as auxiliary information for another experiment. 33 | """ 34 | self.experiment = experiment 35 | self.data: Data = data or experiment.lookup_data() 36 | 37 | def _unique_id(self) -> str: 38 | # While there can be multiple `AuxiliarySource`-s made from the same 39 | # experiment (and thus sharing the experiment name), the uniqueness 40 | # here is only needed w.r.t. parent object ("main experiment", for which 41 | # this will be an auxiliary source for). 42 | return self.experiment.name 43 | 44 | 45 | @unique 46 | class AuxiliaryExperimentPurpose(Enum): 47 | pass 48 | -------------------------------------------------------------------------------- /.github/workflows/reusable_tutorials.yml: -------------------------------------------------------------------------------- 1 | name: Reusable Tutorials Workflow 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | smoke_test: 7 | required: true 8 | type: boolean 9 | pinned_botorch: 10 | required: true 11 | type: boolean 12 | workflow_call: 13 | inputs: 14 | smoke_test: 15 | required: true 16 | type: boolean 17 | pinned_botorch: 18 | required: true 19 | type: boolean 20 | 21 | jobs: 22 | 23 | build-tutorials: 24 | name: Tutorials 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v4 28 | - name: Set up Python 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: "3.10" 32 | 33 | - if: ${{ inputs.pinned_botorch }} 34 | name: Install dependencies with pinned BoTorch 35 | run: | 36 | pip install -e ".[tutorial]" 37 | 38 | - if: ${{ !inputs.pinned_botorch }} 39 | name: Install dependencies with latest BoTorch 40 | env: 41 | ALLOW_BOTORCH_LATEST: true 42 | ALLOW_LATEST_GPYTORCH_LINOP: true 43 | run: | 44 | pip install git+https://github.com/cornellius-gp/gpytorch.git 45 | pip install git+https://github.com/pytorch/botorch.git 46 | pip install -e ".[tutorial]" 47 | 48 | - if: ${{ inputs.smoke_test }} 49 | name: Build tutorials with smoke test 50 | run: | 51 | python scripts/make_tutorials.py -w $(pwd) -e -s 52 | - if: ${{ !inputs.smoke_test }} 53 | name: Build tutorials without smoke test 54 | run: | 55 | python scripts/make_tutorials.py -w $(pwd) -e 56 | -------------------------------------------------------------------------------- /scripts/patch_site_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import re 9 | 10 | 11 | def patch_config( 12 | config_file: str, base_url: str = None, disable_algolia: bool = True 13 | ) -> None: 14 | config = open(config_file).read() 15 | 16 | if base_url is not None: 17 | config = re.sub("baseUrl = '/';", f"baseUrl = '{base_url}';", config) 18 | if disable_algolia is True: 19 | config = re.sub( 20 | "const includeAlgolia = true;", "const includeAlgolia = false;", config 21 | ) 22 | 23 | with open(config_file, "w") as outfile: 24 | outfile.write(config) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser( 29 | description="Path Docusaurus siteConfig.js file when building site." 30 | ) 31 | parser.add_argument( 32 | "-f", 33 | "--config_file", 34 | metavar="path", 35 | required=True, 36 | help="Path to configuration file.", 37 | ) 38 | parser.add_argument( 39 | "-b", 40 | "--base_url", 41 | type=str, 42 | required=False, 43 | help="Value for baseUrl.", 44 | default=None, 45 | ) 46 | parser.add_argument( 47 | "--disable_algolia", 48 | required=False, 49 | action="store_true", 50 | help="Disable algolia.", 51 | ) 52 | args = parser.parse_args() 53 | patch_config(args.config_file, args.base_url, args.disable_algolia) 54 | -------------------------------------------------------------------------------- /ax/utils/common/random.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import random 10 | from collections.abc import Generator 11 | from contextlib import contextmanager 12 | 13 | import numpy as np 14 | import torch 15 | 16 | 17 | def set_rng_seed(seed: int) -> None: 18 | """Sets seeds for random number generators from numpy, pytorch, 19 | and the native random module. 20 | 21 | Args: 22 | seed: The random number generator seed. 23 | """ 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | 28 | 29 | @contextmanager 30 | def with_rng_seed(seed: int | None) -> Generator[None, None, None]: 31 | """Context manager that sets the random number generator seeds 32 | to a given value and restores the previous state on exit. 33 | 34 | If the seed is None, the context manager does nothing. This makes 35 | it possible to use the context manager without having to change 36 | the code based on whether the seed is specified. 37 | 38 | Args: 39 | seed: The random number generator seed. 40 | """ 41 | if seed is None: 42 | yield 43 | else: 44 | old_state_native = random.getstate() 45 | old_state_numpy = np.random.get_state() 46 | try: 47 | with torch.random.fork_rng(): 48 | set_rng_seed(seed) 49 | yield 50 | finally: 51 | random.setstate(old_state_native) 52 | np.random.set_state(old_state_numpy) 53 | -------------------------------------------------------------------------------- /ax/storage/tests/test_registry_bundle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.benchmark.benchmark_metric import BenchmarkMetric 9 | from ax.metrics.branin import BraninMetric 10 | from ax.runners.synthetic import SyntheticRunner 11 | from ax.storage.registry_bundle import RegistryBundle 12 | from ax.utils.common.testutils import TestCase 13 | 14 | 15 | class RegistryBundleTest(TestCase): 16 | def test_from_registry_bundles(self) -> None: 17 | left = RegistryBundle( 18 | metric_clss={BraninMetric: None}, 19 | runner_clss={SyntheticRunner: None}, 20 | json_encoder_registry={}, 21 | json_class_encoder_registry={}, 22 | json_decoder_registry={}, 23 | json_class_decoder_registry={}, 24 | ) 25 | 26 | right = RegistryBundle( 27 | metric_clss={BenchmarkMetric: None}, 28 | runner_clss={SyntheticRunner: None}, 29 | json_encoder_registry={}, 30 | json_class_encoder_registry={}, 31 | json_decoder_registry={}, 32 | json_class_decoder_registry={}, 33 | ) 34 | 35 | self.assertIn(BraninMetric, left.encoder_registry) 36 | self.assertNotIn(BenchmarkMetric, left.encoder_registry) 37 | 38 | combined = RegistryBundle.from_registry_bundles(left, right) 39 | 40 | self.assertIn(BraninMetric, combined.encoder_registry) 41 | self.assertIn(SyntheticRunner, combined.encoder_registry) 42 | self.assertIn(BenchmarkMetric, combined.encoder_registry) 43 | -------------------------------------------------------------------------------- /website/static/js/plotUtils.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | // helper functions used across multiple plots 9 | function rgb(rgb_array) { 10 | return 'rgb(' + rgb_array.join() + ')'; 11 | } 12 | 13 | function copy_and_reverse(arr) { 14 | const copy = arr.slice(); 15 | copy.reverse(); 16 | return copy; 17 | } 18 | 19 | function axis_range(grid, is_log) { 20 | return is_log ? 21 | [Math.log10(Math.min(...grid)), Math.log10(Math.max(...grid))]: 22 | [Math.min(...grid), Math.max(...grid)]; 23 | } 24 | 25 | function relativize_data(f, sd, rel, arm_data, metric) { 26 | // if relative, extract status quo & compute ratio 27 | const f_final = rel === true ? [] : f; 28 | const sd_final = rel === true ? []: sd; 29 | 30 | if (rel === true) { 31 | const f_sq = ( 32 | arm_data['in_sample'][arm_data['status_quo_name']]['y'][metric] 33 | ); 34 | const sd_sq = ( 35 | arm_data['in_sample'][arm_data['status_quo_name']]['se'][metric] 36 | ); 37 | 38 | for (let i = 0; i < f.length; i++) { 39 | res = relativize(f[i], sd[i], f_sq, sd_sq); 40 | f_final.push(100 * res[0]); 41 | sd_final.push(100 * res[1]); 42 | } 43 | } 44 | 45 | return [f_final, sd_final]; 46 | } 47 | 48 | function relativize(m_t, sem_t, m_c, sem_c) { 49 | r_hat = ( 50 | (m_t - m_c) / Math.abs(m_c) - 51 | Math.pow(sem_c, 2) * m_t / Math.pow(Math.abs(m_c), 3) 52 | ); 53 | variance = ( 54 | (Math.pow(sem_t, 2) + Math.pow((m_t / m_c * sem_c), 2)) / 55 | Math.pow(m_c, 2) 56 | ) 57 | return [r_hat, Math.sqrt(variance)]; 58 | } 59 | -------------------------------------------------------------------------------- /ax/utils/common/func_enum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | from enum import Enum, unique 8 | from importlib import import_module 9 | from typing import Any, Callable 10 | 11 | from ax.exceptions.core import UnsupportedError 12 | 13 | 14 | @unique 15 | class FuncEnum(Enum): 16 | """A base class for all enums with the following structure: string values that 17 | map to names of functions, which reside in the same module as the enum.""" 18 | 19 | # pyre-ignore[3]: Input constructors will be used to make different inputs, 20 | # so we need to allow `Any` return type here. 21 | def __call__(self, **kwargs: Any) -> Any: 22 | """Defines a method, by which the members of this enum can be called, 23 | e.g. ``MyFunctions.F(**kwargs)``, which will call the corresponding 24 | function registered by the name ``F`` in the enum.""" 25 | return self._get_function_for_value()(**kwargs) 26 | 27 | # pyre-ignore[31]: Expression `typing.Callable[([...], typing.Any)]` 28 | # is not a valid type. 29 | def _get_function_for_value(self) -> Callable[[...], Any]: 30 | """Retrieve the function in this module, name of which corresponds to the 31 | value of the enum member.""" 32 | try: 33 | return getattr(import_module(self.__module__), self.value) 34 | except AttributeError: 35 | raise UnsupportedError( 36 | f"{self.value} is not defined as a method in " 37 | f"`{self.__module__}`. Please add the method " 38 | "to the file." 39 | ) 40 | -------------------------------------------------------------------------------- /ax/analysis/healthcheck/tests/test_should_generate_candidates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-unsafe 7 | 8 | from random import randint 9 | 10 | from ax.analysis.analysis import AnalysisCardLevel 11 | from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus 12 | from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates 13 | from ax.utils.common.testutils import TestCase 14 | 15 | 16 | class TestShouldGenerateCandidates(TestCase): 17 | def test_should(self) -> None: 18 | trial_index = randint(0, 10) 19 | card = ShouldGenerateCandidates( 20 | should_generate=True, 21 | reason="Something reassuring", 22 | trial_index=trial_index, 23 | ).compute() 24 | self.assertEqual(card.get_status(), HealthcheckStatus.PASS) 25 | self.assertEqual(card.level, AnalysisCardLevel.CRITICAL) 26 | self.assertEqual(card.subtitle, "Something reassuring") 27 | self.assertEqual(card.attributes["trial_index"], trial_index) 28 | 29 | def test_should_not(self) -> None: 30 | trial_index = randint(0, 10) 31 | card = ShouldGenerateCandidates( 32 | should_generate=False, 33 | reason="Something concerning", 34 | trial_index=trial_index, 35 | ).compute() 36 | self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) 37 | self.assertEqual(card.level, AnalysisCardLevel.CRITICAL) 38 | self.assertEqual(card.subtitle, "Something concerning") 39 | self.assertEqual(card.attributes["trial_index"], trial_index) 40 | -------------------------------------------------------------------------------- /ax/plot/tests/test_slices.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import plotly.graph_objects as go 10 | from ax.modelbridge.registry import Models 11 | from ax.plot.base import AxPlotConfig 12 | from ax.plot.slice import ( 13 | interact_slice, 14 | interact_slice_plotly, 15 | plot_slice, 16 | plot_slice_plotly, 17 | ) 18 | from ax.utils.common.testutils import TestCase 19 | from ax.utils.testing.core_stubs import get_branin_experiment 20 | from ax.utils.testing.mock import mock_botorch_optimize 21 | 22 | 23 | class SlicesTest(TestCase): 24 | @mock_botorch_optimize 25 | def test_Slices(self) -> None: 26 | exp = get_branin_experiment(with_batch=True) 27 | exp.trials[0].run() 28 | model = Models.BOTORCH_MODULAR( 29 | # Model bridge kwargs 30 | experiment=exp, 31 | data=exp.fetch_data(), 32 | ) 33 | # Assert that each type of plot can be constructed successfully 34 | plot = plot_slice_plotly( 35 | model, 36 | # pyre-fixme[16]: `ModelBridge` has no attribute `parameters`. 37 | model.parameters[0], 38 | list(model.metric_names)[0], 39 | ) 40 | self.assertIsInstance(plot, go.Figure) 41 | plot = interact_slice_plotly(model) 42 | self.assertIsInstance(plot, go.Figure) 43 | plot = plot_slice(model, model.parameters[0], list(model.metric_names)[0]) 44 | self.assertIsInstance(plot, AxPlotConfig) 45 | plot = interact_slice(model) 46 | self.assertIsInstance(plot, AxPlotConfig) 47 | -------------------------------------------------------------------------------- /ax/modelbridge/transforms/search_space_to_float.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from ax.core.arm import Arm 11 | from ax.core.observation import ObservationFeatures 12 | from ax.core.parameter import ParameterType, RangeParameter 13 | from ax.core.search_space import SearchSpace 14 | from ax.modelbridge.transforms.base import Transform 15 | 16 | 17 | class SearchSpaceToFloat(Transform): 18 | """Replaces the search space with a single range parameter, whose values 19 | are derived from the signature of the arms. 20 | 21 | NOTE: This will have collisions and so should not be used whenever unique 22 | observation features need to be preserved. Its purpose is to enable 23 | forward transforms for any search space regardless of parameterization. 24 | 25 | Transform is done in-place. 26 | """ 27 | 28 | def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: 29 | parameter = RangeParameter( 30 | name="HASH_PARAM", 31 | parameter_type=ParameterType.FLOAT, 32 | lower=0.0, 33 | upper=1e12, 34 | ) 35 | return SearchSpace(parameters=[parameter]) 36 | 37 | def transform_observation_features( 38 | self, observation_features: list[ObservationFeatures] 39 | ) -> list[ObservationFeatures]: 40 | for obsf in observation_features: 41 | sig = Arm(parameters=obsf.parameters).signature 42 | val = float(int(sig, 16) % 1_000_000_000_000) 43 | obsf.parameters = {"HASH_PARAM": val} 44 | return observation_features 45 | -------------------------------------------------------------------------------- /ax/utils/notebook/plotting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from logging import Logger 10 | 11 | from ax.plot.base import AxPlotConfig, AxPlotTypes 12 | from ax.plot.render import _js_requires, _wrap_js, plot_config_to_html 13 | from ax.utils.common.logger import get_logger 14 | from IPython.display import display 15 | from plotly.offline import init_notebook_mode, iplot 16 | 17 | logger: Logger = get_logger(__name__) 18 | 19 | 20 | def init_notebook_plotting(offline: bool = False) -> None: 21 | """Initialize plotting in notebooks, either in online or offline mode.""" 22 | display_bundle = {"text/html": _wrap_js(_js_requires(offline=offline))} 23 | display(display_bundle, raw=True) 24 | logger.info("Injecting Plotly library into cell. Do not overwrite or delete cell.") 25 | logger.info( 26 | """ 27 | Please see 28 | (https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering) 29 | if visualizations are not rendering. 30 | """.strip() 31 | ) 32 | init_notebook_mode() 33 | 34 | 35 | def render(plot_config: AxPlotConfig, inject_helpers: bool = False) -> None: 36 | """Render plot config.""" 37 | if plot_config.plot_type == AxPlotTypes.GENERIC: 38 | iplot(plot_config.data) 39 | elif plot_config.plot_type == AxPlotTypes.HTML: 40 | assert "text/html" in plot_config.data 41 | display(plot_config.data, raw=True) 42 | else: 43 | display_bundle = { 44 | "text/html": plot_config_to_html(plot_config, inject_helpers=inject_helpers) 45 | } 46 | display(display_bundle, raw=True) 47 | -------------------------------------------------------------------------------- /ax/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.core import ( 10 | Arm, 11 | BatchTrial, 12 | ChoiceParameter, 13 | ComparisonOp, 14 | Data, 15 | Experiment, 16 | FixedParameter, 17 | GeneratorRun, 18 | Metric, 19 | MultiObjective, 20 | MultiObjectiveOptimizationConfig, 21 | Objective, 22 | ObjectiveThreshold, 23 | OptimizationConfig, 24 | OrderConstraint, 25 | OutcomeConstraint, 26 | Parameter, 27 | ParameterConstraint, 28 | ParameterType, 29 | RangeParameter, 30 | Runner, 31 | SearchSpace, 32 | SumConstraint, 33 | Trial, 34 | ) 35 | from ax.modelbridge import Models 36 | from ax.service import OptimizationLoop, optimize 37 | from ax.storage import json_load, json_save 38 | 39 | try: 40 | pass 41 | except Exception: # pragma: no cover 42 | __version__ = "Unknown" 43 | 44 | 45 | __all__ = [ 46 | "Arm", 47 | "BatchTrial", 48 | "ChoiceParameter", 49 | "ComparisonOp", 50 | "Data", 51 | "Experiment", 52 | "FixedParameter", 53 | "GeneratorRun", 54 | "Metric", 55 | "Models", 56 | "MultiObjective", 57 | "MultiObjectiveOptimizationConfig", 58 | "Objective", 59 | "ObjectiveThreshold", 60 | "OptimizationConfig", 61 | "OptimizationLoop", 62 | "OrderConstraint", 63 | "OutcomeConstraint", 64 | "Parameter", 65 | "ParameterConstraint", 66 | "ParameterType", 67 | "RangeParameter", 68 | "Runner", 69 | "SearchSpace", 70 | "SumConstraint", 71 | "Trial", 72 | "optimize", 73 | "json_save", 74 | "json_load", 75 | ] 76 | -------------------------------------------------------------------------------- /ax/modelbridge/transforms/deprecated_transform_mixin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from logging import Logger 10 | from typing import Any 11 | 12 | from ax.utils.common.logger import get_logger 13 | 14 | logger: Logger = get_logger(__name__) 15 | 16 | 17 | class DeprecatedTransformMixin: 18 | """ 19 | Mixin class for deprecated transforms. 20 | 21 | This class is used to log warnings when a deprecated transform is used, 22 | and will construct the new transform that should be used instead. 23 | 24 | The deprecated transform should inherit as follows: 25 | 26 | class DeprecatedTransform(DeprecatedTransformMixin, NewTransform): 27 | ... 28 | 29 | :meta private: 30 | """ 31 | 32 | def __init__(self, *args: Any, **kwargs: Any) -> None: 33 | """ 34 | Log a warning that the transform is deprecated, and construct the 35 | new transform. 36 | """ 37 | warning_msg = self.warn_deprecated_message( 38 | self.__class__.__name__, type(self).__bases__[1].__name__ 39 | ) 40 | logger.warning(warning_msg) 41 | 42 | super().__init__(*args, **kwargs) 43 | 44 | @staticmethod 45 | def warn_deprecated_message( 46 | deprecated_transform_name: str, new_transform_name: str 47 | ) -> str: 48 | """ 49 | Constructs the warning message. 50 | """ 51 | return ( 52 | f"`{deprecated_transform_name}` transform has been deprecated " 53 | "and will be removed in a future release. " 54 | f"Using `{new_transform_name}` instead." 55 | ) 56 | -------------------------------------------------------------------------------- /ax/utils/testing/torch_stubs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from __future__ import annotations 10 | 11 | from typing import Any 12 | 13 | import torch 14 | 15 | 16 | def get_torch_test_data( 17 | dtype: torch.dtype = torch.float, 18 | cuda: bool = False, 19 | constant_noise: bool = True, 20 | task_features: list[int] | None = None, 21 | offset: float = 0.0, 22 | ) -> tuple[ 23 | list[torch.Tensor], 24 | list[torch.Tensor], 25 | list[torch.Tensor], 26 | list[tuple[float, float]], 27 | list[int], 28 | list[str], 29 | list[str], 30 | ]: 31 | tkwargs: dict[str, Any] = { 32 | "device": torch.device("cuda" if cuda else "cpu"), 33 | "dtype": dtype, 34 | } 35 | Xs = [ 36 | torch.tensor( 37 | [ 38 | [1.0 + offset, 2.0 + offset, 3.0 + offset], 39 | [2.0 + offset, 3.0 + offset, 4.0 + offset], 40 | ], 41 | **tkwargs, 42 | ) 43 | ] 44 | Ys = [torch.tensor([[3.0 + offset], [4.0 + offset]], **tkwargs)] 45 | if constant_noise: 46 | Yvar = torch.ones(2, 1, **tkwargs) 47 | else: 48 | Yvar = torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs) 49 | Yvars = [Yvar] 50 | 51 | bounds = [ 52 | (0.0 + offset, 1.0 + offset), 53 | (1.0 + offset, 4.0 + offset), 54 | (2.0 + offset, 5.0 + offset), 55 | ] 56 | feature_names = ["x1", "x2", "x3"] 57 | task_features = [] if task_features is None else task_features 58 | metric_names = ["y"] 59 | return Xs, Ys, Yvars, bounds, task_features, feature_names, metric_names 60 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/reduced_state.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun 11 | from sqlalchemy.orm import defaultload, lazyload, strategy_options 12 | from sqlalchemy.orm.attributes import InstrumentedAttribute 13 | 14 | 15 | GR_LARGE_MODEL_ATTRS: list[InstrumentedAttribute] = [ # pyre-ignore[9] 16 | SQAGeneratorRun.model_kwargs, 17 | SQAGeneratorRun.bridge_kwargs, 18 | SQAGeneratorRun.model_state_after_gen, 19 | SQAGeneratorRun.gen_metadata, 20 | ] 21 | 22 | 23 | GR_PARAMS_METRICS_COLS = [ 24 | "parameters", 25 | "parameter_constraints", 26 | "metrics", 27 | ] 28 | 29 | 30 | def get_query_options_to_defer_immutable_duplicates() -> list[strategy_options.Load]: 31 | """Returns the query options that defer loading of attributes that are duplicated 32 | on each trial (like search space attributes and metrics). These attributes do not 33 | need to be loaded for experiments with immutable search space and optimization 34 | configuration. 35 | """ 36 | options = [lazyload(f"generator_runs.{col}") for col in GR_PARAMS_METRICS_COLS] 37 | return options 38 | 39 | 40 | def get_query_options_to_defer_large_model_cols() -> list[strategy_options.Load]: 41 | """Returns the query options that defer loading of model-state-related columns 42 | of generator runs, which can be large and are not needed on every generator run 43 | when loading experiment and generation strategy in reduced state. 44 | """ 45 | return [ 46 | defaultload("generator_runs").defer(col.key) for col in GR_LARGE_MODEL_ATTRS 47 | ] 48 | -------------------------------------------------------------------------------- /ax/utils/common/timeutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from collections.abc import Generator 10 | from datetime import datetime, timedelta 11 | from time import time 12 | 13 | import pandas as pd 14 | 15 | 16 | DS_FRMT = "%Y-%m-%d" # Format to use for parsing DS strings. 17 | 18 | 19 | def to_ds(ts: datetime) -> str: 20 | """Convert a `datetime` to a DS string.""" 21 | return datetime.strftime(ts, DS_FRMT) 22 | 23 | 24 | def to_ts(ds: str) -> datetime: 25 | """Convert a DS string to a `datetime`.""" 26 | return datetime.strptime(ds, DS_FRMT) 27 | 28 | 29 | def _ts_to_pandas(ts: int) -> pd.Timestamp: 30 | """Convert int timestamp into pandas timestamp.""" 31 | return pd.Timestamp(datetime.fromtimestamp(ts)) 32 | 33 | 34 | def _pandas_ts_to_int(ts: pd.Timestamp) -> int: 35 | """Convert int timestamp into pandas timestamp.""" 36 | # pyre-fixme[7]: Expected `int` but got `float`. 37 | return ts.to_pydatetime().timestamp() 38 | 39 | 40 | def current_timestamp_in_millis() -> int: 41 | """Grab current timestamp in milliseconds as an int.""" 42 | return int(round(time() * 1000)) 43 | 44 | 45 | def timestamps_in_range( 46 | start: datetime, end: datetime, delta: timedelta 47 | ) -> Generator[datetime, None, None]: 48 | """Generator of timestamps in range [start, end], at intervals 49 | delta. 50 | """ 51 | curr = start 52 | while curr <= end: 53 | yield curr 54 | curr += delta 55 | 56 | 57 | def unixtime_to_pandas_ts(ts: float) -> pd.Timestamp: 58 | """Convert float unixtime into pandas timestamp (UTC).""" 59 | return pd.to_datetime(ts, unit="s") 60 | -------------------------------------------------------------------------------- /ax/utils/testing/metrics/backend_simulator_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Any 10 | 11 | from ax.core.base_trial import BaseTrial 12 | from ax.core.map_metric import MapMetricFetchResult 13 | from ax.metrics.noisy_function_map import NoisyFunctionMapMetric 14 | 15 | 16 | class BackendSimulatorTimestampMapMetric(NoisyFunctionMapMetric): 17 | """A metric that interfaces with an underlying ``BackendSimulator`` and 18 | returns timestamp map data.""" 19 | 20 | def fetch_trial_data( 21 | self, trial: BaseTrial, noisy: bool = True, **kwargs: Any 22 | ) -> MapMetricFetchResult: 23 | """Fetch data for one trial.""" 24 | backend_simulator = trial.experiment.runner.simulator # pyre-ignore[16] 25 | sim_trial = backend_simulator.get_sim_trial_by_index(trial.index) 26 | end_time = ( 27 | backend_simulator.time 28 | if sim_trial.sim_completed_time is None 29 | else sim_trial.sim_completed_time 30 | ) 31 | timestamps = self.convert_to_timestamps( 32 | start_time=sim_trial.sim_start_time, end_time=end_time 33 | ) 34 | timestamp_kwargs = {"map_keys": ["timestamp"], "timestamp": timestamps} 35 | return NoisyFunctionMapMetric.fetch_trial_data( 36 | self, trial=trial, noisy=noisy, **kwargs, **timestamp_kwargs 37 | ) 38 | 39 | def convert_to_timestamps(self, start_time: float, end_time: float) -> list[float]: 40 | """Given a starting and current time, get the list of intermediate 41 | timestamps at which we have observations.""" 42 | raise NotImplementedError 43 | -------------------------------------------------------------------------------- /ax/service/tests/test_early_stopping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | 9 | from ax.service.utils import early_stopping as early_stopping_utils 10 | from ax.utils.common.testutils import TestCase 11 | from ax.utils.testing.core_stubs import ( 12 | DummyEarlyStoppingStrategy, 13 | get_branin_experiment, 14 | ) 15 | 16 | 17 | class TestEarlyStoppingUtils(TestCase): 18 | """Testing the early stopping utilities functionality that is not tested in 19 | main `AxClient` testing suite (`TestServiceAPI`).""" 20 | 21 | def setUp(self) -> None: 22 | super().setUp() 23 | self.branin_experiment = get_branin_experiment() 24 | 25 | def test_should_stop_trials_early(self) -> None: 26 | expected: dict[int, str | None] = { 27 | 1: "Stopped due to testing.", 28 | 3: "Stopped due to testing.", 29 | } 30 | actual = early_stopping_utils.should_stop_trials_early( 31 | early_stopping_strategy=DummyEarlyStoppingStrategy(expected), 32 | # pyre-fixme[6]: For 2nd param expected `Set[int]` but got `List[int]`. 33 | trial_indices=[1, 2, 3], 34 | experiment=self.branin_experiment, 35 | ) 36 | self.assertEqual(actual, expected) 37 | 38 | def test_should_stop_trials_early_no_strategy(self) -> None: 39 | actual = early_stopping_utils.should_stop_trials_early( 40 | early_stopping_strategy=None, 41 | # pyre-fixme[6]: For 2nd param expected `Set[int]` but got `List[int]`. 42 | trial_indices=[1, 2, 3], 43 | experiment=self.branin_experiment, 44 | ) 45 | expected = {} 46 | self.assertEqual(actual, expected) 47 | -------------------------------------------------------------------------------- /ax/utils/common/docutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | """Support functions for sphinx et. al""" 10 | 11 | from collections.abc import Callable 12 | from typing import Any, TypeVar 13 | 14 | 15 | _T = TypeVar("_T") 16 | 17 | 18 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 19 | # pyre-ignore[34]: T77127616 20 | def copy_doc(src: Callable[..., Any]) -> Callable[[_T], _T]: 21 | """A decorator that copies the docstring of another object 22 | 23 | Since ``sphinx`` actually loads the python modules to grab the docstrings 24 | this works with both ``sphinx`` and the ``help`` function. 25 | 26 | .. code:: python 27 | 28 | class Cat(Mamal): 29 | 30 | @property 31 | @copy_doc(Mamal.is_feline) 32 | def is_feline(self) -> true: 33 | ... 34 | """ 35 | # It would be tempting to try to get the doc through the class the method 36 | # is bound to (via __self__) but decorators are called before __self__ is 37 | # assigned. 38 | # One other solution would be to use a decorator on classes that would fill 39 | # all the missing docstrings but we want to be able to detect syntactically 40 | # when docstrings are copied to keep things nice and simple 41 | 42 | if src.__doc__ is None: 43 | raise ValueError(f"{src.__qualname__} has no docstring to copy") 44 | 45 | def copy_doc(dst: _T) -> _T: 46 | if dst.__doc__ is not None: 47 | # pyre-fixme[16]: `_T` has no attribute `__qualname__`. 48 | raise ValueError(f"{dst.__qualname__} already has a docstring") 49 | dst.__doc__ = src.__doc__ 50 | return dst 51 | 52 | return copy_doc 53 | -------------------------------------------------------------------------------- /ax/exceptions/data_provider.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from collections.abc import Iterable 10 | from typing import Any 11 | 12 | 13 | class DataProviderError(Exception): 14 | """Base Exception for Ax DataProviders. 15 | 16 | The type of the data provider must be included. 17 | The raw error is stored in the data_provider_error section, 18 | and an Ax-friendly message is stored as the actual error message. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | message: str, 24 | data_provider: str, 25 | # pyre-fixme[2]: Parameter annotation cannot be `Any`. 26 | data_provider_error: Any, 27 | ) -> None: 28 | self.message = message 29 | self.data_provider = data_provider 30 | self.data_provider_error = data_provider_error 31 | 32 | def __str__(self) -> str: 33 | return ( 34 | "{message}. \n Error thrown by: {dp} data provider \n" 35 | + "Native {dp} data provider error: {dp_error}" 36 | ).format( 37 | dp=self.data_provider, 38 | message=self.message, 39 | dp_error=self.data_provider_error, 40 | ) 41 | 42 | 43 | class MissingDataError(Exception): 44 | def __init__(self, missing_trial_indexes: Iterable[int]) -> None: 45 | missing_trial_str = ", ".join([str(index) for index in missing_trial_indexes]) 46 | self.message: str = ( 47 | f"Unable to find data for the following trials: {missing_trial_str} " 48 | "consider updating the data fetching kwargs or manually fetching " 49 | "data via `refetch_data()`" 50 | ) 51 | 52 | def __str__(self) -> str: 53 | return self.message 54 | -------------------------------------------------------------------------------- /website/static/img/ax.svg: -------------------------------------------------------------------------------- 1 | 01_FullColor -------------------------------------------------------------------------------- /ax/models/tests/test_discrete.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import numpy as np 10 | from ax.models.discrete_base import DiscreteModel 11 | from ax.utils.common.testutils import TestCase 12 | 13 | 14 | class DiscreteModelTest(TestCase): 15 | def test_discrete_model_get_state(self) -> None: 16 | discrete_model = DiscreteModel() 17 | self.assertEqual(discrete_model._get_state(), {}) 18 | 19 | def test_discrete_model_feature_importances(self) -> None: 20 | discrete_model = DiscreteModel() 21 | with self.assertRaises(NotImplementedError): 22 | discrete_model.feature_importances() 23 | 24 | def test_DiscreteModelFit(self) -> None: 25 | discrete_model = DiscreteModel() 26 | discrete_model.fit( 27 | Xs=[[[0]]], 28 | Ys=[[0]], 29 | Yvars=[[1]], 30 | parameter_values=[[0, 1]], 31 | outcome_names=[], 32 | ) 33 | 34 | def test_discreteModelPredict(self) -> None: 35 | discrete_model = DiscreteModel() 36 | with self.assertRaises(NotImplementedError): 37 | discrete_model.predict([[0]]) 38 | 39 | def test_discreteModelGen(self) -> None: 40 | discrete_model = DiscreteModel() 41 | with self.assertRaises(NotImplementedError): 42 | discrete_model.gen( 43 | n=1, parameter_values=[[0, 1]], objective_weights=np.array([1]) 44 | ) 45 | 46 | def test_discreteModelCrossValidate(self) -> None: 47 | discrete_model = DiscreteModel() 48 | with self.assertRaises(NotImplementedError): 49 | discrete_model.cross_validate( 50 | Xs_train=[[[0]]], Ys_train=[[1]], Yvars_train=[[1]], X_test=[[1]] 51 | ) 52 | -------------------------------------------------------------------------------- /ax/models/tests/test_randomforest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from ax.core.search_space import SearchSpaceDigest 11 | from ax.models.torch.randomforest import RandomForest 12 | from ax.utils.common.testutils import TestCase 13 | from botorch.utils.datasets import SupervisedDataset 14 | 15 | 16 | class RandomForestTest(TestCase): 17 | def test_RFModel(self) -> None: 18 | datasets = [ 19 | SupervisedDataset( 20 | X=torch.rand(10, 2), 21 | Y=torch.rand(10, 1), 22 | Yvar=torch.rand(10, 1), 23 | feature_names=["x1", "x2"], 24 | outcome_names=[f"y{i}"], 25 | ) 26 | for i in range(2) 27 | ] 28 | 29 | m = RandomForest(num_trees=5) 30 | m.fit( 31 | datasets=datasets, 32 | search_space_digest=SearchSpaceDigest( 33 | feature_names=["x1", "x2"], 34 | # pyre-fixme[6]: For 2nd param expected `List[Tuple[Union[float, 35 | # int], Union[float, int]]]` but got `List[Tuple[int, int]]`. 36 | bounds=[(0, 1)] * 2, 37 | ), 38 | ) 39 | self.assertEqual(len(m.models), 2) 40 | # pyre-fixme[16]: `RandomForestRegressor` has no attribute `estimators_`. 41 | self.assertEqual(len(m.models[0].estimators_), 5) 42 | 43 | f, cov = m.predict(torch.rand(5, 2)) 44 | self.assertEqual(f.shape, torch.Size((5, 2))) 45 | self.assertEqual(cov.shape, torch.Size((5, 2, 2))) 46 | 47 | f, cov = m.cross_validate(datasets=datasets, X_test=torch.rand(3, 2)) 48 | self.assertEqual(f.shape, torch.Size((3, 2))) 49 | self.assertEqual(cov.shape, torch.Size((3, 2, 2))) 50 | -------------------------------------------------------------------------------- /ax/models/discrete/eb_thompson.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import logging 10 | 11 | import numpy as np 12 | from ax.models.discrete.thompson import ThompsonSampler 13 | from ax.utils.common.logger import get_logger 14 | from ax.utils.stats.statstools import positive_part_james_stein 15 | 16 | 17 | logger: logging.Logger = get_logger(__name__) 18 | 19 | 20 | class EmpiricalBayesThompsonSampler(ThompsonSampler): 21 | """Generator for Thompson sampling using Empirical Bayes estimates. 22 | 23 | The generator applies positive-part James-Stein Estimator to the data 24 | passed in via `fit` and then performs Thompson Sampling. 25 | """ 26 | 27 | def _fit_Ys_and_Yvars( 28 | self, Ys: list[list[float]], Yvars: list[list[float]], outcome_names: list[str] 29 | ) -> tuple[list[list[float]], list[list[float]]]: 30 | newYs = [] 31 | newYvars = [] 32 | for i, (Y, Yvar) in enumerate(zip(Ys, Yvars)): 33 | newY, newYvar = self._apply_shrinkage(Y, Yvar, i) 34 | newYs.append(newY) 35 | newYvars.append(newYvar) 36 | return newYs, newYvars 37 | 38 | def _apply_shrinkage( 39 | self, Y: list[float], Yvar: list[float], outcome: int 40 | ) -> tuple[list[float], list[float]]: 41 | npY = np.array(Y) 42 | npYvar = np.array(Yvar) 43 | npYsem = np.sqrt(Yvar) 44 | try: 45 | npY, npYsem = positive_part_james_stein(means=npY, sems=npYsem) 46 | except ValueError as e: 47 | logger.warning( 48 | str(e) + f" Raw (unshrunk) estimates used for outcome: {outcome}" 49 | ) 50 | Y = npY.tolist() 51 | npYvar = npYsem**2 52 | Yvar = npYvar.tolist() 53 | return Y, Yvar 54 | -------------------------------------------------------------------------------- /ax/runners/tests/test_single_running_trial_mixin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | from ax.core.base_trial import TrialStatus 11 | from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin 12 | from ax.runners.synthetic import SyntheticRunner 13 | from ax.utils.common.testutils import TestCase 14 | from ax.utils.testing.core_stubs import get_branin_experiment 15 | 16 | 17 | class SyntheticRunnerWithSingleRunningTrial(SingleRunningTrialMixin, SyntheticRunner): 18 | pass 19 | 20 | 21 | class SingleRunningTrialMixinTest(TestCase): 22 | def test_single_running_trial_mixin(self) -> None: 23 | runner = SyntheticRunnerWithSingleRunningTrial() 24 | exp = get_branin_experiment(with_trial=True, with_batch=True) 25 | exp.runner = runner 26 | trials = exp.trials.values() 27 | for trial in trials: 28 | trial.assign_runner() 29 | trial.run() 30 | trial_statuses = runner.poll_trial_status(trials=trials) 31 | self.assertEqual(trial_statuses[TrialStatus.COMPLETED], {0}) 32 | self.assertEqual(trial_statuses[TrialStatus.RUNNING], {1}) 33 | 34 | def test_no_trials(self) -> None: 35 | runner = SyntheticRunnerWithSingleRunningTrial() 36 | trial_statuses = runner.poll_trial_status(trials=[]) 37 | self.assertEqual(trial_statuses, {}) 38 | 39 | def test_abandoned_trial(self) -> None: 40 | runner = SyntheticRunnerWithSingleRunningTrial() 41 | exp = get_branin_experiment(with_trial=True) 42 | exp.runner = runner 43 | trial = exp.trials[0] 44 | trial.assign_runner() 45 | trial.mark_abandoned() 46 | trial_statuses = runner.poll_trial_status(trials=[trial]) 47 | self.assertEqual(trial_statuses, {}) 48 | -------------------------------------------------------------------------------- /ax/storage/json_store/save.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import json 10 | from collections.abc import Callable 11 | from typing import Any 12 | 13 | from ax.core.experiment import Experiment 14 | from ax.storage.json_store.encoder import object_to_json 15 | from ax.storage.json_store.registry import ( 16 | CORE_CLASS_ENCODER_REGISTRY, 17 | CORE_ENCODER_REGISTRY, 18 | ) 19 | 20 | 21 | def save_experiment( 22 | experiment: Experiment, 23 | filepath: str, 24 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 25 | # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use 26 | # `typing.Type` to avoid runtime subscripting errors. 27 | encoder_registry: dict[ 28 | type, Callable[[Any], dict[str, Any]] 29 | ] = CORE_ENCODER_REGISTRY, 30 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 31 | # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use 32 | # `typing.Type` to avoid runtime subscripting errors. 33 | class_encoder_registry: dict[ 34 | type, Callable[[Any], dict[str, Any]] 35 | ] = CORE_CLASS_ENCODER_REGISTRY, 36 | ) -> None: 37 | """Save experiment to file. 38 | 39 | 1) Convert Ax experiment to JSON-serializable dictionary. 40 | 2) Write to file. 41 | """ 42 | if not isinstance(experiment, Experiment): 43 | raise ValueError("Can only save instances of Experiment") 44 | 45 | if not filepath.endswith(".json"): 46 | raise ValueError("Filepath must end in .json") 47 | 48 | json_experiment = object_to_json( 49 | experiment, 50 | encoder_registry=encoder_registry, 51 | class_encoder_registry=class_encoder_registry, 52 | ) 53 | with open(filepath, "w+") as file: 54 | file.write(json.dumps(json_experiment)) 55 | -------------------------------------------------------------------------------- /ax/analysis/healthcheck/should_generate_candidates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-unsafe 7 | 8 | import json 9 | 10 | import pandas as pd 11 | from ax.analysis.analysis import AnalysisCardLevel 12 | 13 | from ax.analysis.healthcheck.healthcheck_analysis import ( 14 | HealthcheckAnalysis, 15 | HealthcheckAnalysisCard, 16 | HealthcheckStatus, 17 | ) 18 | from ax.core.experiment import Experiment 19 | from ax.core.generation_strategy_interface import GenerationStrategyInterface 20 | 21 | 22 | class ShouldGenerateCandidates(HealthcheckAnalysis): 23 | def __init__( 24 | self, 25 | should_generate: bool, 26 | reason: str, 27 | trial_index: int, 28 | ) -> None: 29 | self.should_generate = should_generate 30 | self.reason = reason 31 | self.trial_index = trial_index 32 | 33 | def compute( 34 | self, 35 | experiment: Experiment | None = None, 36 | generation_strategy: GenerationStrategyInterface | None = None, 37 | ) -> HealthcheckAnalysisCard: 38 | status = ( 39 | HealthcheckStatus.PASS 40 | if self.should_generate 41 | else HealthcheckStatus.WARNING 42 | ) 43 | return HealthcheckAnalysisCard( 44 | name=self.name, 45 | title=f"Ready to Generate Candidates for Trial {self.trial_index}", 46 | blob=json.dumps( 47 | { 48 | "status": status, 49 | } 50 | ), 51 | subtitle=self.reason, 52 | df=pd.DataFrame( 53 | { 54 | "status": [status], 55 | "reason": [self.reason], 56 | } 57 | ), 58 | level=AnalysisCardLevel.CRITICAL, 59 | attributes=self.attributes, 60 | ) 61 | -------------------------------------------------------------------------------- /ax/models/random/uniform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import numpy as np 11 | import numpy.typing as npt 12 | from ax.models.random.base import RandomModel 13 | 14 | 15 | class UniformGenerator(RandomModel): 16 | """This class specifies a uniform random generation algorithm. 17 | 18 | As a uniform generator does not make use of a model, it does not implement 19 | the fit or predict methods. 20 | 21 | See base `RandomModel` for a description of model attributes. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | deduplicate: bool = True, 27 | seed: int | None = None, 28 | init_position: int = 0, 29 | generated_points: npt.NDArray | None = None, 30 | fallback_to_sample_polytope: bool = False, 31 | ) -> None: 32 | super().__init__( 33 | deduplicate=deduplicate, 34 | seed=seed, 35 | init_position=init_position, 36 | generated_points=generated_points, 37 | fallback_to_sample_polytope=fallback_to_sample_polytope, 38 | ) 39 | self._rs = np.random.RandomState(seed=self.seed) 40 | if self.init_position > 0: 41 | # Fast-forward the random state by generating & discarding samples. 42 | self._rs.uniform(size=(self.init_position)) 43 | 44 | def _gen_samples(self, n: int, tunable_d: int) -> npt.NDArray: 45 | """Generate samples from the scipy uniform distribution. 46 | 47 | Args: 48 | n: Number of samples to generate. 49 | tunable_d: Dimension of samples to generate. 50 | 51 | Returns: 52 | samples: An (n x d) array of random points. 53 | 54 | """ 55 | self.init_position += n * tunable_d 56 | return self._rs.uniform(size=(n, tunable_d)) 57 | -------------------------------------------------------------------------------- /ax/service/utils/best_point_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.core.experiment import Experiment 10 | from pyre_extensions import none_throws 11 | 12 | BASELINE_ARM_NAME = "baseline_arm" 13 | 14 | 15 | def select_baseline_name_default_first_trial( 16 | experiment: Experiment, baseline_arm_name: str | None 17 | ) -> tuple[str, bool]: 18 | """ 19 | Choose a baseline arm from arms on the experiment. Logic: 20 | 1. If ``baseline_arm_name`` provided, validate that arm exists 21 | and return that arm name. 22 | 2. If ``experiment.status_quo`` is set, return its arm name. 23 | 3. If there is at least one trial on the experiment, use the 24 | first trial's first arm as the baseline. 25 | 4. Error if 1-3 all don't apply. 26 | 27 | Returns: 28 | Tuple: 29 | baseline arm name (str) 30 | true when baseline selected from first arm of experiment (bool) 31 | raise ValueError if no valid baseline found 32 | """ 33 | 34 | arms_dict = experiment.arms_by_name 35 | 36 | if baseline_arm_name: 37 | if baseline_arm_name not in arms_dict: 38 | raise ValueError(f"Arm by name {baseline_arm_name=} not found.") 39 | return baseline_arm_name, False 40 | 41 | if experiment.status_quo and none_throws(experiment.status_quo).name in arms_dict: 42 | baseline_arm_name = none_throws(experiment.status_quo).name 43 | return baseline_arm_name, False 44 | 45 | if ( 46 | experiment.trials 47 | and experiment.trials[0].arms 48 | and experiment.trials[0].arms[0].name in arms_dict 49 | ): 50 | baseline_arm_name = experiment.trials[0].arms[0].name 51 | return baseline_arm_name, True 52 | 53 | else: 54 | raise ValueError("Could not find valid baseline arm.") 55 | -------------------------------------------------------------------------------- /ax/preview/api/protocols/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | 9 | from typing import Any, Mapping 10 | 11 | from ax.core.base_trial import TrialStatus 12 | from ax.preview.api.protocols.utils import _APIRunner 13 | from ax.preview.api.types import TParameterization 14 | from pyre_extensions import override 15 | 16 | 17 | class IRunner(_APIRunner): 18 | @override 19 | def run_trial( 20 | self, trial_index: int, parameterization: TParameterization 21 | ) -> dict[str, Any]: 22 | """ 23 | Given an index and parameterization, run a trial and return a dictionary of any 24 | appropriate metadata. This metadata will be used to identify the trial when 25 | polling its status, stopping, fetching data, etc. This may hold information 26 | such as the trial's unique identifier on the system its running on, a 27 | directory where the trial is logging results to, etc. 28 | 29 | The metadata MUST be JSON-serializable (i.e. dict, list, str, int, float, bool, 30 | or None) so that Trials may be properly serialized in Ax. 31 | """ 32 | ... 33 | 34 | @override 35 | def poll_trial( 36 | self, trial_index: int, trial_metadata: Mapping[str, Any] 37 | ) -> TrialStatus: 38 | """ 39 | Given trial index and metadata, poll the status of the trial. 40 | """ 41 | ... 42 | 43 | @override 44 | def stop_trial( 45 | self, trial_index: int, trial_metadata: Mapping[str, Any] 46 | ) -> dict[str, Any]: 47 | """ 48 | Given trial index and metadata, stop the trial. Returns a dictionary of any 49 | appropriate metadata. 50 | 51 | The metadata MUST be JSON-serializable (i.e. dict, list, str, int, float, bool, 52 | or None) so that Trials may be properly serialized in Ax. 53 | """ 54 | ... 55 | -------------------------------------------------------------------------------- /ax/core/tests/test_map_metric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from ax.core.map_metric import MapMetric 10 | from ax.utils.common.testutils import TestCase 11 | from ax.utils.testing.core_stubs import get_map_data 12 | 13 | 14 | METRIC_STRING = "MapMetric('m1')" 15 | 16 | 17 | class MapMetricTest(TestCase): 18 | def test_Init(self) -> None: 19 | metric = MapMetric(name="m1", lower_is_better=False) 20 | self.assertEqual(str(metric), METRIC_STRING) 21 | 22 | def test_Eq(self) -> None: 23 | metric1 = MapMetric(name="m1", lower_is_better=False) 24 | metric2 = MapMetric(name="m1", lower_is_better=False) 25 | self.assertEqual(metric1, metric2) 26 | 27 | metric3 = MapMetric(name="m1", lower_is_better=True) 28 | self.assertNotEqual(metric1, metric3) 29 | 30 | def test_Clone(self) -> None: 31 | metric1 = MapMetric(name="m1", lower_is_better=False) 32 | self.assertEqual(metric1, metric1.clone()) 33 | 34 | def test_Sortable(self) -> None: 35 | metric1 = MapMetric(name="m1", lower_is_better=False) 36 | metric2 = MapMetric(name="m2", lower_is_better=False) 37 | self.assertTrue(metric1 < metric2) 38 | 39 | def test_WrapUnwrap(self) -> None: 40 | data = get_map_data() 41 | 42 | trial_multi = MapMetric._unwrap_trial_data_multi( 43 | results=MapMetric._wrap_trial_data_multi(data=data) 44 | ) 45 | self.assertEqual(trial_multi, data) 46 | 47 | experiment = MapMetric._unwrap_experiment_data( 48 | results=MapMetric._wrap_experiment_data(data=data) 49 | ) 50 | self.assertEqual(experiment, data) 51 | 52 | experiment_multi = MapMetric._unwrap_experiment_data_multi( 53 | results=MapMetric._wrap_experiment_data_multi(data=data) 54 | ) 55 | self.assertEqual(experiment_multi, data) 56 | -------------------------------------------------------------------------------- /sphinx/source/metrics.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.metrics 5 | ========== 6 | 7 | .. automodule:: ax.metrics 8 | .. currentmodule:: ax.metrics 9 | 10 | 11 | Branin 12 | ~~~~~~ 13 | 14 | .. automodule:: ax.metrics.branin 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Branin Map 20 | ~~~~~~~~~~ 21 | 22 | .. automodule:: ax.metrics.branin_map 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Chemistry 28 | ~~~~~~~~~ 29 | 30 | .. automodule:: ax.metrics.chemistry 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | 35 | Curve 36 | ~~~~~ 37 | 38 | .. automodule:: ax.metrics.curve 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | 43 | Factorial 44 | ~~~~~~~~~ 45 | 46 | .. automodule:: ax.metrics.factorial 47 | :members: 48 | :undoc-members: 49 | :show-inheritance: 50 | 51 | Hartmann6 52 | ~~~~~~~~~ 53 | 54 | .. automodule:: ax.metrics.hartmann6 55 | :members: 56 | :undoc-members: 57 | :show-inheritance: 58 | 59 | L2 Norm 60 | ~~~~~~~ 61 | 62 | .. automodule:: ax.metrics.l2norm 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | Noisy Functions 68 | ~~~~~~~~~~~~~~~ 69 | 70 | .. automodule:: ax.metrics.noisy_function 71 | :members: 72 | :undoc-members: 73 | :show-inheritance: 74 | 75 | Noisy Function Map 76 | ~~~~~~~~~~~~~~~~~~ 77 | 78 | .. automodule:: ax.metrics.noisy_function_map 79 | :members: 80 | :undoc-members: 81 | :show-inheritance: 82 | 83 | Sklearn 84 | ~~~~~~~ 85 | 86 | .. automodule:: ax.metrics.sklearn 87 | :members: 88 | :undoc-members: 89 | :show-inheritance: 90 | 91 | Tensorboard 92 | ~~~~~~~~~~~ 93 | 94 | .. automodule:: ax.metrics.tensorboard 95 | :members: 96 | :undoc-members: 97 | :show-inheritance: 98 | 99 | 100 | TorchX 101 | ~~~~~~ 102 | 103 | .. automodule:: ax.metrics.torchx 104 | :members: 105 | :undoc-members: 106 | :show-inheritance: 107 | -------------------------------------------------------------------------------- /ax/analysis/plotly/plotly_analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | 9 | import pandas as pd 10 | from ax.analysis.analysis import Analysis, AnalysisCard 11 | from ax.core.experiment import Experiment 12 | from ax.core.generation_strategy_interface import GenerationStrategyInterface 13 | from IPython.display import display, Markdown 14 | from plotly import graph_objects as go, io as pio 15 | 16 | 17 | class PlotlyAnalysisCard(AnalysisCard): 18 | blob_annotation = "plotly" 19 | 20 | def get_figure(self) -> go.Figure: 21 | return pio.from_json(self.blob) 22 | 23 | def _ipython_display_(self) -> None: 24 | """ 25 | IPython display hook. This is called when the AnalysisCard is printed in an 26 | IPython environment (ex. Jupyter). Here we want to display the Plotly figure. 27 | """ 28 | display(Markdown(f"## {self.title}\n\n### {self.subtitle}")) 29 | display(self.get_figure()) 30 | 31 | 32 | class PlotlyAnalysis(Analysis): 33 | """ 34 | An Analysis that computes a Plotly figure. 35 | """ 36 | 37 | def compute( 38 | self, 39 | experiment: Experiment | None = None, 40 | generation_strategy: GenerationStrategyInterface | None = None, 41 | ) -> PlotlyAnalysisCard: ... 42 | 43 | def _create_plotly_analysis_card( 44 | self, 45 | title: str, 46 | subtitle: str, 47 | level: int, 48 | df: pd.DataFrame, 49 | fig: go.Figure, 50 | ) -> PlotlyAnalysisCard: 51 | """ 52 | Make a PlotlyAnalysisCard from this Analysis using provided fields and 53 | details about the Analysis class. 54 | """ 55 | return PlotlyAnalysisCard( 56 | name=self.name, 57 | attributes=self.attributes, 58 | title=title, 59 | subtitle=subtitle, 60 | level=level, 61 | df=df, 62 | blob=pio.to_json(fig), 63 | ) 64 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/1_bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report. 3 | labels: ["bug"] 4 | title: "[Bug]: " 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for taking the time to fill out a bug report. We strive to make Ax a useful and stable library for all our users. 10 | - type: textarea 11 | id: what-happened 12 | attributes: 13 | label: What happened? 14 | description: Provide a detailed description of the bug as well as the expected behavior. 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: repro 19 | attributes: 20 | label: Please provide a minimal, reproducible example of the unexpected behavior. 21 | description: Follow [these guidelines](https://stackoverflow.com/help/minimal-reproducible-example) for writing your example. 22 | validations: 23 | required: true 24 | - type: textarea 25 | id: traceback 26 | attributes: 27 | label: Please paste any relevant traceback/logs produced by the example provided. 28 | description: This will be automatically formatted into code, so no need for backticks. 29 | render: shell 30 | - type: input 31 | id: ax-version 32 | attributes: 33 | label: Ax Version 34 | description: What version of Ax are you using? 35 | validations: 36 | required: true 37 | - type: input 38 | id: python-version 39 | attributes: 40 | label: Python Version 41 | description: What version of Python are you using? 42 | validations: 43 | required: true 44 | - type: input 45 | id: os 46 | attributes: 47 | label: Operating System 48 | description: What operating system are you using? 49 | validations: 50 | required: true 51 | - type: checkboxes 52 | id: terms 53 | attributes: 54 | label: Code of Conduct 55 | description: By submitting this issue, you agree to follow Ax's [Code of Conduct](https://github.com/facebook/Ax/blob/main/CODE_OF_CONDUCT.md). 56 | options: 57 | - label: I agree to follow Ax's Code of Conduct 58 | required: true 59 | -------------------------------------------------------------------------------- /ax/service/utils/early_stopping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | 9 | from ax.core.experiment import Experiment 10 | from ax.early_stopping.strategies import BaseEarlyStoppingStrategy 11 | from pyre_extensions import none_throws 12 | 13 | 14 | def should_stop_trials_early( 15 | early_stopping_strategy: BaseEarlyStoppingStrategy | None, 16 | trial_indices: set[int], 17 | experiment: Experiment, 18 | ) -> dict[int, str | None]: 19 | """Evaluate whether to early-stop running trials. 20 | 21 | Args: 22 | early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines 23 | whether a trial should be stopped given the state of an experiment. 24 | trial_indices: Indices of trials to consider for early stopping. 25 | experiment: The experiment containing the trials. 26 | 27 | Returns: 28 | A dictionary mapping trial indices that should be early stopped to 29 | (optional) messages with the associated reason. 30 | """ 31 | if early_stopping_strategy is None: 32 | return {} 33 | 34 | early_stopping_strategy = none_throws(early_stopping_strategy) 35 | return early_stopping_strategy.should_stop_trials_early( 36 | trial_indices=trial_indices, experiment=experiment 37 | ) 38 | 39 | 40 | def get_early_stopping_metrics( 41 | experiment: Experiment, early_stopping_strategy: BaseEarlyStoppingStrategy | None 42 | ) -> list[str]: 43 | """A helper function that returns a list of metric names on which a given 44 | `early_stopping_strategy` is operating.""" 45 | if early_stopping_strategy is None: 46 | return [] 47 | if early_stopping_strategy.metric_names is not None: 48 | return list(early_stopping_strategy.metric_names) 49 | # TODO: generalize this to multi-objective ess 50 | default_objective, _ = early_stopping_strategy._default_objective_and_direction( 51 | experiment=experiment 52 | ) 53 | return [default_objective] 54 | -------------------------------------------------------------------------------- /sphinx/source/preview.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.preview 5 | ========== 6 | 7 | .. automodule:: ax.preview 8 | .. currentmodule:: ax.preview 9 | 10 | A preview of future Ax API 11 | -------------------------- 12 | 13 | 14 | IMetric 15 | ~~~~~~~ 16 | 17 | .. automodule:: ax.preview.api.protocols.metric 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | IRunner 23 | ~~~~~~~ 24 | 25 | .. automodule:: ax.preview.api.protocols.runner 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | 31 | Utils 32 | ~~~~~~~ 33 | 34 | .. automodule:: ax.preview.api.protocols.utils 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Client 41 | ~~~~~~ 42 | 43 | .. automodule:: ax.preview.api.client 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | 48 | 49 | Configs 50 | ~~~~~~~ 51 | 52 | .. automodule:: ax.preview.api.configs 53 | :members: 54 | :undoc-members: 55 | :show-inheritance: 56 | 57 | Types 58 | ~~~~~ 59 | 60 | .. automodule:: ax.preview.api.types 61 | :members: 62 | :undoc-members: 63 | :show-inheritance: 64 | 65 | From Config 66 | ~~~~~~~~~~~ 67 | 68 | .. automodule:: ax.preview.api.utils.instantiation.from_config 69 | :members: 70 | :undoc-members: 71 | :show-inheritance: 72 | 73 | From String 74 | ~~~~~~~~~~~ 75 | 76 | .. automodule:: ax.preview.api.utils.instantiation.from_string 77 | :members: 78 | :undoc-members: 79 | :show-inheritance: 80 | 81 | 82 | ModelBridge 83 | ~~~~~~~~~~~ 84 | 85 | .. automodule:: ax.preview.modelbridge 86 | :members: 87 | :undoc-members: 88 | :show-inheritance: 89 | 90 | Dispatch Utils 91 | ~~~~~~~~~~~~~~ 92 | 93 | .. automodule:: ax.preview.modelbridge.dispatch_utils 94 | :members: 95 | :undoc-members: 96 | :show-inheritance: 97 | 98 | Storage Utils 99 | ~~~~~~~~~~~~~ 100 | 101 | .. automodule:: ax.preview.api.utils.storage 102 | :members: 103 | :undoc-members: 104 | :show-inheritance: 105 | -------------------------------------------------------------------------------- /website/core/Footer.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | const React = require('react'); 9 | 10 | class Footer extends React.Component { 11 | 12 | docUrl(doc, language) { 13 | const baseUrl = this.props.config.baseUrl; 14 | const docsUrl = this.props.config.docsUrl; 15 | const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`; 16 | const langPart = `${language ? `${language}/` : ''}`; 17 | return `${baseUrl}${docsPart}${langPart}${doc}`; 18 | } 19 | 20 | pageUrl(doc, language) { 21 | const baseUrl = this.props.config.baseUrl; 22 | return baseUrl + (language ? `${language}/` : '') + doc; 23 | } 24 | 25 | render() { 26 | const currentYear = new Date().getFullYear(); 27 | 28 | return ( 29 | 63 | ); 64 | } 65 | } 66 | 67 | module.exports = Footer; 68 | -------------------------------------------------------------------------------- /ax/core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # flake8: noqa F401 10 | from ax.core.arm import Arm 11 | from ax.core.batch_trial import BatchTrial 12 | from ax.core.data import Data 13 | from ax.core.experiment import Experiment 14 | from ax.core.generator_run import GeneratorRun 15 | from ax.core.metric import Metric 16 | from ax.core.objective import MultiObjective, Objective 17 | from ax.core.observation import ObservationFeatures 18 | from ax.core.optimization_config import ( 19 | MultiObjectiveOptimizationConfig, 20 | OptimizationConfig, 21 | ) 22 | from ax.core.outcome_constraint import ( 23 | ComparisonOp, 24 | ObjectiveThreshold, 25 | OutcomeConstraint, 26 | ) 27 | from ax.core.parameter import ( 28 | ChoiceParameter, 29 | FixedParameter, 30 | Parameter, 31 | ParameterType, 32 | RangeParameter, 33 | ) 34 | from ax.core.parameter_constraint import ( 35 | OrderConstraint, 36 | ParameterConstraint, 37 | SumConstraint, 38 | ) 39 | from ax.core.parameter_distribution import ParameterDistribution 40 | from ax.core.risk_measures import RiskMeasure 41 | from ax.core.runner import Runner 42 | from ax.core.search_space import SearchSpace 43 | from ax.core.trial import Trial 44 | 45 | 46 | __all__ = [ 47 | "Arm", 48 | "BatchTrial", 49 | "ChoiceParameter", 50 | "ComparisonOp", 51 | "Data", 52 | "Experiment", 53 | "FixedParameter", 54 | "GeneratorRun", 55 | "Metric", 56 | "MultiObjective", 57 | "MultiObjectiveOptimizationConfig", 58 | "Objective", 59 | "ObjectiveThreshold", 60 | "ObservationFeatures", 61 | "OptimizationConfig", 62 | "OrderConstraint", 63 | "OutcomeConstraint", 64 | "Parameter", 65 | "ParameterConstraint", 66 | "ParameterDistribution", 67 | "ParameterType", 68 | "RangeParameter", 69 | "RiskMeasure", 70 | "Runner", 71 | "SearchSpace", 72 | "SimpleExperiment", 73 | "SumConstraint", 74 | "Trial", 75 | ] 76 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_typeutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import numpy as np 11 | from ax.utils.common.testutils import TestCase 12 | from ax.utils.common.typeutils import ( 13 | checked_cast, 14 | checked_cast_dict, 15 | checked_cast_list, 16 | checked_cast_optional, 17 | ) 18 | from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type 19 | 20 | 21 | class TestTypeUtils(TestCase): 22 | def test_checked_cast(self) -> None: 23 | self.assertEqual(checked_cast(float, 2.0), 2.0) 24 | with self.assertRaises(ValueError): 25 | checked_cast(float, 2) 26 | 27 | def test_checked_cast_with_error_override(self) -> None: 28 | self.assertEqual(checked_cast(float, 2.0), 2.0) 29 | with self.assertRaises(NotImplementedError): 30 | checked_cast( 31 | float, 2, exception=NotImplementedError("foo() doesn't support ints") 32 | ) 33 | 34 | def test_checked_cast_list(self) -> None: 35 | self.assertEqual(checked_cast_list(float, [1.0, 2.0]), [1.0, 2.0]) 36 | with self.assertRaises(ValueError): 37 | checked_cast_list(float, [1.0, 2]) 38 | 39 | def test_checked_cast_optional(self) -> None: 40 | self.assertEqual(checked_cast_optional(float, None), None) 41 | with self.assertRaises(ValueError): 42 | checked_cast_optional(float, 2) 43 | 44 | def test_checked_cast_dict(self) -> None: 45 | self.assertEqual(checked_cast_dict(str, int, {"some": 1}), {"some": 1}) 46 | with self.assertRaises(ValueError): 47 | checked_cast_dict(str, int, {"some": 1.0}) 48 | with self.assertRaises(ValueError): 49 | checked_cast_dict(str, int, {1: 1}) 50 | 51 | def test_numpy_type_to_python_type(self) -> None: 52 | self.assertEqual(type(numpy_type_to_python_type(np.int64(2))), int) 53 | self.assertEqual(type(numpy_type_to_python_type(np.float64(2))), float) 54 | -------------------------------------------------------------------------------- /website/static/img/ax_lockup.svg: -------------------------------------------------------------------------------- 1 | 01_FullColor -------------------------------------------------------------------------------- /website/static/img/ax_logo_lockup.svg: -------------------------------------------------------------------------------- 1 | 01_FullColor -------------------------------------------------------------------------------- /website/static/img/ax_lockup_white.svg: -------------------------------------------------------------------------------- 1 | Ax_Identity_Lockup_white_font -------------------------------------------------------------------------------- /docs/why-ax.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: why-ax 3 | title: Why Ax? 4 | sidebar_label: Why Ax? 5 | --- 6 | 7 | Developers and researchers alike face problems which confront them with a large space of possible ways to configure something –– whether those are "magic numbers" used for infrastructure or compiler flags, learning rates or other hyperparameters in machine learning, or images and calls-to-action used in marketing promotions. Selecting and tuning these configurations can often take time, resources, and can affect the quality of user experiences. Ax is a machine learning system to help automate this process, so that researchers and developers can determine how to get the most out of their software in an optimally efficient way. 8 | 9 | Ax is a platform for optimizing any kind of experiment, including machine learning experiments, A/B tests, and simulations. Ax can optimize discrete configurations (e.g., variants of an A/B test) using multi-armed bandit optimization, and continuous (e.g., integer or floating point)-valued configurations using Bayesian optimization. This makes it suitable for a wide range of applications. 10 | 11 | Ax has been successfully applied to a variety of product, infrastructure, ML, and research applications at Facebook. 12 | 13 | # Unique capabilities 14 | - **Support for noisy functions**. Results of A/B tests and simulations with reinforcement learning agents often exhibit high amounts of noise. Ax supports [state-of-the-art algorithms](https://research.facebook.com/blog/2018/09/efficient-tuning-of-online-systems-using-bayesian-optimization/) which work better than traditional Bayesian optimization in high-noise settings. 15 | - **Customization**. Ax's developer API makes it easy to integrate custom data modeling and decision algorithms. This allows developers to build their own custom optimization services with minimal overhead. 16 | - **Multi-modal experimentation**. Ax has first-class support for running and combining data from different types of experiments, such as "offline" simulation data and "online" data from real-world experiments. 17 | - **Multi-objective optimization**. Ax supports multi-objective and constrained optimization which are common to real-world problems, like improving load time without increasing data use. 18 | -------------------------------------------------------------------------------- /ax/analysis/plotly/surface/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | import math 8 | 9 | import numpy as np 10 | from ax.core.parameter import ( 11 | ChoiceParameter, 12 | FixedParameter, 13 | Parameter, 14 | RangeParameter, 15 | TParamValue, 16 | ) 17 | 18 | 19 | def get_parameter_values(parameter: Parameter, density: int = 100) -> list[TParamValue]: 20 | """ 21 | Get a list of parameter values to predict over for a given parameter. 22 | """ 23 | 24 | # For RangeParameter use linspace for the range of the parameter 25 | if isinstance(parameter, RangeParameter): 26 | if parameter.log_scale: 27 | return np.logspace( 28 | math.log10(parameter.lower), math.log10(parameter.upper), density 29 | ).tolist() 30 | 31 | return np.linspace(parameter.lower, parameter.upper, density).tolist() 32 | 33 | # For ChoiceParameter use the values of the parameter directly 34 | if isinstance(parameter, ChoiceParameter) and parameter.is_ordered: 35 | return parameter.values 36 | 37 | raise ValueError( 38 | f"Parameter {parameter.name} must be a RangeParameter or " 39 | "ChoiceParameter with is_ordered=True to be used in surface plot." 40 | ) 41 | 42 | 43 | def select_fixed_value(parameter: Parameter) -> TParamValue: 44 | """ 45 | Select a fixed value for a parameter. Use mean for RangeParameter, "middle" value 46 | for ChoiceParameter, and value for FixedParameter. 47 | """ 48 | if isinstance(parameter, RangeParameter): 49 | return (parameter.lower * 1.0 + parameter.upper) / 2 50 | elif isinstance(parameter, ChoiceParameter): 51 | return parameter.values[len(parameter.values) // 2] 52 | elif isinstance(parameter, FixedParameter): 53 | return parameter.value 54 | else: 55 | raise ValueError(f"Got unexpected parameter type {parameter}.") 56 | 57 | 58 | def is_axis_log_scale(parameter: Parameter) -> bool: 59 | """ 60 | Check if the parameter is log scale. 61 | """ 62 | return isinstance(parameter, RangeParameter) and parameter.log_scale 63 | -------------------------------------------------------------------------------- /ax/benchmark/problems/hd_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | import copy 9 | from typing import TypeVar 10 | 11 | from ax.benchmark.benchmark_problem import BenchmarkProblem 12 | from ax.core.parameter import ParameterType, RangeParameter 13 | from ax.core.search_space import SearchSpace 14 | 15 | TProblem = TypeVar("TProblem", bound=BenchmarkProblem) 16 | 17 | 18 | def embed_higher_dimension(problem: TProblem, total_dimensionality: int) -> TProblem: 19 | """ 20 | Return a new `BenchmarkProblem` with enough `RangeParameter`s added to the 21 | search space to make its total dimensionality equal to `total_dimensionality` 22 | and add `total_dimensionality` to its name. 23 | 24 | The search space of the original `problem` is within the search space of the 25 | new problem, and the constraints are copied from the original problem. 26 | """ 27 | num_dummy_dimensions = total_dimensionality - len(problem.search_space.parameters) 28 | 29 | search_space = SearchSpace( 30 | parameters=[ 31 | *problem.search_space.parameters.values(), 32 | *[ 33 | RangeParameter( 34 | name=f"embedding_dummy_{i}", 35 | parameter_type=ParameterType.FLOAT, 36 | lower=0, 37 | upper=1, 38 | ) 39 | for i in range(num_dummy_dimensions) 40 | ], 41 | ], 42 | parameter_constraints=problem.search_space.parameter_constraints, 43 | ) 44 | 45 | # if problem name already has dimensionality in it, strip it 46 | def _is_dim_suffix(s: str) -> bool: 47 | return s[-1] == "d" and all(char in "0123456789" for char in s[:-1]) 48 | 49 | orig_name_without_dimensionality = "_".join( 50 | [substr for substr in problem.name.split("_") if not _is_dim_suffix(substr)] 51 | ) 52 | new_name = f"{orig_name_without_dimensionality}_{total_dimensionality}d" 53 | 54 | new_problem = copy.copy(problem) 55 | new_problem.name = new_name 56 | new_problem.search_space = search_space 57 | return new_problem 58 | -------------------------------------------------------------------------------- /ax/utils/common/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from __future__ import annotations 10 | 11 | import abc 12 | 13 | from ax.utils.common.equality import equality_typechecker, object_attribute_dicts_equal 14 | 15 | 16 | class Base: 17 | """Metaclass for core Ax classes. Provides an equality check and `db_id` 18 | property for SQA storage. 19 | """ 20 | 21 | _db_id: int | None = None 22 | 23 | @property 24 | def db_id(self) -> int | None: 25 | return self._db_id 26 | 27 | @db_id.setter 28 | def db_id(self, db_id: int) -> None: 29 | self._db_id = db_id 30 | 31 | @equality_typechecker 32 | def __eq__(self, other: Base) -> bool: 33 | return object_attribute_dicts_equal( 34 | one_dict=self.__dict__, other_dict=other.__dict__ 35 | ) 36 | 37 | @equality_typechecker 38 | def _eq_skip_db_id_check(self, other: Base) -> bool: 39 | return object_attribute_dicts_equal( 40 | one_dict=self.__dict__, other_dict=other.__dict__, skip_db_id_check=True 41 | ) 42 | 43 | 44 | class SortableBase(Base, metaclass=abc.ABCMeta): 45 | """Extension to the base class that also provides an inequality check.""" 46 | 47 | @property 48 | @abc.abstractmethod 49 | def _unique_id(self) -> str: 50 | """Returns an identification string that can be used to uniquely 51 | identify this instance from others attached to the same parent 52 | object. For example, for ``Trials`` this can be their index, 53 | 54 | since that is unique w.r.t. to parent ``Experiment`` object. 55 | For ``GenerationNode``-s attached to a ``GenerationStrategy``, 56 | this can be their name since we ensure uniqueness of it upon 57 | ``GenerationStrategy`` instantiation. 58 | 59 | This method is needed to correctly update SQLAlchemy objects 60 | that appear as children of other objects, in lists or other 61 | sortable collections or containers. 62 | """ 63 | pass 64 | 65 | def __lt__(self, other: SortableBase) -> bool: 66 | return self._unique_id < other._unique_id 67 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Deploy 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | tests-and-coverage-latest: 12 | name: Tests with latest BoTorch 13 | uses: ./.github/workflows/reusable_test.yml 14 | with: 15 | pinned_botorch: false 16 | secrets: inherit 17 | 18 | tests-and-coverage-pinned: 19 | name: Tests with pinned BoTorch 20 | uses: ./.github/workflows/reusable_test.yml 21 | with: 22 | pinned_botorch: true 23 | secrets: inherit 24 | 25 | publish-stable-website: 26 | 27 | needs: tests-and-coverage-pinned # only run if test step succeeds 28 | runs-on: ubuntu-latest 29 | 30 | steps: 31 | - uses: actions/checkout@v4 32 | - name: Set up Python 33 | uses: actions/setup-python@v5 34 | with: 35 | python-version: "3.10" 36 | - name: Install dependencies 37 | run: | 38 | # use stable Botorch 39 | pip install -e ".[tutorial]" 40 | - name: Publish latest website 41 | env: 42 | DOCUSAURUS_PUBLISH_TOKEN: ${{ secrets.DOCUSAURUS_PUBLISH_TOKEN }} 43 | run: | 44 | bash scripts/publish_site.sh -d -v ${{ github.event.release.tag_name }} 45 | 46 | deploy: 47 | 48 | needs: tests-and-coverage-pinned # only run if test step succeeds 49 | runs-on: ubuntu-latest 50 | 51 | steps: 52 | - uses: actions/checkout@v4 53 | - name: Set up Python 54 | uses: actions/setup-python@v5 55 | with: 56 | python-version: "3.10" 57 | - name: Install dependencies 58 | run: | 59 | # use stable Botorch 60 | pip install -e ".[dev,mysql,notebook]" 61 | pip install --upgrade build setuptools setuptools_scm wheel 62 | - name: Fetch all history for all tags and branches 63 | run: git fetch --prune --unshallow 64 | - name: Build wheel 65 | run: | 66 | python -m build --sdist --wheel 67 | - name: Deploy to PyPI 68 | uses: pypa/gh-action-pypi-publish@release/v1 69 | with: 70 | user: __token__ 71 | password: ${{ secrets.PYPI_TOKEN }} 72 | verbose: true 73 | -------------------------------------------------------------------------------- /ax/plot/tests/test_diagnostic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import plotly.graph_objects as go 10 | from ax.modelbridge.cross_validation import cross_validate 11 | from ax.modelbridge.registry import Models 12 | from ax.plot.base import AxPlotConfig 13 | from ax.plot.diagnostic import ( 14 | interact_cross_validation, 15 | interact_cross_validation_plotly, 16 | ) 17 | from ax.utils.common.testutils import TestCase 18 | from ax.utils.testing.core_stubs import get_branin_experiment 19 | from ax.utils.testing.mock import mock_botorch_optimize 20 | 21 | 22 | class DiagnosticTest(TestCase): 23 | @mock_botorch_optimize 24 | def setUp(self) -> None: 25 | super().setUp() 26 | exp = get_branin_experiment(with_batch=True) 27 | exp.trials[0].run() 28 | self.model = Models.BOTORCH_MODULAR( 29 | # Model bridge kwargs 30 | experiment=exp, 31 | data=exp.fetch_data(), 32 | ) 33 | 34 | def test_cross_validation(self) -> None: 35 | for autoset_axis_limits in [False, True]: 36 | cv = cross_validate(self.model) 37 | # Assert that each type of plot can be constructed successfully 38 | label_dict = {"branin": "BrAnIn"} 39 | plot = interact_cross_validation_plotly( 40 | cv, label_dict=label_dict, autoset_axis_limits=autoset_axis_limits 41 | ) 42 | x_range = plot.layout.updatemenus[0].buttons[0].args[1]["xaxis.range"] 43 | y_range = plot.layout.updatemenus[0].buttons[0].args[1]["yaxis.range"] 44 | if autoset_axis_limits: 45 | self.assertTrue((len(x_range) == 2) and (x_range[0] < x_range[1])) 46 | self.assertTrue((len(y_range) == 2) and (y_range[0] < y_range[1])) 47 | else: 48 | self.assertIsNone(x_range) 49 | self.assertIsNone(y_range) 50 | 51 | self.assertIsInstance(plot, go.Figure) 52 | plot = interact_cross_validation( 53 | cv, label_dict=label_dict, autoset_axis_limits=autoset_axis_limits 54 | ) 55 | self.assertIsInstance(plot, AxPlotConfig) 56 | -------------------------------------------------------------------------------- /website/tutorials.json: -------------------------------------------------------------------------------- 1 | { 2 | "API Comparison": [ 3 | { 4 | "id": "gpei_hartmann_service", 5 | "title": "[RECOMMENDED] Service API" 6 | }, 7 | { 8 | "id": "gpei_hartmann_loop", 9 | "title": "Loop API" 10 | }, 11 | { 12 | "id": "gpei_hartmann_developer", 13 | "title": "Developer API" 14 | } 15 | ], 16 | "Deep Dives": [ 17 | { 18 | "id": "visualizations", 19 | "title": "Visualizations" 20 | }, 21 | { 22 | "id": "generation_strategy", 23 | "title": "Generation Strategy" 24 | }, 25 | { 26 | "id": "scheduler", 27 | "title": "Scheduler" 28 | }, 29 | { 30 | "id": "modular_botax", 31 | "title": "Modular `BoTorchModel`" 32 | } 33 | ], 34 | "Bayesian Optimization": [ 35 | { 36 | "id": "tune_cnn_service", 37 | "title": "Hyperparameter Optimization for PyTorch" 38 | }, 39 | { 40 | "id": "submitit", 41 | "title": "Hyperparameter Optimization on SLURM via SubmitIt" 42 | }, 43 | { 44 | "id": "multi_task", 45 | "title": "Multi-Task Modeling" 46 | }, 47 | { 48 | "id": "multiobjective_optimization", 49 | "title": "Multi-Objective Optimization" 50 | }, 51 | { 52 | "id": "saasbo", 53 | "title": "High-Dimensional Bayesian Optimization with Sparse Axis-Aligned Subspaces (SAASBO)" 54 | }, 55 | { 56 | "id": "saasbo_nehvi", 57 | "title": "Fully Bayesian, High-Dimensional, Multi-Objective Optimization" 58 | }, 59 | { 60 | "id": "sebo", 61 | "title": "Sparsity Exploration Bayesian Optimization (SEBO)" 62 | }, 63 | { 64 | "dir": "early_stopping", 65 | "id": "early_stopping", 66 | "title": "Trial-Level Early Stopping" 67 | }, 68 | { 69 | "id": "gss", 70 | "title": "Global Stopping (Experiment-Level Early Stopping)" 71 | } 72 | ], 73 | "Field Experiments": [ 74 | { 75 | "id": "factorial", 76 | "title": "Bandit Optimization" 77 | }, 78 | { 79 | "dir": "human_in_the_loop", 80 | "id": "human_in_the_loop", 81 | "title": "Human-in-the-Loop Optimization" 82 | } 83 | ], 84 | "Integrating External Strategies": [ 85 | { 86 | "id": "external_generation_node", 87 | "title": "RandomForest with ExternalGenerationNode" 88 | } 89 | ] 90 | } 91 | -------------------------------------------------------------------------------- /ax/storage/tests/test_botorch_modular_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # pyre-strict 7 | 8 | from ax.models.torch.botorch_modular.acquisition import Acquisition 9 | from ax.storage.botorch_modular_registry import ( 10 | ACQUISITION_FUNCTION_REGISTRY, 11 | ACQUISITION_REGISTRY, 12 | MODEL_REGISTRY, 13 | register_acquisition, 14 | register_acquisition_function, 15 | register_model, 16 | REVERSE_ACQUISITION_FUNCTION_REGISTRY, 17 | REVERSE_ACQUISITION_REGISTRY, 18 | REVERSE_MODEL_REGISTRY, 19 | ) 20 | from ax.utils.common.testutils import TestCase 21 | from botorch.acquisition.acquisition import AcquisitionFunction 22 | from botorch.models.model import Model 23 | 24 | 25 | class NewModel(Model): 26 | pass 27 | 28 | 29 | class NewAcquisition(Acquisition): 30 | pass 31 | 32 | 33 | class NewAcquisitionFunction(AcquisitionFunction): 34 | pass 35 | 36 | 37 | class RegisterNewClassTest(TestCase): 38 | def test_register_model(self) -> None: 39 | self.assertNotIn(NewModel, MODEL_REGISTRY) 40 | self.assertNotIn(NewModel, REVERSE_MODEL_REGISTRY.values()) 41 | register_model(NewModel) 42 | self.assertIn(NewModel, MODEL_REGISTRY) 43 | self.assertIn(NewModel, REVERSE_MODEL_REGISTRY.values()) 44 | 45 | def test_register_acquisition(self) -> None: 46 | self.assertNotIn(NewAcquisition, ACQUISITION_REGISTRY) 47 | self.assertNotIn(NewAcquisition, REVERSE_ACQUISITION_REGISTRY.values()) 48 | register_acquisition(NewAcquisition) 49 | self.assertIn(NewAcquisition, ACQUISITION_REGISTRY) 50 | self.assertIn(NewAcquisition, REVERSE_ACQUISITION_REGISTRY.values()) 51 | 52 | def test_register_acquisition_function(self) -> None: 53 | self.assertNotIn(NewAcquisitionFunction, ACQUISITION_FUNCTION_REGISTRY) 54 | self.assertNotIn( 55 | NewAcquisitionFunction, REVERSE_ACQUISITION_FUNCTION_REGISTRY.values() 56 | ) 57 | register_acquisition_function(NewAcquisitionFunction) 58 | self.assertIn(NewAcquisitionFunction, ACQUISITION_FUNCTION_REGISTRY) 59 | self.assertIn( 60 | NewAcquisitionFunction, REVERSE_ACQUISITION_FUNCTION_REGISTRY.values() 61 | ) 62 | -------------------------------------------------------------------------------- /sphinx/source/service.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.service 5 | ========== 6 | 7 | .. automodule:: ax.service 8 | .. currentmodule:: ax.service 9 | 10 | 11 | Ax Client 12 | ~~~~~~~~~ 13 | 14 | .. automodule:: ax.service.ax_client 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Managed Loop 20 | ~~~~~~~~~~~~ 21 | 22 | .. automodule:: ax.service.managed_loop 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Interactive Loop 28 | ~~~~~~~~~~~~~~~~ 29 | 30 | .. automodule:: ax.service.interactive_loop 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | 35 | 36 | Scheduler 37 | ~~~~~~~~~ 38 | 39 | .. automodule:: ax.service.scheduler 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | .. automodule:: ax.service.utils.scheduler_options 45 | :members: 46 | :undoc-members: 47 | :show-inheritance: 48 | 49 | Utils 50 | ----- 51 | 52 | Analysis 53 | ~~~~~~~~ 54 | 55 | .. automodule:: ax.service.utils.analysis_base 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | 61 | Best Point Identification 62 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 63 | 64 | .. automodule:: ax.service.utils.best_point_mixin 65 | :members: 66 | :undoc-members: 67 | :show-inheritance: 68 | 69 | 70 | .. automodule:: ax.service.utils.best_point 71 | :members: 72 | :undoc-members: 73 | :show-inheritance: 74 | 75 | 76 | .. automodule:: ax.service.utils.best_point_utils 77 | :members: 78 | :undoc-members: 79 | :show-inheritance: 80 | 81 | 82 | Instantiation 83 | ~~~~~~~~~~~~~ 84 | 85 | .. automodule:: ax.service.utils.instantiation 86 | :members: 87 | :undoc-members: 88 | :show-inheritance: 89 | 90 | 91 | Reporting 92 | ~~~~~~~~~ 93 | 94 | .. automodule:: ax.service.utils.report_utils 95 | :members: 96 | :undoc-members: 97 | :show-inheritance: 98 | 99 | 100 | WithDBSettingsBase 101 | ~~~~~~~~~~~~~~~~~~ 102 | 103 | .. automodule:: ax.service.utils.with_db_settings_base 104 | :members: 105 | :undoc-members: 106 | :show-inheritance: 107 | 108 | 109 | EarlyStopping 110 | ~~~~~~~~~~~~~ 111 | 112 | .. automodule:: ax.service.utils.early_stopping 113 | :members: 114 | :undoc-members: 115 | :show-inheritance: 116 | --------------------------------------------------------------------------------