├── website ├── static │ ├── .nojekyll │ ├── CNAME │ ├── img │ │ ├── favicon.png │ │ ├── oss_logo.png │ │ ├── favicon │ │ │ └── favicon.ico │ │ ├── ax_wireframe.svg │ │ ├── database-solid.svg │ │ ├── th-large-solid.svg │ │ ├── dice-solid.svg │ │ ├── ax.svg │ │ ├── ax_lockup.svg │ │ ├── ax_logo_lockup.svg │ │ └── ax_lockup_white.svg │ └── js │ │ ├── mathjax.js │ │ └── plotUtils.js ├── versioned_docs │ └── .gitkeep ├── versioned_sidebars │ └── .gitkeep ├── sidebars.json ├── package.json ├── tutorials.json └── core │ └── Footer.js ├── docs ├── assets │ ├── gp_opt.png │ ├── bo_1d_opt.gif │ ├── mab_animate.gif │ ├── mab_probs.png │ ├── mab_regret.png │ ├── gp_posterior.png │ ├── bandit_allocation.png │ └── example_shrinkage.png ├── algo-overview.md └── why-ax.md ├── ax ├── metrics │ ├── chemistry_data.zip │ ├── tests │ │ ├── __init__.py │ │ ├── test_curve.py │ │ └── test_tensorboard.py │ ├── l2norm.py │ ├── __init__.py │ ├── hartmann6.py │ ├── branin.py │ ├── botorch_test_problem.py │ └── jenatton.py ├── benchmark │ ├── methods │ │ ├── __init__.py │ │ ├── choose_generation_strategy.py │ │ └── saasbo.py │ ├── problems │ │ ├── __init__.py │ │ ├── hpo │ │ │ └── __init__.py │ │ ├── hss │ │ │ └── __init__.py │ │ ├── synthetic │ │ │ ├── __init__.py │ │ │ └── hss │ │ │ │ ├── __init__.py │ │ │ │ └── jenatton.py │ │ ├── baseline_results │ │ │ ├── __init__.py │ │ │ ├── hpo │ │ │ │ ├── __init__.py │ │ │ │ └── torchvision │ │ │ │ │ └── __init__.py │ │ │ └── synthetic │ │ │ │ ├── __init__.py │ │ │ │ ├── hd │ │ │ │ └── __init__.py │ │ │ │ └── hss │ │ │ │ └── __init__.py │ │ └── hd_embedding.py │ ├── __init__.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_torchvision_problem_storage.py │ │ ├── test_problems.py │ │ ├── test_benchmark_method.py │ │ └── test_scored_benchmark.py │ └── benchmark_method.py ├── plot │ ├── __init__.py │ ├── css │ │ ├── __init__.py │ │ └── base.css │ ├── js │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── css.js │ │ │ ├── plotly_requires.js │ │ │ ├── plotly_offline.js │ │ │ └── plotly_online.js │ │ └── generic_plotly.js │ ├── tests │ │ ├── __init__.py │ │ ├── test_parallel_coordinates.py │ │ ├── test_helper.py │ │ ├── test_diagnostic.py │ │ ├── test_slices.py │ │ ├── test_traces.py │ │ ├── test_fitted_scatter.py │ │ ├── test_contours.py │ │ └── test_feature_importances.py │ └── marginal_effects.py ├── utils │ ├── __init__.py │ ├── stats │ │ ├── __init__.py │ │ └── tests │ │ │ └── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── test_serialization.py │ │ │ ├── test_docutils.py │ │ │ ├── test_equality.py │ │ │ ├── test_typeutils.py │ │ │ └── test_logger.py │ │ ├── base.py │ │ ├── timeutils.py │ │ └── docutils.py │ ├── notebook │ │ ├── __init__.py │ │ └── plotting.py │ ├── report │ │ ├── __init__.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ └── test_render.py │ │ └── resources │ │ │ ├── __init__.py │ │ │ ├── simple_template.html │ │ │ ├── sufficient_statistic.html │ │ │ └── base_template.html │ ├── testing │ │ ├── __init__.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ └── backend_simulator_map.py │ │ ├── tests │ │ │ └── __init__.py │ │ ├── doctest.py │ │ ├── test_init_files.py │ │ ├── pyre_strict.py │ │ ├── torch_stubs.py │ │ └── unittest_conventions.py │ ├── tutorials │ │ └── __init__.py │ ├── flake8_plugins │ │ └── __init__.py │ └── measurement │ │ ├── __init__.py │ │ └── tests │ │ └── __init__.py ├── core │ ├── tests │ │ ├── __init__.py │ │ ├── test_map_metric.py │ │ ├── test_metric.py │ │ ├── test_types.py │ │ ├── test_runner.py │ │ ├── test_parameter_distribution.py │ │ ├── test_risk_measures.py │ │ └── test_arm.py │ ├── map_metric.py │ └── __init__.py ├── exceptions │ ├── __init__.py │ ├── model.py │ ├── constants.py │ ├── storage.py │ ├── generation_strategy.py │ └── data_provider.py ├── models │ ├── discrete │ │ ├── __init__.py │ │ └── eb_thompson.py │ ├── random │ │ ├── __init__.py │ │ └── uniform.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_base.py │ │ ├── test_alebo_initializer.py │ │ ├── test_rembo_initializer.py │ │ ├── test_randomforest.py │ │ ├── test_discrete.py │ │ ├── test_torch.py │ │ └── test_full_factorial.py │ ├── torch │ │ ├── __init__.py │ │ ├── tests │ │ │ └── __init__.py │ │ ├── botorch_modular │ │ │ └── __init__.py │ │ └── frontier_utils.py │ ├── __init__.py │ └── types.py ├── runners │ ├── tests │ │ └── __init__.py │ ├── __init__.py │ └── synthetic.py ├── service │ ├── tests │ │ ├── __init__.py │ │ ├── test_early_stopping.py │ │ └── test_best_point.py │ ├── utils │ │ ├── __init__.py │ │ └── early_stopping.py │ └── __init__.py ├── early_stopping │ ├── tests │ │ └── __init__.py │ ├── __init__.py │ └── strategies │ │ ├── __init__.py │ │ └── logical.py ├── global_stopping │ ├── tests │ │ └── __init__.py │ ├── __init__.py │ └── strategies │ │ ├── __init__.py │ │ └── base.py ├── modelbridge │ ├── tests │ │ ├── __init__.py │ │ ├── test_centered_unit_x_transform.py │ │ ├── test_transform_utils.py │ │ ├── test_rounding.py │ │ ├── test_base_transform.py │ │ └── test_cap_parameter_transform.py │ ├── strategies │ │ └── __init__.py │ ├── transforms │ │ ├── __init__.py │ │ ├── centered_unit_x.py │ │ └── cap_parameter.py │ └── __init__.py ├── storage │ ├── json_store │ │ ├── tests │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── load.py │ │ └── save.py │ ├── sqa_store │ │ ├── tests │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── delete.py │ │ ├── structs.py │ │ ├── timestamp.py │ │ ├── reduced_state.py │ │ └── sqa_enum.py │ ├── __init__.py │ └── utils.py └── __init__.py ├── CHANGELOG.md ├── pyproject.toml ├── sphinx └── source │ ├── ax.rst │ ├── global_stopping.rst │ ├── index.rst │ ├── runners.rst │ ├── early_stopping.rst │ ├── exceptions.rst │ ├── service.rst │ └── metrics.rst ├── pytest.ini ├── scripts ├── import_ax.py ├── docker_install.sh ├── build_ax.sh ├── wheels_build.ps1 ├── patch_site_config.py └── insert_api_refs.py ├── .flake8 └── LICENSE /website/static/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /website/static/CNAME: -------------------------------------------------------------------------------- 1 | ax.dev 2 | -------------------------------------------------------------------------------- /website/versioned_docs/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /website/versioned_sidebars/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/assets/gp_opt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/gp_opt.png -------------------------------------------------------------------------------- /docs/assets/bo_1d_opt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/bo_1d_opt.gif -------------------------------------------------------------------------------- /docs/assets/mab_animate.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/mab_animate.gif -------------------------------------------------------------------------------- /docs/assets/mab_probs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/mab_probs.png -------------------------------------------------------------------------------- /docs/assets/mab_regret.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/mab_regret.png -------------------------------------------------------------------------------- /ax/metrics/chemistry_data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/ax/metrics/chemistry_data.zip -------------------------------------------------------------------------------- /docs/assets/gp_posterior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/gp_posterior.png -------------------------------------------------------------------------------- /website/static/img/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/website/static/img/favicon.png -------------------------------------------------------------------------------- /website/static/img/oss_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/website/static/img/oss_logo.png -------------------------------------------------------------------------------- /docs/assets/bandit_allocation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/bandit_allocation.png -------------------------------------------------------------------------------- /docs/assets/example_shrinkage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/example_shrinkage.png -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | Ax uses GitHub tags for managing releases. See changelog [here](https://github.com/facebook/Ax/releases). 2 | -------------------------------------------------------------------------------- /website/static/img/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basnijholt/Ax/main/website/static/img/favicon/favicon.ico -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=34.4", "wheel", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.usort] 6 | first_party_detection = false 7 | -------------------------------------------------------------------------------- /docs/algo-overview.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: algo-overview 3 | title: Overview 4 | --- 5 | 6 | Ax supports: 7 | * Bandit optimization 8 | * Empirical Bayes with Thompson sampling 9 | * Bayesian optimization 10 | -------------------------------------------------------------------------------- /ax/benchmark/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/benchmark/problems/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /sphinx/source/ax.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax 5 | =================================== 6 | 7 | .. automodule:: ax 8 | :members: 9 | :noindex: 10 | 11 | .. currentmodule:: ax 12 | -------------------------------------------------------------------------------- /ax/benchmark/problems/hpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/benchmark/problems/hss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/benchmark/problems/synthetic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/benchmark/problems/synthetic/hss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/plot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/problems/baseline_results/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/core/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/plot/css/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/plot/js/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/plot/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/stats/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/problems/baseline_results/hpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/benchmark/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/metrics/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/discrete/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/random/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/torch/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/plot/js/common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/runners/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/service/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/service/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/notebook/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/report/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/testing/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/tutorials/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/problems/baseline_results/synthetic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/benchmark/problems/baseline_results/synthetic/hd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/early_stopping/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/global_stopping/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/modelbridge/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/torch/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/common/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/flake8_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/measurement/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/report/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/stats/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/testing/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/testing/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/benchmark/problems/baseline_results/hpo/torchvision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/benchmark/problems/baseline_results/synthetic/hss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /ax/modelbridge/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/modelbridge/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/storage/json_store/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/measurement/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/utils/report/resources/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /ax/models/torch/botorch_modular/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /website/sidebars.json: -------------------------------------------------------------------------------- 1 | { 2 | "docs": { 3 | "Introduction": ["why-ax"], 4 | "Getting Started": ["installation", "api", "glossary"], 5 | "Algorithms": ["bayesopt", "banditopt"], 6 | "Components": ["core", "trial-evaluation", "data", "models", "storage"] 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /ax/utils/report/resources/simple_template.html: -------------------------------------------------------------------------------- 1 | 4 | {% extends "base_template.html" %} 5 | {% block content %} 6 | {% for element in html_elements %} 7 | {{element}} 8 | {% endfor %} 9 | {% endblock %} 10 | -------------------------------------------------------------------------------- /ax/early_stopping/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from ax.early_stopping import strategies 9 | 10 | __all__ = ["strategies"] 11 | -------------------------------------------------------------------------------- /ax/global_stopping/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from ax.global_stopping import strategies 9 | 10 | __all__ = ["strategies"] 11 | -------------------------------------------------------------------------------- /ax/plot/js/generic_plotly.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | Plotly.newPlot( 9 | {{id}}, 10 | {{data}}, 11 | {{layout}}, 12 | {"showLink": false} 13 | ); 14 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # Configuration for pytest. 2 | [pytest] 3 | filterwarnings = 4 | # Filter out parameters and sklearn deprecation warnings. 5 | ignore::DeprecationWarning:.*paramz.* 6 | ignore::DeprecationWarning:.*sklearn* 7 | # Filter out numpy non-integer indices warning. 8 | ignore::DeprecationWarning:.*using a non-integer array as obj in delete* 9 | -------------------------------------------------------------------------------- /ax/service/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.service.managed_loop import OptimizationLoop, optimize 8 | 9 | 10 | __all__ = ["OptimizationLoop", "optimize"] 11 | -------------------------------------------------------------------------------- /ax/storage/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.storage.json_store import load as json_load, save as json_save 8 | 9 | 10 | __all__ = ["json_save", "json_load"] 11 | -------------------------------------------------------------------------------- /website/static/img/ax_wireframe.svg: -------------------------------------------------------------------------------- 1 | 06_OneColor_White -------------------------------------------------------------------------------- /ax/plot/js/common/css.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | var css = document.createElement('style'); 9 | css.type = 'text/css'; 10 | css.innerHTML = "{{css}}"; 11 | document.getElementsByTagName("head")[0].appendChild(css); 12 | -------------------------------------------------------------------------------- /scripts/import_ax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ax 8 | from ax.service.ax_client import AxClient 9 | 10 | 11 | if __name__ == "__main__": 12 | assert ax is not None 13 | assert AxClient is not None 14 | -------------------------------------------------------------------------------- /ax/plot/js/common/plotly_requires.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | require(['plotly'], function(Plotly) { 9 | window.PLOTLYENV = window.PLOTLYENV || {}; 10 | window.PLOTLYENV.BASE_URL = 'https://plot.ly'; 11 | {{script}} 12 | }); 13 | -------------------------------------------------------------------------------- /ax/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa F401 8 | from ax.models.random.sobol import SobolGenerator 9 | from ax.models.torch.botorch import BotorchModel 10 | 11 | 12 | __all__ = ["SobolGenerator", "BotorchModel"] 13 | -------------------------------------------------------------------------------- /ax/storage/json_store/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.storage.json_store.load import load_experiment as json_load 8 | from ax.storage.json_store.save import save_experiment as json_save 9 | 10 | 11 | __all__ = ["json_load", "json_save"] 12 | -------------------------------------------------------------------------------- /website/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "scripts": { 3 | "examples": "docusaurus-examples", 4 | "start": "docusaurus-start", 5 | "build": "docusaurus-build", 6 | "publish-gh-pages": "docusaurus-publish", 7 | "write-translations": "docusaurus-write-translations", 8 | "version": "docusaurus-version", 9 | "rename-version": "docusaurus-rename-version" 10 | }, 11 | "devDependencies": { 12 | "docusaurus": "^1.7.2" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | 3 | # E704: the linter doesn't parse types properly 4 | # T499, T484: silence mypy since using pyre for typechecking 5 | # W503: black and flake8 disagree on how to place operators 6 | # E231: black and flake8 disagree on whitespace after ',' 7 | # E203: black and flake8 disagree on whitespace before ':' 8 | ignore = T484, T499, W503, E704, E231, E203 9 | 10 | # Black really wants lines to be 88 chars... 11 | max-line-length = 88 12 | -------------------------------------------------------------------------------- /ax/runners/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa F401 8 | from ax.runners.simulated_backend import SimulatedBackendRunner 9 | from ax.runners.synthetic import SyntheticRunner 10 | 11 | 12 | __all__ = ["SimulatedBackendRunner", "SyntheticRunner"] 13 | -------------------------------------------------------------------------------- /ax/plot/js/common/plotly_offline.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | if (!window.Plotly) { 9 | define('plotly', function(require, exports, module) { 10 | {{library}} 11 | }); 12 | require(['plotly'], function(Plotly) { 13 | window.Plotly = Plotly; 14 | }); 15 | } 16 | -------------------------------------------------------------------------------- /ax/metrics/l2norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from ax.metrics.noisy_function import NoisyFunctionMetric 9 | 10 | 11 | class L2NormMetric(NoisyFunctionMetric): 12 | def f(self, x: np.ndarray) -> float: 13 | return np.sqrt((x**2).sum()) 14 | -------------------------------------------------------------------------------- /ax/plot/js/common/plotly_online.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | requirejs.config({ 9 | paths: { 10 | plotly: ['https://cdn.plot.ly/plotly-latest.min'], 11 | }, 12 | }); 13 | if (!window.Plotly) { 14 | require(['plotly'], function(plotly) { 15 | window.Plotly = plotly; 16 | }); 17 | } 18 | -------------------------------------------------------------------------------- /ax/global_stopping/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy 8 | from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy 9 | 10 | 11 | __all__ = [ 12 | "BaseGlobalStoppingStrategy", 13 | "ImprovementGlobalStoppingStrategy", 14 | ] 15 | -------------------------------------------------------------------------------- /ax/models/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, Union 8 | 9 | from ax.core.optimization_config import OptimizationConfig 10 | from botorch.acquisition import AcquisitionFunction 11 | 12 | TConfig = Dict[ 13 | str, 14 | Union[ 15 | int, float, str, AcquisitionFunction, Dict[str, Any], OptimizationConfig, None 16 | ], 17 | ] 18 | -------------------------------------------------------------------------------- /ax/modelbridge/transforms/centered_unit_x.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from ax.modelbridge.transforms.unit_x import UnitX 9 | 10 | 11 | class CenteredUnitX(UnitX): 12 | """Map X to [-1, 1]^d for RangeParameter of type float and not log scale. 13 | 14 | Transform is done in-place. 15 | """ 16 | 17 | target_lb: float = -1.0 18 | target_range: float = 2.0 19 | -------------------------------------------------------------------------------- /ax/exceptions/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.exceptions.core import AxError 8 | 9 | 10 | class ModelError(AxError): 11 | """Raised when an error occurs during modeling.""" 12 | 13 | pass 14 | 15 | 16 | class CVNotSupportedError(AxError): 17 | """Raised when cross validation is applied to a model which doesn't 18 | support it. 19 | """ 20 | 21 | pass 22 | -------------------------------------------------------------------------------- /ax/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa F401 8 | from ax.metrics.branin import BraninMetric 9 | from ax.metrics.chemistry import ChemistryMetric 10 | from ax.metrics.factorial import FactorialMetric 11 | from ax.metrics.sklearn import SklearnMetric 12 | 13 | __all__ = [ 14 | "BraninMetric", 15 | "ChemistryMetric", 16 | "FactorialMetric", 17 | "SklearnMetric", 18 | ] 19 | -------------------------------------------------------------------------------- /ax/models/torch/frontier_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.models.torch.botorch_moo_defaults import ( 8 | get_default_frontier_evaluator, 9 | get_weighted_mc_objective_and_objective_thresholds, 10 | TFrontierEvaluator, 11 | ) 12 | 13 | 14 | __all__ = [ 15 | "get_weighted_mc_objective_and_objective_thresholds", 16 | "get_default_frontier_evaluator", 17 | "TFrontierEvaluator", 18 | ] 19 | -------------------------------------------------------------------------------- /ax/metrics/tests/test_curve.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from ax.metrics.curve import AbstractCurveMetric 9 | from ax.utils.common.testutils import TestCase 10 | 11 | 12 | class AbstractCurveMetricTest(TestCase): 13 | def testAbstractCurveMetric(self): 14 | self.assertTrue(AbstractCurveMetric.is_available_while_running()) 15 | with self.assertRaises(TypeError): 16 | AbstractCurveMetric("foo", "bar") 17 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # necessary to import this file so SQLAlchemy knows about the event listeners 8 | # see https://fburl.com/8mn7yjt2 9 | from ax.storage.sqa_store import validation 10 | from ax.storage.sqa_store.load import load_experiment as sqa_load 11 | from ax.storage.sqa_store.save import save_experiment as sqa_save 12 | 13 | 14 | __all__ = ["sqa_load", "sqa_save"] 15 | 16 | del validation 17 | -------------------------------------------------------------------------------- /scripts/docker_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | VERSION=$1 8 | 9 | # Build Linux Python3.7 10 | # Execute from Ax root directory. 11 | cd ../ || exit 12 | /opt/python/cp37-cp37m/bin/pip3.7 install numpy 13 | /opt/python/cp37-cp37m/bin/python3.7 setup.py bdist_wheel 14 | 15 | # Convert to manylinux 16 | cd dist || exit 17 | auditwheel repair ax_platform-"$VERSION"-cp37-cp37m-linux_x86_64.whl 18 | rm ./* 19 | mv wheelhouse/* . 20 | rm -rf wheelhouse 21 | -------------------------------------------------------------------------------- /ax/modelbridge/tests/test_centered_unit_x_transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.modelbridge.tests.test_unit_x_transform import UnitXTransformTest 8 | from ax.modelbridge.transforms.centered_unit_x import CenteredUnitX 9 | 10 | 11 | class CenteredUnitXTransformTest(UnitXTransformTest): 12 | 13 | transform_class = CenteredUnitX 14 | expected_c_dicts = [{"x": -0.5, "y": 0.5}, {"x": -0.5, "a": 1.0}] 15 | expected_c_bounds = [0.0, 1.5] 16 | -------------------------------------------------------------------------------- /sphinx/source/global_stopping.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.global_stopping 5 | =================================== 6 | 7 | .. automodule:: ax.global_stopping 8 | .. currentmodule:: ax.global_stopping 9 | 10 | Strategies 11 | ---------- 12 | 13 | Base Strategies 14 | ~~~~~~~~~~~~~~~ 15 | 16 | .. automodule:: ax.global_stopping.strategies.base 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | 21 | `ImprovementGlobalStoppingStrategy` 22 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 23 | 24 | .. automodule:: ax.global_stopping.strategies.improvement 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | -------------------------------------------------------------------------------- /sphinx/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Ax documentation index file, created by 2 | sphinx-quickstart on Sat Mar 2 00:03:32 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | API Reference 7 | =============================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | 12 | ax 13 | benchmark 14 | core 15 | exceptions 16 | metrics 17 | modelbridge 18 | models 19 | plot 20 | runners 21 | service 22 | storage 23 | utils 24 | 25 | 26 | Indices and tables 27 | ================== 28 | 29 | * :ref:`genindex` 30 | * :ref:`modindex` 31 | * :ref:`search` 32 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/delete.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.storage.sqa_store.db import session_scope 7 | from ax.storage.sqa_store.sqa_classes import SQAExperiment 8 | 9 | 10 | def delete_experiment(exp_name: str) -> None: 11 | """Delete experiment by name. 12 | 13 | Args: 14 | experiment_name: Name of the experiment to delete. 15 | """ 16 | with session_scope() as session: 17 | exp = session.query(SQAExperiment).filter_by(name=exp_name).one_or_none() 18 | session.delete(exp) 19 | session.flush() 20 | -------------------------------------------------------------------------------- /scripts/build_ax.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # ManyLinux build 8 | # Mount top-level Ax directory as Ax-main. 9 | docker container run --mount type=bind,source="$(pwd)/../",target=/Ax-main -it quay.io/pypa/manylinux2010_x86_64 10 | 11 | # MANUAL STEP FOR NOW 12 | # Now, in Docker container, cd Ax-main and MANUALLY RUN ./docker_install.sh 13 | 14 | # LOCAL BUILD 15 | # Requires Python 3.7 installed locally, and on path 16 | cd .. 17 | pip3.7 install numpy 18 | python3.7 setup.py bdist_wheel 19 | 20 | # Final PyPI Upload 21 | twine upload dist/* 22 | -------------------------------------------------------------------------------- /ax/metrics/hartmann6.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from ax.metrics.noisy_function import NoisyFunctionMetric 9 | from ax.utils.common.typeutils import checked_cast 10 | from ax.utils.measurement.synthetic_functions import aug_hartmann6, hartmann6 11 | 12 | 13 | class Hartmann6Metric(NoisyFunctionMetric): 14 | def f(self, x: np.ndarray) -> float: 15 | return checked_cast(float, hartmann6(x)) 16 | 17 | 18 | class AugmentedHartmann6Metric(NoisyFunctionMetric): 19 | def f(self, x: np.ndarray) -> float: 20 | return checked_cast(float, aug_hartmann6(x)) 21 | -------------------------------------------------------------------------------- /ax/utils/report/resources/sufficient_statistic.html: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | {% for cell in cells %} 14 | {% if cell.html %} 15 |
16 |

{{cell.caption}}

17 |
18 |
19 | {{cell.html}} 20 |
21 | {% endif %} 22 | {% endfor %} 23 |
24 | 25 | 26 | -------------------------------------------------------------------------------- /ax/utils/testing/doctest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import doctest 8 | import unittest 9 | 10 | from ax.utils.common import testutils 11 | from ax.utils.testing.manifest import ModuleInfo, populate_test_class 12 | 13 | 14 | def run_doctests(t: unittest.TestCase, m: ModuleInfo) -> None: 15 | results = doctest.testmod(m.module, optionflags=doctest.ELLIPSIS) 16 | assert results.failed == 0 17 | 18 | 19 | @populate_test_class(run_doctests) 20 | class TestDocTests(testutils.TestCase): 21 | """ 22 | Run all the doctests in the main library. 23 | 24 | This is a support file for `ae_unittest`. 25 | """ 26 | -------------------------------------------------------------------------------- /website/static/js/mathjax.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | window.MathJax = { 9 | tex2jax: { 10 | inlineMath: [['$', '$'], ['\\(', '\\)']], 11 | displayMath: [['$$', '$$'], ['\\[', '\\]']], 12 | processEscapes: true, 13 | processEnvironments: true, 14 | }, 15 | // Center justify equations in code and markdown cells. Note that this 16 | // doesn't work with Plotly though, hence the !important declaratio 17 | // below. 18 | displayAlign: 'center', 19 | 'HTML-CSS': { 20 | styles: { 21 | '.MathJax_Display': {margin: 0, 'text-align': 'center !important'}, 22 | }, 23 | linebreaks: {automatic: true}, 24 | }, 25 | }; 26 | -------------------------------------------------------------------------------- /ax/modelbridge/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa F401 8 | from ax.modelbridge import transforms 9 | from ax.modelbridge.base import ModelBridge 10 | from ax.modelbridge.factory import ( 11 | get_factorial, 12 | get_GPEI, 13 | get_sobol, 14 | get_thompson, 15 | get_uniform, 16 | Models, 17 | ) 18 | from ax.modelbridge.torch import TorchModelBridge 19 | 20 | 21 | __all__ = [ 22 | "ModelBridge", 23 | "Models", 24 | "TorchModelBridge", 25 | "get_factorial", 26 | "get_GPEI", 27 | "get_GPKG", 28 | "get_sobol", 29 | "get_thompson", 30 | "get_uniform", 31 | "transforms", 32 | ] 33 | -------------------------------------------------------------------------------- /website/static/img/database-solid.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 11 | 12 | -------------------------------------------------------------------------------- /ax/models/tests/test_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.models.base import Model 8 | from ax.utils.common.testutils import TestCase 9 | 10 | 11 | class BaseModelTest(TestCase): 12 | def test_base_model(self): 13 | model = Model() 14 | raw_state = {"foo": "bar", "two": 3.0} 15 | self.assertEqual(model.serialize_state(raw_state), raw_state) 16 | self.assertEqual(model.deserialize_state(raw_state), raw_state) 17 | self.assertEqual(model._get_state(), {}) 18 | with self.assertRaisesRegex( 19 | NotImplementedError, "Feature importance not available" 20 | ): 21 | model.feature_importances() 22 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/structs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, NamedTuple, Optional 8 | 9 | from ax.storage.sqa_store.decoder import Decoder 10 | from ax.storage.sqa_store.encoder import Encoder 11 | from ax.storage.sqa_store.sqa_config import SQAConfig 12 | 13 | 14 | class DBSettings(NamedTuple): 15 | """ 16 | Defines behavior for loading/saving experiment to/from db. 17 | Either creator or url must be specified as a way to connect to the SQL db. 18 | """ 19 | 20 | creator: Optional[Callable] = None 21 | decoder: Decoder = Decoder(config=SQAConfig()) 22 | encoder: Encoder = Encoder(config=SQAConfig()) 23 | url: Optional[str] = None 24 | -------------------------------------------------------------------------------- /sphinx/source/runners.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.runners 5 | =================================== 6 | 7 | .. automodule:: ax.runners 8 | .. currentmodule:: ax.runners 9 | 10 | BoTorch Test Problem 11 | ~~~~~~ 12 | 13 | .. automodule:: ax.runners.botorch_test_problem 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | 19 | Synthetic Runner 20 | ~~~~~~~~~~~~~~~~ 21 | 22 | .. automodule:: ax.runners.synthetic 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Simulated Backend Runner 28 | ~~~~~~~~~~~~~~~~ 29 | 30 | .. automodule:: ax.runners.simulated_backend 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | 35 | TorchX Runner 36 | ~~~~~~~~~~~~~~~~ 37 | 38 | .. automodule:: ax.runners.torchx 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | -------------------------------------------------------------------------------- /ax/utils/testing/test_init_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from glob import glob 9 | 10 | from ax.utils.common.testutils import TestCase 11 | 12 | 13 | class InitTest(TestCase): 14 | def testInitFiles(self) -> None: 15 | """__init__.py files are necessary when not using buck targets""" 16 | for root, _dirs, files in os.walk("./ax/ax", topdown=False): 17 | if len(glob(f"{root}/**/*.py", recursive=True)) > 0: 18 | with self.subTest(root): 19 | self.assertTrue( 20 | "__init__.py" in files, 21 | "directory " + root + " does not contain a .__init__.py file", 22 | ) 23 | -------------------------------------------------------------------------------- /ax/utils/testing/pyre_strict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | 9 | from ax.utils.common import testutils 10 | from ax.utils.testing.manifest import ModuleInfo, populate_test_class 11 | 12 | 13 | def test_pyre_strict(t: unittest.TestCase, m: ModuleInfo) -> None: 14 | with open(m.file) as fd: 15 | for line in fd: 16 | if line == "# pyre-strict\n" or line == "# no-strict-types\n": 17 | return 18 | raise Exception(f"{m.path}'s header should contain '# pyre-strict'") 19 | 20 | 21 | @populate_test_class(test_pyre_strict) 22 | class TestPyreStrict(testutils.TestCase): 23 | """ 24 | Test that all the files start are marked pyre strict. 25 | """ 26 | 27 | pass 28 | -------------------------------------------------------------------------------- /ax/plot/tests/test_parallel_coordinates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from ax.plot.base import AxPlotConfig, AxPlotTypes 8 | from ax.plot.parallel_coordinates import plot_parallel_coordinates 9 | from ax.utils.common.testutils import TestCase 10 | from ax.utils.testing.core_stubs import get_branin_experiment 11 | 12 | 13 | class ParallelCoordinatesTest(TestCase): 14 | def testParallelCoordinates(self): 15 | exp = get_branin_experiment(with_batch=True) 16 | exp.trials[0].run() 17 | 18 | # Assert that each type of plot can be constructed successfully 19 | plot = plot_parallel_coordinates(experiment=exp) 20 | 21 | self.assertIsInstance(plot, AxPlotConfig) 22 | self.assertEqual(plot.plot_type, AxPlotTypes.GENERIC) 23 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import NamedTuple 8 | 9 | from ax.utils.common.serialization import named_tuple_to_dict 10 | from ax.utils.common.testutils import TestCase 11 | 12 | 13 | class TestSerializationUtils(TestCase): 14 | def test_named_tuple_to_dict(self): 15 | class Foo(NamedTuple): 16 | x: int 17 | y: str 18 | 19 | foo = Foo(x=5, y="g") 20 | self.assertEqual(named_tuple_to_dict(foo), {"x": 5, "y": "g"}) 21 | 22 | bar = {"x": 5, "foo": foo, "y": [(1, True), foo]} 23 | self.assertEqual( 24 | named_tuple_to_dict(bar), 25 | {"x": 5, "foo": {"x": 5, "y": "g"}, "y": [(1, True), {"x": 5, "y": "g"}]}, 26 | ) 27 | -------------------------------------------------------------------------------- /ax/metrics/branin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from ax.metrics.noisy_function import NoisyFunctionMetric 9 | from ax.utils.common.typeutils import checked_cast 10 | from ax.utils.measurement.synthetic_functions import aug_branin, branin 11 | 12 | 13 | class BraninMetric(NoisyFunctionMetric): 14 | def f(self, x: np.ndarray) -> float: 15 | x1, x2 = x 16 | return checked_cast(float, branin(x1=x1, x2=x2)) 17 | 18 | 19 | class NegativeBraninMetric(BraninMetric): 20 | def f(self, x: np.ndarray) -> float: 21 | fpos = super().f(x) 22 | return -fpos 23 | 24 | 25 | class AugmentedBraninMetric(NoisyFunctionMetric): 26 | def f(self, x: np.ndarray) -> float: 27 | return checked_cast(float, aug_branin(x)) 28 | -------------------------------------------------------------------------------- /website/static/img/th-large-solid.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 13 | 14 | -------------------------------------------------------------------------------- /ax/benchmark/tests/test_torchvision_problem_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem 7 | from ax.storage.json_store.decoder import object_from_json 8 | from ax.storage.json_store.encoder import object_to_json 9 | from ax.utils.common.testutils import TestCase 10 | 11 | 12 | class TestProblems(TestCase): 13 | def test_encode_decode(self): 14 | original_object = PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name( 15 | name="MNIST" 16 | ) 17 | 18 | json_object = object_to_json( 19 | original_object, 20 | ) 21 | converted_object = object_from_json( 22 | json_object, 23 | ) 24 | 25 | self.assertEqual(original_object, converted_object) 26 | -------------------------------------------------------------------------------- /ax/exceptions/constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | TS_MIN_WEIGHT_ERROR = """ 8 | No arms generated by Thompson Sampling had weight > min_weight. 9 | The minimum weight required is {min_weight:2.4}, and the 10 | maximum weight of any arm generated is {max_weight:2.4}. 11 | """ 12 | 13 | TS_NO_FEASIBLE_ARMS_ERROR = """ 14 | Less than 1% of samples have a feasible arm. 15 | Check your outcome constraints. 16 | """ 17 | 18 | CHOLESKY_ERROR_ANNOTATION = ( 19 | "Cholesky errors typically occur when the same or very similar " 20 | "arms are suggested repeatedly. This can mean the model has " 21 | "already converged and you should avoid running further trials. " 22 | "It will also help to convert integer or categorical parameters " 23 | "to float ranges where reasonable.\nOriginal error: " 24 | ) 25 | -------------------------------------------------------------------------------- /scripts/wheels_build.ps1: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Python path to use for build, flag to upload file 7 | param( 8 | [string]$pypath = $("python3"), 9 | [switch]$upload = $false 10 | ) 11 | 12 | Set-Alias -Name mypy -Value $pypath 13 | Write-Host "Py Version:" 14 | mypy --version 15 | 16 | # Jump into the Ax repo folder 17 | pushd $PSScriptRoot\.. 18 | 19 | # Install or upgrade all the dependecies 20 | mypy -m pip install botorch jinja2 pandas scipy simplejson sklearn plotly numpy twine wheel 21 | mypy -m pip install --upgrade botorch jinja2 pandas scipy simplejson sklearn plotly numpy twine wheel 22 | 23 | # Let's build 24 | mypy ./setup.py bdist_wheel 25 | # Validate the build 26 | twine check dist/* 27 | 28 | # Final PyPI Upload 29 | If ($upload) { 30 | echo "Uploading" 31 | twine upload dist/* 32 | } 33 | 34 | # Done! 35 | popd 36 | -------------------------------------------------------------------------------- /ax/benchmark/tests/test_problems.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.benchmark.benchmark_result import AggregatedBenchmarkResult 7 | from ax.benchmark.problems.registry import ( 8 | BENCHMARK_PROBLEM_REGISTRY, 9 | get_problem_and_baseline, 10 | ) 11 | from ax.utils.common.testutils import TestCase 12 | 13 | 14 | class TestProblems(TestCase): 15 | def test_load_baselines(self): 16 | 17 | # Make sure the json parsing suceeded 18 | for name in BENCHMARK_PROBLEM_REGISTRY.keys(): 19 | if "MNIST" in name: 20 | continue # Skip these as they cause the test to take a long time 21 | 22 | problem, baseline = get_problem_and_baseline(problem_name=name) 23 | 24 | self.assertTrue(isinstance(baseline, AggregatedBenchmarkResult)) 25 | self.assertIn(problem.name, baseline.name) 26 | -------------------------------------------------------------------------------- /ax/modelbridge/tests/test_transform_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.modelbridge.transforms.utils import ClosestLookupDict 8 | from ax.utils.common.testutils import TestCase 9 | 10 | 11 | class TransformUtilsTest(TestCase): 12 | def test_closest_lookup_dict(self): 13 | # test empty lookup 14 | d = ClosestLookupDict() 15 | with self.assertRaises(RuntimeError): 16 | d[0] 17 | # basic test 18 | keys = (1.0, 2, 4) 19 | vals = ("a", "b", "c") 20 | d = ClosestLookupDict(zip(keys, vals)) 21 | for k, v in zip(keys, vals): 22 | self.assertEqual(d[k], v) 23 | self.assertEqual(d[2.5], "b") 24 | self.assertEqual(d[0], "a") 25 | self.assertEqual(d[6], "c") 26 | with self.assertRaises(ValueError): 27 | d["str_key"] = 3 28 | -------------------------------------------------------------------------------- /ax/metrics/tests/test_tensorboard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from unittest import mock 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from ax.metrics.tensorboard import TensorboardCurveMetric 12 | from ax.utils.common.testutils import TestCase 13 | 14 | 15 | class TensorboardCurveMetricTest(TestCase): 16 | def test_GetCurvesFromIds(self): 17 | def mock_get_tb_from_posix(path): 18 | return pd.Series([int(path)] * 2) 19 | 20 | mock_path = "ax.metrics.tensorboard.get_tb_from_posix" 21 | with mock.patch(mock_path, side_effect=mock_get_tb_from_posix): 22 | out = TensorboardCurveMetric.get_curves_from_ids(["1", "2"]) 23 | self.assertEqual(len(out), 2) 24 | self.assertTrue(np.array_equal(out["1"].values, np.array([1, 1]))) 25 | self.assertTrue(np.array_equal(out["2"].values, np.array([2, 2]))) 26 | -------------------------------------------------------------------------------- /ax/benchmark/methods/choose_generation_strategy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.benchmark.benchmark_method import BenchmarkMethod 7 | from ax.benchmark.benchmark_problem import BenchmarkProblem 8 | from ax.modelbridge.dispatch_utils import choose_generation_strategy 9 | from ax.service.scheduler import SchedulerOptions 10 | 11 | 12 | def get_choose_generation_strategy_method( 13 | problem: BenchmarkProblem, num_trials: int = 30 14 | ) -> BenchmarkMethod: 15 | generation_strategy = choose_generation_strategy( 16 | search_space=problem.search_space, 17 | optimization_config=problem.optimization_config, 18 | num_trials=num_trials, 19 | ) 20 | 21 | return BenchmarkMethod( 22 | name=f"ChooseGenerationStrategy::{problem.name}", 23 | generation_strategy=generation_strategy, 24 | scheduler_options=SchedulerOptions(total_trials=num_trials), 25 | ) 26 | -------------------------------------------------------------------------------- /ax/core/map_metric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Type 10 | 11 | from ax.core.map_data import MapData 12 | from ax.core.metric import Metric 13 | 14 | 15 | class MapMetric(Metric): 16 | """Base class for representing metrics that return `MapData`. 17 | 18 | The `fetch_trial_data` method is the essential method to override when 19 | subclassing, which specifies how to retrieve a Metric, for a given trial. 20 | 21 | A MapMetric must return a MapData object, which requires (at minimum) the following: 22 | https://ax.dev/api/_modules/ax/core/data.html#Data.required_columns 23 | 24 | Attributes: 25 | lower_is_better: Flag for metrics which should be minimized. 26 | properties: Properties specific to a particular metric. 27 | """ 28 | 29 | data_constructor: Type[MapData] = MapData 30 | -------------------------------------------------------------------------------- /ax/modelbridge/tests/test_rounding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from ax.modelbridge.transforms.rounding import ( 9 | randomized_onehot_round, 10 | strict_onehot_round, 11 | ) 12 | from ax.utils.common.testutils import TestCase 13 | 14 | 15 | class RoundingTest(TestCase): 16 | def setUp(self): 17 | pass 18 | 19 | def testOneHotRound(self): 20 | self.assertTrue( 21 | np.allclose( 22 | strict_onehot_round(np.array([0.1, 0.5, 0.3])), np.array([0, 1, 0]) 23 | ) 24 | ) 25 | # One item should be set to one at random. 26 | self.assertEqual( 27 | np.count_nonzero( 28 | np.isclose( 29 | randomized_onehot_round(np.array([0.0, 0.0, 0.0])), 30 | np.array([1, 1, 1]), 31 | ) 32 | ), 33 | 1, 34 | ) 35 | -------------------------------------------------------------------------------- /ax/early_stopping/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.early_stopping.strategies.base import ( 8 | BaseEarlyStoppingStrategy, 9 | EarlyStoppingTrainingData, 10 | ModelBasedEarlyStoppingStrategy, 11 | ) 12 | from ax.early_stopping.strategies.logical import ( 13 | AndEarlyStoppingStrategy, 14 | LogicalEarlyStoppingStrategy, 15 | OrEarlyStoppingStrategy, 16 | ) 17 | from ax.early_stopping.strategies.percentile import PercentileEarlyStoppingStrategy 18 | from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy 19 | 20 | 21 | __all__ = [ 22 | "BaseEarlyStoppingStrategy", 23 | "EarlyStoppingTrainingData", 24 | "ModelBasedEarlyStoppingStrategy", 25 | "PercentileEarlyStoppingStrategy", 26 | "ThresholdEarlyStoppingStrategy", 27 | "AndEarlyStoppingStrategy", 28 | "OrEarlyStoppingStrategy", 29 | "LogicalEarlyStoppingStrategy", 30 | ] 31 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/timestamp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import datetime 8 | from typing import Optional 9 | 10 | from sqlalchemy.engine.interfaces import Dialect 11 | from sqlalchemy.types import Integer, TypeDecorator 12 | 13 | 14 | class IntTimestamp(TypeDecorator): 15 | impl = Integer 16 | cache_ok = True 17 | 18 | # pyre-fixme[15]: `process_bind_param` overrides method defined in 19 | # `TypeDecorator` inconsistently. 20 | def process_bind_param( 21 | self, value: Optional[datetime.datetime], dialect: Dialect 22 | ) -> Optional[int]: 23 | if value is None: 24 | return None # pragma: no cover 25 | else: 26 | return int(value.timestamp()) 27 | 28 | def process_result_value( 29 | self, value: Optional[int], dialect: Dialect 30 | ) -> Optional[datetime.datetime]: 31 | return None if value is None else datetime.datetime.fromtimestamp(value) 32 | -------------------------------------------------------------------------------- /ax/modelbridge/tests/test_base_transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from unittest.mock import MagicMock 8 | 9 | from ax.modelbridge.transforms.base import Transform 10 | from ax.utils.common.testutils import TestCase 11 | 12 | 13 | class TransformsTest(TestCase): 14 | def testIdentityTransform(self): 15 | # Test that the identity transform does not mutate anything 16 | t = Transform(MagicMock(), MagicMock(), MagicMock()) 17 | x = MagicMock() 18 | ys = [] 19 | ys.append(t.transform_search_space(x)) 20 | ys.append(t.transform_optimization_config(x, x, x)) 21 | ys.append(t.transform_observation_features(x)) 22 | ys.append(t.transform_observation_data(x, x)) 23 | ys.append(t.untransform_observation_features(x)) 24 | ys.append(t.untransform_observation_data(x, x)) 25 | self.assertEqual(len(x.mock_calls), 0) 26 | for y in ys: 27 | self.assertEqual(y, x) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_docutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.utils.common.docutils import copy_doc 8 | from ax.utils.common.testutils import TestCase 9 | 10 | 11 | def has_doc(): 12 | """I have a docstring""" 13 | 14 | 15 | def has_no_doc(): 16 | pass 17 | 18 | 19 | class TestDocUtils(TestCase): 20 | def test_transfer_doc(self): 21 | @copy_doc(has_doc) 22 | def inherits_doc(): 23 | pass 24 | 25 | self.assertEqual(inherits_doc.__doc__, "I have a docstring") 26 | 27 | def test_fail_when_already_has_doc(self): 28 | with self.assertRaises(ValueError): 29 | 30 | @copy_doc(has_doc) 31 | def inherits_doc(): 32 | """I already have a doc string""" 33 | pass 34 | 35 | def test_fail_when_no_doc_to_copy(self): 36 | with self.assertRaises(ValueError): 37 | 38 | @copy_doc(has_no_doc) 39 | def f(): 40 | pass 41 | -------------------------------------------------------------------------------- /ax/storage/json_store/load.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | from typing import Any, Callable, Dict, Type 9 | 10 | from ax.core.experiment import Experiment 11 | from ax.storage.json_store.decoder import object_from_json 12 | from ax.storage.json_store.registry import ( 13 | CORE_CLASS_DECODER_REGISTRY, 14 | CORE_DECODER_REGISTRY, 15 | ) 16 | 17 | 18 | def load_experiment( 19 | filepath: str, 20 | decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, 21 | class_decoder_registry: Dict[ 22 | str, Callable[[Dict[str, Any]], Any] 23 | ] = CORE_CLASS_DECODER_REGISTRY, 24 | ) -> Experiment: 25 | """Load experiment from file. 26 | 27 | 1) Read file. 28 | 2) Convert dictionary to Ax experiment instance. 29 | """ 30 | with open(filepath, "r") as file: 31 | json_experiment = json.loads(file.read()) 32 | return object_from_json( 33 | json_experiment, decoder_registry, class_decoder_registry 34 | ) 35 | -------------------------------------------------------------------------------- /ax/models/tests/test_alebo_initializer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from ax.models.random.alebo_initializer import ALEBOInitializer 9 | from ax.utils.common.testutils import TestCase 10 | 11 | 12 | class ALEBOSobolTest(TestCase): 13 | def testALEBOSobolModel(self): 14 | B = np.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]) 15 | Q = np.linalg.pinv(B) @ B 16 | # Test setting attributes 17 | m = ALEBOInitializer(B=B) 18 | self.assertTrue(np.allclose(Q, m.Q)) 19 | 20 | # Test gen 21 | Z, w = m.gen(5, bounds=[(-1.0, 1.0)] * 3) 22 | self.assertEqual(Z.shape, (5, 3)) 23 | self.assertTrue(Z.min() >= -1.0) 24 | self.assertTrue(Z.max() <= 1.0) 25 | # Verify that it is in the subspace 26 | self.assertTrue(np.allclose(Q @ Z.transpose(), Z.transpose())) 27 | 28 | m = ALEBOInitializer(B=B, nsamp=1) 29 | with self.assertRaises(ValueError): 30 | m.gen(2, bounds=[(-1.0, 1.0)] * 3) 31 | -------------------------------------------------------------------------------- /ax/benchmark/benchmark_method.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass 7 | 8 | from ax.exceptions.core import UserInputError 9 | from ax.modelbridge.generation_strategy import GenerationStrategy 10 | from ax.service.utils.scheduler_options import SchedulerOptions 11 | from ax.utils.common.base import Base 12 | 13 | 14 | @dataclass(frozen=True) 15 | class BenchmarkMethod(Base): 16 | """Benchmark method, represented in terms of Ax generation strategy (which tells us 17 | which models to use when) and scheduler options (which tell us extra execution 18 | information like maximum parallelism, early stopping configuration, etc.) 19 | """ 20 | 21 | name: str 22 | generation_strategy: GenerationStrategy 23 | scheduler_options: SchedulerOptions 24 | 25 | def __post_init__(self) -> None: 26 | if self.scheduler_options.total_trials is None: 27 | raise UserInputError( 28 | "SchedulerOptions.total_trials may not be None in BenchmarkMethod." 29 | ) 30 | -------------------------------------------------------------------------------- /sphinx/source/early_stopping.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.early_stopping 5 | =================================== 6 | 7 | .. automodule:: ax.early_stopping 8 | .. currentmodule:: ax.early_stopping 9 | 10 | Strategies 11 | ---------- 12 | 13 | Base Strategies 14 | ~~~~~~~~~~~~~~~ 15 | 16 | .. automodule:: ax.early_stopping.strategies.base 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | 21 | Logical Strategies 22 | ~~~~~~~~~~~~~~~~~~ 23 | 24 | .. automodule:: ax.early_stopping.strategies.logical 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | `PercentileEarlyStoppingStrategy` 30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 31 | 32 | .. automodule:: ax.early_stopping.strategies.percentile 33 | :members: 34 | :undoc-members: 35 | :show-inheritance: 36 | 37 | `ThresholdEarlyStoppingStrategy` 38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 39 | 40 | .. automodule:: ax.early_stopping.strategies.threshold 41 | :members: 42 | :undoc-members: 43 | :show-inheritance: 44 | 45 | Utils 46 | ----- 47 | 48 | .. automodule:: ax.early_stopping.utils 49 | :members: 50 | :undoc-members: 51 | :show-inheritance: 52 | -------------------------------------------------------------------------------- /sphinx/source/exceptions.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.exceptions 5 | =================================== 6 | 7 | .. automodule:: ax.exceptions 8 | .. currentmodule:: ax.exceptions 9 | 10 | 11 | Constants 12 | ~~~~~~~~~~~~~~~~~~ 13 | 14 | .. automodule:: ax.exceptions.constants 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | 20 | Core 21 | ~~~~~~~~~~~~~~~~~~ 22 | 23 | .. automodule:: ax.exceptions.core 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Data 29 | ~~~~~~~~~~~~~~~~~~ 30 | 31 | .. automodule:: ax.exceptions.data_provider 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Generation Strategy 37 | ~~~~~~~~~~~~~~~~~~~~ 38 | 39 | .. automodule:: ax.exceptions.generation_strategy 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | Model 45 | ~~~~~~~~~~~~~~~~~~ 46 | 47 | .. automodule:: ax.exceptions.model 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | Storage 53 | ~~~~~~~~~~~~~~~~~~ 54 | 55 | .. automodule:: ax.exceptions.storage 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | -------------------------------------------------------------------------------- /ax/models/tests/test_rembo_initializer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from ax.models.random.rembo_initializer import REMBOInitializer 9 | from ax.utils.common.testutils import TestCase 10 | 11 | 12 | class REMBOInitializerTest(TestCase): 13 | def testREMBOInitializerModel(self): 14 | A = np.vstack((np.eye(2, 2), -(np.eye(2, 2)))) 15 | # Test setting attributes 16 | m = REMBOInitializer(A=A, bounds_d=[(-2, 2)] * 2) 17 | self.assertTrue(np.allclose(A, m.A)) 18 | self.assertEqual(m.bounds_d, [(-2, 2), (-2, 2)]) 19 | 20 | # Test project up 21 | Z = m.project_up(5 * np.random.rand(3, 2)) 22 | self.assertEqual(Z.shape, (3, 4)) 23 | self.assertTrue(Z.min() >= -1.0) 24 | self.assertTrue(Z.max() <= 1.0) 25 | 26 | # Test gen 27 | Z, w = m.gen(3, bounds=[(-1.0, 1.0)] * 4) 28 | self.assertEqual(Z.shape, (3, 4)) 29 | self.assertTrue(Z.min() >= -1.0) 30 | self.assertTrue(Z.max() <= 1.0) 31 | -------------------------------------------------------------------------------- /ax/plot/css/base.css: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | .hovertext { 9 | opacity: 0.85; 10 | } 11 | 12 | .plot-menu div { 13 | margin-top: 5px; 14 | margin-bottom: 5px; 15 | } 16 | 17 | .plot-menu select { 18 | height: 30px; 19 | padding: 6px 10px; 20 | background-color: #fff; 21 | border: 1px solid #D1D1D1; 22 | border-radius: 4px; 23 | box-shadow: none; 24 | box-sizing: border-box; 25 | margin-left: 15px; 26 | margin-bottom: 0px; 27 | } 28 | 29 | .plot-menu select:focus { 30 | border: 1px solid #33C3F0; 31 | outline: 0; 32 | } 33 | 34 | .plot-menu button { 35 | display: inline-block; 36 | height: 30px; 37 | padding: 0 10px; 38 | color: #555; 39 | text-align: center; 40 | font-size: 12px; 41 | font-weight: 600; 42 | text-transform: uppercase; 43 | text-decoration: none; 44 | white-space: nowrap; 45 | background-color: transparent; 46 | border-radius: 4px; 47 | border: 1px solid #bbb; 48 | cursor: pointer; 49 | box-sizing: border-box; 50 | background-color: #fff; 51 | margin-left: 15px; 52 | } 53 | -------------------------------------------------------------------------------- /ax/utils/report/tests/test_render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.utils.common.testutils import TestCase 8 | from ax.utils.report.render import ( 9 | h2_html, 10 | h3_html, 11 | link_html, 12 | list_item_html, 13 | p_html, 14 | render_report_elements, 15 | table_cell_html, 16 | table_heading_cell_html, 17 | table_html, 18 | table_row_html, 19 | unordered_list_html, 20 | ) 21 | 22 | 23 | class RenderTest(TestCase): 24 | def testRenderReportElements(self): 25 | elements = [ 26 | p_html("foobar"), 27 | h2_html("foobar"), 28 | h3_html("foobar"), 29 | list_item_html("foobar"), 30 | unordered_list_html(["foo", "bar"]), 31 | link_html("foo", "bar"), 32 | table_cell_html("foobar"), 33 | table_cell_html("foobar", width="100px"), 34 | table_heading_cell_html("foobar"), 35 | table_row_html(["foo", "bar"]), 36 | table_html(["foo", "bar"]), 37 | ] 38 | render_report_elements("test", elements) 39 | -------------------------------------------------------------------------------- /ax/exceptions/storage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from ax.exceptions.core import AxError 9 | 10 | 11 | class JSONDecodeError(AxError): 12 | """Raised when an error occurs during JSON decoding.""" 13 | 14 | pass 15 | 16 | 17 | class JSONEncodeError(AxError): 18 | """Raised when an error occurs during JSON encoding.""" 19 | 20 | pass 21 | 22 | 23 | class SQADecodeError(AxError): 24 | """Raised when an error occurs during SQA decoding.""" 25 | 26 | pass 27 | 28 | 29 | class SQAEncodeError(AxError): 30 | """Raised when an error occurs during SQA encoding.""" 31 | 32 | pass 33 | 34 | 35 | class ImmutabilityError(AxError): 36 | """Raised when an attempt is made to update an immutable object.""" 37 | 38 | pass 39 | 40 | 41 | class IncorrectDBConfigurationError(AxError): 42 | """Raised when an attempt is made to save and load an object, but 43 | the current engine and session factory is setup up incorrectly to 44 | process the call (e.g. current session factory will connect to a 45 | wrong database for the call). 46 | """ 47 | 48 | pass 49 | -------------------------------------------------------------------------------- /ax/core/tests/test_map_metric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.core.map_metric import MapMetric 8 | from ax.utils.common.testutils import TestCase 9 | 10 | 11 | METRIC_STRING = "MapMetric('m1')" 12 | 13 | 14 | class MapMetricTest(TestCase): 15 | def setUp(self): 16 | pass 17 | 18 | def testInit(self): 19 | metric = MapMetric(name="m1", lower_is_better=False) 20 | self.assertEqual(str(metric), METRIC_STRING) 21 | 22 | def testEq(self): 23 | metric1 = MapMetric(name="m1", lower_is_better=False) 24 | metric2 = MapMetric(name="m1", lower_is_better=False) 25 | self.assertEqual(metric1, metric2) 26 | 27 | metric3 = MapMetric(name="m1", lower_is_better=True) 28 | self.assertNotEqual(metric1, metric3) 29 | 30 | def testClone(self): 31 | metric1 = MapMetric(name="m1", lower_is_better=False) 32 | self.assertEqual(metric1, metric1.clone()) 33 | 34 | def testSortable(self): 35 | metric1 = MapMetric(name="m1", lower_is_better=False) 36 | metric2 = MapMetric(name="m2", lower_is_better=False) 37 | self.assertTrue(metric1 < metric2) 38 | -------------------------------------------------------------------------------- /ax/runners/synthetic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, Iterable, Optional, Set 8 | 9 | from ax.core.base_trial import BaseTrial, TrialStatus 10 | from ax.core.runner import Runner 11 | 12 | 13 | class SyntheticRunner(Runner): 14 | """Class for synthetic or dummy runner. 15 | 16 | Currently acts as a shell runner, only creating a name. 17 | """ 18 | 19 | def __init__(self, dummy_metadata: Optional[str] = None) -> None: 20 | self.dummy_metadata = dummy_metadata 21 | 22 | def run(self, trial: BaseTrial) -> Dict[str, Any]: 23 | deployed_name = ( 24 | trial.experiment.name + "_" + str(trial.index) 25 | if trial.experiment.has_name 26 | else str(trial.index) 27 | ) 28 | metadata = {"name": deployed_name} 29 | 30 | # Add dummy metadata if needed for testing 31 | if self.dummy_metadata: 32 | metadata["dummy_metadata"] = self.dummy_metadata 33 | return metadata 34 | 35 | def poll_trial_status( 36 | self, trials: Iterable[BaseTrial] 37 | ) -> Dict[TrialStatus, Set[int]]: 38 | return {TrialStatus.COMPLETED: {t.index for t in trials}} 39 | -------------------------------------------------------------------------------- /ax/plot/tests/test_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from ax.plot.helper import arm_name_to_sort_key, extend_range 9 | from ax.utils.common.testutils import TestCase 10 | 11 | 12 | class HelperTest(TestCase): 13 | def test_extend_range(self): 14 | with self.assertRaises(ValueError): 15 | extend_range(lower=1, upper=-1) 16 | self.assertEqual(extend_range(lower=-1, upper=1), (-1.2, 1.2)) 17 | self.assertEqual(extend_range(lower=-1, upper=0, percent=30), (-1.3, 0.3)) 18 | self.assertEqual(extend_range(lower=0, upper=1, percent=50), (-0.5, 1.5)) 19 | 20 | def test_arm_name_to_sort_key(self): 21 | arm_names = ["0_0", "1_10", "1_2", "10_0", "control"] 22 | sorted_names = sorted(arm_names, key=arm_name_to_sort_key, reverse=True) 23 | expected = ["control", "0_0", "1_2", "1_10", "10_0"] 24 | self.assertEqual(sorted_names, expected) 25 | 26 | arm_names = ["0_0", "0", "1_10", "3_2_x", "3_x", "1_2", "control"] 27 | sorted_names = sorted(arm_names, key=arm_name_to_sort_key, reverse=True) 28 | expected = ["control", "3_x", "3_2_x", "0", "0_0", "1_2", "1_10"] 29 | self.assertEqual(sorted_names, expected) 30 | -------------------------------------------------------------------------------- /ax/storage/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import enum 8 | 9 | from ax.core.experiment import DataType # noqa F401 10 | 11 | 12 | class DomainType(enum.Enum): 13 | """Class for enumerating domain types.""" 14 | 15 | FIXED: int = 0 16 | RANGE: int = 1 17 | CHOICE: int = 2 18 | 19 | 20 | class MetricIntent(enum.Enum): 21 | """Class for enumerating metric use types.""" 22 | 23 | OBJECTIVE: str = "objective" 24 | MULTI_OBJECTIVE: str = "multi_objective" 25 | SCALARIZED_OBJECTIVE: str = "scalarized_objective" 26 | # Additional objective is not yet supported in Ax open-source. 27 | ADDITIONAL_OBJECTIVE: str = "additional_objective" 28 | OUTCOME_CONSTRAINT: str = "outcome_constraint" 29 | SCALARIZED_OUTCOME_CONSTRAINT: str = "scalarized_outcome_constraint" 30 | OBJECTIVE_THRESHOLD: str = "objective_threshold" 31 | TRACKING: str = "tracking" 32 | 33 | 34 | class ParameterConstraintType(enum.Enum): 35 | """Class for enumerating parameter constraint types. 36 | 37 | Linear constraint is base type whereas other constraint types are 38 | special types of linear constraints. 39 | """ 40 | 41 | LINEAR: int = 0 42 | ORDER: int = 1 43 | SUM: int = 2 44 | -------------------------------------------------------------------------------- /ax/plot/tests/test_diagnostic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import plotly.graph_objects as go 8 | from ax.modelbridge.cross_validation import cross_validate 9 | from ax.modelbridge.registry import Models 10 | from ax.plot.base import AxPlotConfig 11 | from ax.plot.diagnostic import ( 12 | interact_cross_validation, 13 | interact_cross_validation_plotly, 14 | ) 15 | from ax.utils.common.testutils import TestCase 16 | from ax.utils.testing.core_stubs import get_branin_experiment 17 | from ax.utils.testing.mock import fast_botorch_optimize 18 | 19 | 20 | class DiagnosticTest(TestCase): 21 | @fast_botorch_optimize 22 | def test_cross_validation(self): 23 | exp = get_branin_experiment(with_batch=True) 24 | exp.trials[0].run() 25 | model = Models.BOTORCH( 26 | # Model bridge kwargs 27 | experiment=exp, 28 | data=exp.fetch_data(), 29 | ) 30 | cv = cross_validate(model) 31 | # Assert that each type of plot can be constructed successfully 32 | plot = interact_cross_validation_plotly(cv) 33 | self.assertIsInstance(plot, go.Figure) 34 | plot = interact_cross_validation(cv) 35 | self.assertIsInstance(plot, AxPlotConfig) 36 | -------------------------------------------------------------------------------- /ax/utils/report/resources/base_template.html: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | {{experiment_name}} 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | {% if headfoot %} 16 |
17 |

18 | {% if trial_index is none %} 19 | Report for {{experiment_name}} 20 | {% else %} 21 | Report for {{ batch_noun }} #{{ trial_index }} of 22 | {{experiment_name}} 23 | {% endif %} 24 |

25 |

Powered by Ax

26 |
27 | {% endif %} 28 |
29 | {% block content %} 30 | {% endblock %} 31 |
32 | {% if headfoot %} 33 | 34 | {% endif %} 35 |
36 | 37 | 38 | -------------------------------------------------------------------------------- /website/static/img/dice-solid.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 16 | 17 | -------------------------------------------------------------------------------- /ax/models/random/uniform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import numpy as np 10 | from ax.models.random.base import RandomModel 11 | from scipy.stats import uniform 12 | 13 | 14 | class UniformGenerator(RandomModel): 15 | """This class specifies a uniform random generation algorithm. 16 | 17 | As a uniform generator does not make use of a model, it does not implement 18 | the fit or predict methods. 19 | 20 | Attributes: 21 | seed: An optional seed value for the underlying PRNG. 22 | 23 | """ 24 | 25 | def __init__(self, deduplicate: bool = False, seed: Optional[int] = None) -> None: 26 | super().__init__(deduplicate=deduplicate, seed=seed) 27 | self._rs = np.random.RandomState(seed=seed) 28 | 29 | def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray: 30 | """Generate samples from the scipy uniform distribution. 31 | 32 | Args: 33 | n: Number of samples to generate. 34 | tunable_d: Dimension of samples to generate. 35 | 36 | Returns: 37 | samples: An (n x d) array of random points. 38 | 39 | """ 40 | return uniform.rvs(size=(n, tunable_d), random_state=self._rs) # pyre-ignore 41 | -------------------------------------------------------------------------------- /ax/benchmark/problems/hd_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import asdict 7 | 8 | from ax.benchmark.benchmark_problem import BenchmarkProblem 9 | from ax.core.parameter import ParameterType, RangeParameter 10 | from ax.core.search_space import SearchSpace 11 | 12 | 13 | def embed_higher_dimension( 14 | problem: BenchmarkProblem, total_dimensionality: int 15 | ) -> BenchmarkProblem: 16 | num_dummy_dimensions = total_dimensionality - len(problem.search_space.parameters) 17 | 18 | search_space = SearchSpace( 19 | parameters=[ 20 | *problem.search_space.parameters.values(), 21 | *[ 22 | RangeParameter( 23 | name=f"embedding_dummy_{i}", 24 | parameter_type=ParameterType.FLOAT, 25 | lower=0, 26 | upper=1, 27 | ) 28 | for i in range(num_dummy_dimensions) 29 | ], 30 | ], 31 | parameter_constraints=problem.search_space.parameter_constraints, 32 | ) 33 | 34 | problem_kwargs = asdict(problem) 35 | problem_kwargs["name"] = f"{problem_kwargs['name']}_{total_dimensionality}d" 36 | problem_kwargs["search_space"] = search_space 37 | 38 | return problem.__class__(**problem_kwargs) 39 | -------------------------------------------------------------------------------- /ax/service/utils/early_stopping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Dict, Optional, Set 7 | 8 | from ax.core.experiment import Experiment 9 | from ax.early_stopping.strategies import BaseEarlyStoppingStrategy 10 | from ax.utils.common.typeutils import not_none 11 | 12 | 13 | def should_stop_trials_early( 14 | early_stopping_strategy: Optional[BaseEarlyStoppingStrategy], 15 | trial_indices: Set[int], 16 | experiment: Experiment, 17 | ) -> Dict[int, Optional[str]]: 18 | """Evaluate whether to early-stop running trials. 19 | 20 | Args: 21 | early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines 22 | whether a trial should be stopped given the state of an experiment. 23 | trial_indices: Indices of trials to consider for early stopping. 24 | experiment: The experiment containing the trials. 25 | 26 | Returns: 27 | A dictionary mapping trial indices that should be early stopped to 28 | (optional) messages with the associated reason. 29 | """ 30 | if early_stopping_strategy is None: 31 | return {} 32 | 33 | early_stopping_strategy = not_none(early_stopping_strategy) 34 | return early_stopping_strategy.should_stop_trials_early( 35 | trial_indices=trial_indices, experiment=experiment 36 | ) 37 | -------------------------------------------------------------------------------- /ax/utils/notebook/plotting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.plot.base import AxPlotConfig, AxPlotTypes 8 | from ax.plot.render import _js_requires, _wrap_js, plot_config_to_html 9 | from ax.utils.common.logger import get_logger 10 | from IPython.display import display 11 | from plotly.offline import init_notebook_mode, iplot 12 | 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def init_notebook_plotting(offline=False): 18 | """Initialize plotting in notebooks, either in online or offline mode.""" 19 | display_bundle = {"text/html": _wrap_js(_js_requires(offline=offline))} 20 | display(display_bundle, raw=True) 21 | logger.info("Injecting Plotly library into cell. Do not overwrite or delete cell.") 22 | init_notebook_mode() 23 | 24 | 25 | def render(plot_config: AxPlotConfig, inject_helpers=False) -> None: 26 | """Render plot config.""" 27 | if plot_config.plot_type == AxPlotTypes.GENERIC: 28 | iplot(plot_config.data) 29 | elif plot_config.plot_type == AxPlotTypes.HTML: 30 | assert "text/html" in plot_config.data 31 | display(plot_config.data, raw=True) 32 | else: 33 | display_bundle = { 34 | "text/html": plot_config_to_html(plot_config, inject_helpers=inject_helpers) 35 | } 36 | display(display_bundle, raw=True) 37 | -------------------------------------------------------------------------------- /ax/storage/json_store/save.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | from typing import Any, Callable, Dict, Type 9 | 10 | from ax.core.experiment import Experiment 11 | from ax.storage.json_store.encoder import object_to_json 12 | from ax.storage.json_store.registry import ( 13 | CORE_CLASS_ENCODER_REGISTRY, 14 | CORE_ENCODER_REGISTRY, 15 | ) 16 | 17 | 18 | def save_experiment( 19 | experiment: Experiment, 20 | filepath: str, 21 | encoder_registry: Dict[ 22 | Type, Callable[[Any], Dict[str, Any]] 23 | ] = CORE_ENCODER_REGISTRY, 24 | class_encoder_registry: Dict[ 25 | Type, Callable[[Any], Dict[str, Any]] 26 | ] = CORE_CLASS_ENCODER_REGISTRY, 27 | ) -> None: 28 | """Save experiment to file. 29 | 30 | 1) Convert Ax experiment to JSON-serializable dictionary. 31 | 2) Write to file. 32 | """ 33 | if not isinstance(experiment, Experiment): 34 | raise ValueError("Can only save instances of Experiment") 35 | 36 | if not filepath.endswith(".json"): 37 | raise ValueError("Filepath must end in .json") 38 | 39 | json_experiment = object_to_json( 40 | experiment, 41 | encoder_registry=encoder_registry, 42 | class_encoder_registry=class_encoder_registry, 43 | ) 44 | with open(filepath, "w+") as file: 45 | file.write(json.dumps(json_experiment)) 46 | -------------------------------------------------------------------------------- /ax/utils/common/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import annotations 8 | 9 | import abc 10 | from typing import Optional 11 | 12 | from ax.utils.common.equality import equality_typechecker, object_attribute_dicts_equal 13 | 14 | 15 | class Base: 16 | """Metaclass for core Ax classes. Provides an equality check and `db_id` 17 | property for SQA storage. 18 | """ 19 | 20 | _db_id: Optional[int] = None 21 | 22 | @property 23 | def db_id(self) -> Optional[int]: 24 | return self._db_id 25 | 26 | @db_id.setter 27 | def db_id(self, db_id: int) -> None: 28 | self._db_id = db_id 29 | 30 | @equality_typechecker 31 | def __eq__(self, other: Base) -> bool: 32 | return object_attribute_dicts_equal( 33 | one_dict=self.__dict__, other_dict=other.__dict__ 34 | ) 35 | 36 | 37 | class SortableBase(Base, metaclass=abc.ABCMeta): 38 | """Extension to the base class that also provides an inequality check.""" 39 | 40 | @property 41 | @abc.abstractmethod 42 | def _unique_id(self) -> str: 43 | """Returns an identification string that can be used to uniquely 44 | identify this instance from others attached to the same experiment. 45 | """ 46 | pass 47 | 48 | def __lt__(self, other: SortableBase) -> bool: 49 | return self._unique_id < other._unique_id 50 | -------------------------------------------------------------------------------- /ax/models/tests/test_randomforest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from ax.core.search_space import SearchSpaceDigest 9 | from ax.models.torch.randomforest import RandomForest 10 | from ax.utils.common.testutils import TestCase 11 | from botorch.utils.datasets import FixedNoiseDataset 12 | 13 | 14 | class RandomForestTest(TestCase): 15 | def testRFModel(self): 16 | datasets = [ 17 | FixedNoiseDataset( 18 | X=torch.rand(10, 2), Y=torch.rand(10, 1), Yvar=torch.rand(10, 1) 19 | ) 20 | for _ in range(2) 21 | ] 22 | 23 | m = RandomForest(num_trees=5) 24 | m.fit( 25 | datasets=datasets, 26 | metric_names=["y1", "y2"], 27 | search_space_digest=SearchSpaceDigest( 28 | feature_names=["x1", "x2"], 29 | bounds=[(0, 1)] * 2, 30 | ), 31 | ) 32 | self.assertEqual(len(m.models), 2) 33 | self.assertEqual(len(m.models[0].estimators_), 5) 34 | 35 | f, cov = m.predict(torch.rand(5, 2)) 36 | self.assertEqual(f.shape, torch.Size((5, 2))) 37 | self.assertEqual(cov.shape, torch.Size((5, 2, 2))) 38 | 39 | f, cov = m.cross_validate(datasets=datasets, X_test=torch.rand(3, 2)) 40 | self.assertEqual(f.shape, torch.Size((3, 2))) 41 | self.assertEqual(cov.shape, torch.Size((3, 2, 2))) 42 | -------------------------------------------------------------------------------- /ax/service/tests/test_early_stopping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.service.utils import early_stopping as early_stopping_utils 7 | from ax.utils.common.testutils import TestCase 8 | from ax.utils.testing.core_stubs import ( 9 | DummyEarlyStoppingStrategy, 10 | get_branin_experiment, 11 | ) 12 | 13 | 14 | class TestEarlyStoppingUtils(TestCase): 15 | """Testing the early stopping utilities functionality that is not tested in 16 | main `AxClient` testing suite (`TestServiceAPI`).""" 17 | 18 | def setUp(self): 19 | self.branin_experiment = get_branin_experiment() 20 | 21 | def test_should_stop_trials_early(self): 22 | expected = { 23 | 1: "Stopped due to testing.", 24 | 3: "Stopped due to testing.", 25 | } 26 | actual = early_stopping_utils.should_stop_trials_early( 27 | early_stopping_strategy=DummyEarlyStoppingStrategy(expected), 28 | trial_indices=[1, 2, 3], 29 | experiment=self.branin_experiment, 30 | ) 31 | self.assertEqual(actual, expected) 32 | 33 | def test_should_stop_trials_early_no_strategy(self): 34 | actual = early_stopping_utils.should_stop_trials_early( 35 | early_stopping_strategy=None, 36 | trial_indices=[1, 2, 3], 37 | experiment=self.branin_experiment, 38 | ) 39 | expected = {} 40 | self.assertEqual(actual, expected) 41 | -------------------------------------------------------------------------------- /ax/core/tests/test_metric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.core.metric import Metric 8 | from ax.utils.common.testutils import TestCase 9 | from ax.utils.testing.core_stubs import get_branin_metric, get_factorial_metric 10 | 11 | 12 | METRIC_STRING = "Metric('m1')" 13 | 14 | 15 | class MetricTest(TestCase): 16 | def setUp(self): 17 | pass 18 | 19 | def testInit(self): 20 | metric = Metric(name="m1", lower_is_better=False) 21 | self.assertEqual(str(metric), METRIC_STRING) 22 | 23 | def testEq(self): 24 | metric1 = Metric(name="m1", lower_is_better=False) 25 | metric2 = Metric(name="m1", lower_is_better=False) 26 | self.assertEqual(metric1, metric2) 27 | 28 | metric3 = Metric(name="m1", lower_is_better=True) 29 | self.assertNotEqual(metric1, metric3) 30 | 31 | def testClone(self): 32 | metric1 = Metric(name="m1", lower_is_better=False) 33 | self.assertEqual(metric1, metric1.clone()) 34 | 35 | metric2 = get_branin_metric(name="branin") 36 | self.assertEqual(metric2, metric2.clone()) 37 | 38 | metric3 = get_factorial_metric(name="factorial") 39 | self.assertEqual(metric3, metric3.clone()) 40 | 41 | def testSortable(self): 42 | metric1 = Metric(name="m1", lower_is_better=False) 43 | metric2 = Metric(name="m2", lower_is_better=False) 44 | self.assertTrue(metric1 < metric2) 45 | -------------------------------------------------------------------------------- /ax/plot/tests/test_slices.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import plotly.graph_objects as go 8 | from ax.modelbridge.registry import Models 9 | from ax.plot.base import AxPlotConfig 10 | from ax.plot.slice import ( 11 | interact_slice, 12 | interact_slice_plotly, 13 | plot_slice, 14 | plot_slice_plotly, 15 | ) 16 | from ax.utils.common.testutils import TestCase 17 | from ax.utils.testing.core_stubs import get_branin_experiment 18 | from ax.utils.testing.mock import fast_botorch_optimize 19 | 20 | 21 | class SlicesTest(TestCase): 22 | @fast_botorch_optimize 23 | def testSlices(self): 24 | exp = get_branin_experiment(with_batch=True) 25 | exp.trials[0].run() 26 | model = Models.BOTORCH( 27 | # Model bridge kwargs 28 | experiment=exp, 29 | data=exp.fetch_data(), 30 | ) 31 | # Assert that each type of plot can be constructed successfully 32 | plot = plot_slice_plotly( 33 | model, model.parameters[0], list(model.metric_names)[0] 34 | ) 35 | self.assertIsInstance(plot, go.Figure) 36 | plot = interact_slice_plotly(model) 37 | self.assertIsInstance(plot, go.Figure) 38 | plot = plot_slice(model, model.parameters[0], list(model.metric_names)[0]) 39 | self.assertIsInstance(plot, AxPlotConfig) 40 | plot = interact_slice(model) 41 | self.assertIsInstance(plot, AxPlotConfig) 42 | -------------------------------------------------------------------------------- /ax/utils/testing/torch_stubs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Dict 10 | 11 | import torch 12 | 13 | 14 | def get_optimizer_kwargs() -> Dict[str, int]: 15 | return {"num_restarts": 2, "raw_samples": 2, "maxiter": 2, "batch_limit": 1} 16 | 17 | 18 | def get_torch_test_data( 19 | dtype=torch.float, 20 | cuda: bool = False, 21 | constant_noise: bool = True, 22 | task_features=None, 23 | offset: float = 0.0, 24 | ): 25 | tkwargs = {"device": torch.device("cuda" if cuda else "cpu"), "dtype": dtype} 26 | Xs = [ 27 | torch.tensor( 28 | [ 29 | [1.0 + offset, 2.0 + offset, 3.0 + offset], 30 | [2.0 + offset, 3.0 + offset, 4.0 + offset], 31 | ], 32 | **tkwargs, 33 | ) 34 | ] 35 | Ys = [torch.tensor([[3.0 + offset], [4.0 + offset]], **tkwargs)] 36 | Yvars = [torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs)] 37 | if constant_noise: 38 | Yvars[0].fill_(1.0) 39 | bounds = [ 40 | (0.0 + offset, 1.0 + offset), 41 | (1.0 + offset, 4.0 + offset), 42 | (2.0 + offset, 5.0 + offset), 43 | ] 44 | feature_names = ["x1", "x2", "x3"] 45 | task_features = [] if task_features is None else task_features 46 | metric_names = ["y", "r"] 47 | return Xs, Ys, Yvars, bounds, task_features, feature_names, metric_names 48 | -------------------------------------------------------------------------------- /ax/utils/common/timeutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from datetime import datetime, timedelta 8 | from time import time 9 | from typing import Generator 10 | 11 | import pandas as pd 12 | 13 | 14 | DS_FRMT = "%Y-%m-%d" # Format to use for parsing DS strings. 15 | 16 | 17 | def to_ds(ts: datetime) -> str: 18 | """Convert a `datetime` to a DS string.""" 19 | return datetime.strftime(ts, DS_FRMT) 20 | 21 | 22 | def to_ts(ds: str) -> datetime: 23 | """Convert a DS string to a `datetime`.""" 24 | return datetime.strptime(ds, DS_FRMT) 25 | 26 | 27 | def _ts_to_pandas(ts: int) -> pd.Timestamp: 28 | """Convert int timestamp into pandas timestamp.""" 29 | return pd.Timestamp(datetime.fromtimestamp(ts)) 30 | 31 | 32 | def _pandas_ts_to_int(ts: pd.Timestamp) -> int: 33 | """Convert int timestamp into pandas timestamp.""" 34 | # pyre-fixme[7]: Expected `int` but got `float`. 35 | return ts.to_pydatetime().timestamp() 36 | 37 | 38 | def current_timestamp_in_millis() -> int: 39 | """Grab current timestamp in milliseconds as an int.""" 40 | return int(round(time() * 1000)) 41 | 42 | 43 | def timestamps_in_range( 44 | start: datetime, end: datetime, delta: timedelta 45 | ) -> Generator[datetime, None, None]: 46 | """Generator of timestamps in range [start, end], at intervals 47 | delta. 48 | """ 49 | curr = start 50 | while curr <= end: 51 | yield curr 52 | curr += delta 53 | -------------------------------------------------------------------------------- /ax/core/tests/test_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.core.types import merge_model_predict 8 | from ax.utils.common.testutils import TestCase 9 | 10 | 11 | class TypesTest(TestCase): 12 | def setUp(self): 13 | self.num_arms = 2 14 | mu = {"m1": [0.0, 0.5], "m2": [0.1, 0.6]} 15 | cov = { 16 | "m1": {"m1": [0.0, 0.0], "m2": [0.0, 0.0]}, 17 | "m2": {"m1": [0.0, 0.0], "m2": [0.0, 0.0]}, 18 | } 19 | self.predict = (mu, cov) 20 | 21 | def testMergeModelPredict(self): 22 | mu_append = {"m1": [0.6], "m2": [0.7]} 23 | cov_append = { 24 | "m1": {"m1": [0.0], "m2": [0.0]}, 25 | "m2": {"m1": [0.0], "m2": [0.0]}, 26 | } 27 | merged_predicts = merge_model_predict(self.predict, (mu_append, cov_append)) 28 | self.assertEqual(len(merged_predicts[0]["m1"]), 3) 29 | 30 | def testMergeModelPredictFail(self): 31 | mu_append = {"m1": [0.6]} 32 | cov_append = { 33 | "m1": {"m1": [0.0], "m2": [0.0]}, 34 | "m2": {"m1": [0.0], "m2": [0.0]}, 35 | } 36 | with self.assertRaises(ValueError): 37 | merge_model_predict(self.predict, (mu_append, cov_append)) 38 | 39 | mu_append = {"m1": [0.6], "m2": [0.7]} 40 | cov_append = {"m1": {"m1": [0.0], "m2": [0.0]}} 41 | with self.assertRaises(ValueError): 42 | merge_model_predict(self.predict, (mu_append, cov_append)) 43 | -------------------------------------------------------------------------------- /ax/plot/tests/test_traces.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import plotly.graph_objects as go 9 | from ax.modelbridge.registry import Models 10 | from ax.plot.base import AxPlotConfig 11 | from ax.plot.trace import ( 12 | optimization_trace_single_method, 13 | optimization_trace_single_method_plotly, 14 | ) 15 | from ax.utils.common.testutils import TestCase 16 | from ax.utils.testing.core_stubs import get_branin_experiment 17 | from ax.utils.testing.mock import fast_botorch_optimize 18 | 19 | 20 | class TracesTest(TestCase): 21 | @fast_botorch_optimize 22 | def testTraces(self): 23 | exp = get_branin_experiment(with_batch=True) 24 | exp.trials[0].run() 25 | model = Models.BOTORCH( 26 | # Model bridge kwargs 27 | experiment=exp, 28 | data=exp.fetch_data(), 29 | ) 30 | # Assert that each type of plot can be constructed successfully 31 | plot = optimization_trace_single_method_plotly( 32 | np.array([[1, 2, 3], [4, 5, 6]]), 33 | list(model.metric_names)[0], 34 | optimization_direction="minimize", 35 | ) 36 | self.assertIsInstance(plot, go.Figure) 37 | plot = optimization_trace_single_method( 38 | np.array([[1, 2, 3], [4, 5, 6]]), 39 | list(model.metric_names)[0], 40 | optimization_direction="minimize", 41 | ) 42 | self.assertIsInstance(plot, AxPlotConfig) 43 | -------------------------------------------------------------------------------- /scripts/patch_site_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import re 9 | 10 | 11 | def patch_config( 12 | config_file: str, base_url: str = None, disable_algolia: bool = True 13 | ) -> None: 14 | config = open(config_file, "r").read() 15 | 16 | if base_url is not None: 17 | config = re.sub("baseUrl = '/';", "baseUrl = '{}';".format(base_url), config) 18 | if disable_algolia is True: 19 | config = re.sub( 20 | "const includeAlgolia = true;", "const includeAlgolia = false;", config 21 | ) 22 | 23 | with open(config_file, "w") as outfile: 24 | outfile.write(config) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser( 29 | description="Path Docusaurus siteConfig.js file when building site." 30 | ) 31 | parser.add_argument( 32 | "-f", 33 | "--config_file", 34 | metavar="path", 35 | required=True, 36 | help="Path to configuration file.", 37 | ) 38 | parser.add_argument( 39 | "-b", 40 | "--base_url", 41 | type=str, 42 | required=False, 43 | help="Value for baseUrl.", 44 | default=None, 45 | ) 46 | parser.add_argument( 47 | "--disable_algolia", 48 | required=False, 49 | action="store_true", 50 | help="Disable algolia.", 51 | ) 52 | args = parser.parse_args() 53 | patch_config(args.config_file, args.base_url, args.disable_algolia) 54 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_equality.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from datetime import datetime 8 | 9 | import pandas as pd 10 | from ax.utils.common.equality import ( 11 | dataframe_equals, 12 | datetime_equals, 13 | equality_typechecker, 14 | same_elements, 15 | ) 16 | from ax.utils.common.testutils import TestCase 17 | 18 | 19 | class EqualityTest(TestCase): 20 | def testEqualityTypechecker(self): 21 | @equality_typechecker 22 | def eq(x, y): 23 | return x == y 24 | 25 | self.assertFalse(eq(5, 5.0)) 26 | self.assertTrue(eq(5, 5)) 27 | 28 | def testListsEquals(self): 29 | self.assertFalse(same_elements([0], [0, 1])) 30 | self.assertFalse(same_elements([1, 0], [0, 2])) 31 | self.assertTrue(same_elements([1, 0], [0, 1])) 32 | 33 | def testDatetimeEquals(self): 34 | now = datetime.now() 35 | self.assertTrue(datetime_equals(None, None)) 36 | self.assertFalse(datetime_equals(None, now)) 37 | self.assertTrue(datetime_equals(now, now)) 38 | 39 | def testDataframeEquals(self): 40 | pd1 = pd.DataFrame.from_records([{"x": 100, "y": 200}]) 41 | pd2 = pd.DataFrame.from_records([{"y": 200, "x": 100}]) 42 | pd3 = pd.DataFrame.from_records([{"x": 100, "y": 300}]) 43 | 44 | self.assertTrue(dataframe_equals(pd.DataFrame(), pd.DataFrame())) 45 | self.assertTrue(dataframe_equals(pd1, pd2)) 46 | self.assertFalse(dataframe_equals(pd1, pd3)) 47 | -------------------------------------------------------------------------------- /website/static/js/plotUtils.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | // helper functions used across multiple plots 9 | function rgb(rgb_array) { 10 | return 'rgb(' + rgb_array.join() + ')'; 11 | } 12 | 13 | function copy_and_reverse(arr) { 14 | const copy = arr.slice(); 15 | copy.reverse(); 16 | return copy; 17 | } 18 | 19 | function axis_range(grid, is_log) { 20 | return is_log ? 21 | [Math.log10(Math.min(...grid)), Math.log10(Math.max(...grid))]: 22 | [Math.min(...grid), Math.max(...grid)]; 23 | } 24 | 25 | function relativize_data(f, sd, rel, arm_data, metric) { 26 | // if relative, extract status quo & compute ratio 27 | const f_final = rel === true ? [] : f; 28 | const sd_final = rel === true ? []: sd; 29 | 30 | if (rel === true) { 31 | const f_sq = ( 32 | arm_data['in_sample'][arm_data['status_quo_name']]['y'][metric] 33 | ); 34 | const sd_sq = ( 35 | arm_data['in_sample'][arm_data['status_quo_name']]['se'][metric] 36 | ); 37 | 38 | for (let i = 0; i < f.length; i++) { 39 | res = relativize(f[i], sd[i], f_sq, sd_sq); 40 | f_final.push(100 * res[0]); 41 | sd_final.push(100 * res[1]); 42 | } 43 | } 44 | 45 | return [f_final, sd_final]; 46 | } 47 | 48 | function relativize(m_t, sem_t, m_c, sem_c) { 49 | r_hat = ( 50 | (m_t - m_c) / Math.abs(m_c) - 51 | Math.pow(sem_c, 2) * m_t / Math.pow(Math.abs(m_c), 3) 52 | ); 53 | variance = ( 54 | (Math.pow(sem_t, 2) + Math.pow((m_t / m_c * sem_c), 2)) / 55 | Math.pow(m_c, 2) 56 | ) 57 | return [r_hat, Math.sqrt(variance)]; 58 | } 59 | -------------------------------------------------------------------------------- /ax/exceptions/generation_strategy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing # noqa F401, this is to enable type-checking 8 | 9 | from ax.exceptions.core import AxError, OptimizationComplete 10 | 11 | 12 | class MaxParallelismReachedException(AxError): 13 | """Special exception indicating that maximum number of trials running in 14 | parallel set on a given step (as `GenerationStep.max_parallelism`) has been 15 | reached. Upon getting this exception, users should wait until more trials 16 | are completed with data, to generate new trials. 17 | """ 18 | 19 | def __init__(self, step_index: int, model_name: str, num_running: int) -> None: 20 | super().__init__( 21 | f"Maximum parallelism for generation step #{step_index} ({model_name})" 22 | f" has been reached: {num_running} trials are currently 'running'. Some " 23 | "trials need to be completed before more trials can be generated. See " 24 | "https://ax.dev/docs/bayesopt.html to understand why limited parallelism " 25 | "improves performance of Bayesian optimization." 26 | ) 27 | 28 | 29 | class GenerationStrategyCompleted(OptimizationComplete): 30 | """Special exception indicating that the generation strategy has been 31 | completed. 32 | """ 33 | 34 | pass 35 | 36 | 37 | class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted): 38 | """Special exception indicating that the generation strategy is repeatedly 39 | suggesting previously sampled points. 40 | """ 41 | 42 | pass 43 | -------------------------------------------------------------------------------- /ax/exceptions/data_provider.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Iterable 8 | 9 | 10 | class DataProviderError(Exception): 11 | """Base Exception for Ax DataProviders. 12 | 13 | The type of the data provider must be included. 14 | The raw error is stored in the data_provider_error section, 15 | and an Ax-friendly message is stored as the actual error message. 16 | """ 17 | 18 | def __init__( 19 | self, message: str, data_provider: str, data_provider_error: Any 20 | ) -> None: 21 | self.message = message 22 | self.data_provider = data_provider 23 | self.data_provider_error = data_provider_error 24 | 25 | def __str__(self) -> str: 26 | return ( 27 | "{message}. \n Error thrown by: {dp} data provider \n" 28 | + "Native {dp} data provider error: {dp_error}" 29 | ).format( 30 | dp=self.data_provider, 31 | message=self.message, 32 | dp_error=self.data_provider_error, 33 | ) 34 | 35 | 36 | class MissingDataError(Exception): 37 | def __init__(self, missing_trial_indexes: Iterable[int]) -> None: 38 | missing_trial_str = ", ".join([str(index) for index in missing_trial_indexes]) 39 | self.message: str = ( 40 | f"Unable to find data for the following trials: {missing_trial_str} " 41 | "consider updating the data fetching kwargs or manually fetching " 42 | "data via `refetch_data()`" 43 | ) 44 | 45 | def __str__(self) -> str: 46 | return self.message 47 | -------------------------------------------------------------------------------- /ax/plot/tests/test_fitted_scatter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import plotly.graph_objects as go 8 | from ax.modelbridge.registry import Models 9 | from ax.plot.base import AxPlotConfig 10 | from ax.plot.scatter import interact_fitted, interact_fitted_plotly 11 | from ax.utils.common.testutils import TestCase 12 | from ax.utils.testing.core_stubs import get_branin_experiment 13 | from ax.utils.testing.mock import fast_botorch_optimize 14 | 15 | 16 | class FittedScatterTest(TestCase): 17 | @fast_botorch_optimize 18 | def test_fitted_scatter(self): 19 | exp = get_branin_experiment(with_str_choice_param=True, with_batch=True) 20 | exp.trials[0].run() 21 | model = Models.BOTORCH( 22 | # Model bridge kwargs 23 | experiment=exp, 24 | data=exp.fetch_data(), 25 | ) 26 | # Assert that each type of plot can be constructed successfully 27 | plot = interact_fitted_plotly(model=model, rel=False) 28 | self.assertIsInstance(plot, go.Figure) 29 | plot = interact_fitted(model=model, rel=False) 30 | self.assertIsInstance(plot, AxPlotConfig) 31 | 32 | # Make sure all parameters and metrics are displayed in tooltips 33 | tooltips = list(exp.parameters.keys()) + list(exp.metrics.keys()) 34 | for d in plot.data["data"]: 35 | # Only check scatter plots hoverovers 36 | if d["type"] != "scatter": 37 | continue 38 | for text in d["text"]: 39 | for tt in tooltips: 40 | self.assertTrue(tt in text) 41 | -------------------------------------------------------------------------------- /ax/utils/testing/metrics/backend_simulator_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, List 8 | 9 | from ax.core.base_trial import BaseTrial 10 | from ax.core.map_data import MapData 11 | from ax.metrics.noisy_function_map import NoisyFunctionMapMetric 12 | 13 | 14 | class BackendSimulatorTimestampMapMetric(NoisyFunctionMapMetric): 15 | """A metric that interfaces with an underlying ``BackendSimulator`` and 16 | returns timestamp map data.""" 17 | 18 | def fetch_trial_data( 19 | self, trial: BaseTrial, noisy: bool = True, **kwargs: Any 20 | ) -> MapData: 21 | """Fetch data for one trial.""" 22 | backend_simulator = trial.experiment.runner.simulator # pyre-ignore[16] 23 | sim_trial = backend_simulator.get_sim_trial_by_index(trial.index) 24 | end_time = ( 25 | backend_simulator.time 26 | if sim_trial.sim_completed_time is None 27 | else sim_trial.sim_completed_time 28 | ) 29 | timestamps = self.convert_to_timestamps( 30 | start_time=sim_trial.sim_start_time, end_time=end_time 31 | ) 32 | timestamp_kwargs = {"map_keys": ["timestamp"], "timestamp": timestamps} 33 | return NoisyFunctionMapMetric.fetch_trial_data( 34 | self, trial=trial, noisy=noisy, **kwargs, **timestamp_kwargs 35 | ) 36 | 37 | def convert_to_timestamps(self, start_time: float, end_time: float) -> List[float]: 38 | """Given a starting and current time, get the list of intermediate 39 | timestamps at which we have observations.""" 40 | raise NotImplementedError 41 | -------------------------------------------------------------------------------- /website/tutorials.json: -------------------------------------------------------------------------------- 1 | { 2 | "API Comparison": [ 3 | { 4 | "id": "gpei_hartmann_loop", 5 | "title": "Loop API" 6 | }, 7 | { 8 | "id": "gpei_hartmann_service", 9 | "title": "Service API" 10 | }, 11 | { 12 | "id": "gpei_hartmann_developer", 13 | "title": "Developer API" 14 | } 15 | ], 16 | "Deep Dives": [ 17 | { 18 | "id": "visualizations", 19 | "title": "Visualizations" 20 | }, 21 | { 22 | "id": "generation_strategy", 23 | "title": "Generation Strategy" 24 | }, 25 | { 26 | "id": "scheduler", 27 | "title": "Scheduler" 28 | }, 29 | { 30 | "id": "modular_botax", 31 | "title": "Modular `BoTorchModel`" 32 | } 33 | ], 34 | "Bayesian Optimization": [ 35 | { 36 | "id": "tune_cnn", 37 | "title": "Hyperparameter Optimization for PyTorch" 38 | }, 39 | { 40 | "id": "raytune_pytorch_cnn", 41 | "title": "Hyperparameter Optimization via Raytune" 42 | }, 43 | { 44 | "id": "multi_task", 45 | "title": "Multi-Task Modeling" 46 | }, 47 | { 48 | "id": "multiobjective_optimization", 49 | "title": "Multi-Objective Optimization" 50 | }, 51 | { 52 | "id": "saasbo", 53 | "title": "High-Dimensional Bayesian Optimization with Sparse Axis-Aligned Subspaces (SAASBO)" 54 | }, 55 | { 56 | "id": "saasbo_nehvi", 57 | "title": "Fully Bayesian, High-Dimensional, Multi-Objective Optimization" 58 | } 59 | ], 60 | "Field Experiments": [ 61 | { 62 | "id": "factorial", 63 | "title": "Bandit Optimization" 64 | }, 65 | { 66 | "dir": "human_in_the_loop", 67 | "id": "human_in_the_loop", 68 | "title": "Human-in-the-Loop Optimization" 69 | } 70 | ] 71 | } 72 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/reduced_state.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List 8 | 9 | from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun 10 | from sqlalchemy.orm import defaultload, lazyload, strategy_options 11 | from sqlalchemy.orm.attributes import InstrumentedAttribute 12 | 13 | 14 | GR_LARGE_MODEL_ATTRS: List[InstrumentedAttribute] = [ # pyre-ignore[9] 15 | SQAGeneratorRun.model_kwargs, 16 | SQAGeneratorRun.bridge_kwargs, 17 | SQAGeneratorRun.model_state_after_gen, 18 | SQAGeneratorRun.gen_metadata, 19 | ] 20 | 21 | 22 | GR_PARAMS_METRICS_COLS = [ 23 | "parameters", 24 | "parameter_constraints", 25 | "metrics", 26 | ] 27 | 28 | 29 | def get_query_options_to_defer_immutable_duplicates() -> List[strategy_options.Load]: 30 | """Returns the query options that defer loading of attributes that are duplicated 31 | on each trial (like search space attributes and metrics). These attributes do not 32 | need to be loaded for experiments with immutable search space and optimization 33 | configuration. 34 | """ 35 | options = [lazyload(f"generator_runs.{col}") for col in GR_PARAMS_METRICS_COLS] 36 | return options 37 | 38 | 39 | def get_query_options_to_defer_large_model_cols() -> List[strategy_options.Load]: 40 | """Returns the query options that defer loading of model-state-related columns 41 | of generator runs, which can be large and are not needed on every generator run 42 | when loading experiment and generation strategy in reduced state. 43 | """ 44 | return [ 45 | defaultload("generator_runs").defer(col.key) for col in GR_LARGE_MODEL_ATTRS 46 | ] 47 | -------------------------------------------------------------------------------- /ax/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.core import ( 8 | Arm, 9 | BatchTrial, 10 | ChoiceParameter, 11 | ComparisonOp, 12 | Data, 13 | Experiment, 14 | FixedParameter, 15 | GeneratorRun, 16 | Metric, 17 | MultiObjective, 18 | MultiObjectiveOptimizationConfig, 19 | Objective, 20 | ObjectiveThreshold, 21 | OptimizationConfig, 22 | OrderConstraint, 23 | OutcomeConstraint, 24 | Parameter, 25 | ParameterConstraint, 26 | ParameterType, 27 | RangeParameter, 28 | Runner, 29 | SearchSpace, 30 | SumConstraint, 31 | Trial, 32 | ) 33 | from ax.modelbridge import Models 34 | from ax.service import OptimizationLoop, optimize 35 | from ax.storage import json_load, json_save 36 | 37 | 38 | try: 39 | # pyre-fixme[21]: Could not find a module... to import `ax.version`. 40 | from ax.version import version as __version__ 41 | except Exception: 42 | __version__ = "Unknown" 43 | 44 | __all__ = [ 45 | "Arm", 46 | "BatchTrial", 47 | "ChoiceParameter", 48 | "ComparisonOp", 49 | "Data", 50 | "Experiment", 51 | "FixedParameter", 52 | "GeneratorRun", 53 | "Metric", 54 | "Models", 55 | "MultiObjective", 56 | "MultiObjectiveOptimizationConfig", 57 | "Objective", 58 | "ObjectiveThreshold", 59 | "OptimizationConfig", 60 | "OptimizationLoop", 61 | "OrderConstraint", 62 | "OutcomeConstraint", 63 | "Parameter", 64 | "ParameterConstraint", 65 | "ParameterType", 66 | "RangeParameter", 67 | "Runner", 68 | "SearchSpace", 69 | "SumConstraint", 70 | "Trial", 71 | "optimize", 72 | "json_save", 73 | "json_load", 74 | ] 75 | -------------------------------------------------------------------------------- /ax/benchmark/tests/test_benchmark_method.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.benchmark.benchmark_method import BenchmarkMethod 7 | from ax.exceptions.core import UserInputError 8 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 9 | from ax.modelbridge.registry import Models 10 | from ax.service.scheduler import SchedulerOptions 11 | from ax.utils.common.testutils import TestCase 12 | 13 | 14 | class TestBenchmarkMethod(TestCase): 15 | def test_benchmark_method(self): 16 | gs = GenerationStrategy( 17 | steps=[ 18 | GenerationStep( 19 | model=Models.SOBOL, 20 | num_trials=10, 21 | ) 22 | ], 23 | name="SOBOL", 24 | ) 25 | options = SchedulerOptions(total_trials=10) 26 | method = BenchmarkMethod( 27 | name="Sobol10", generation_strategy=gs, scheduler_options=options 28 | ) 29 | 30 | self.assertEqual(method.generation_strategy, gs) 31 | self.assertEqual(method.scheduler_options, options) 32 | 33 | def test_total_trials_none(self): 34 | gs = GenerationStrategy( 35 | steps=[ 36 | GenerationStep( 37 | model=Models.SOBOL, 38 | num_trials=10, 39 | ) 40 | ], 41 | name="SOBOL", 42 | ) 43 | options = SchedulerOptions() 44 | 45 | with self.assertRaisesRegex( 46 | UserInputError, "SchedulerOptions.total_trials may not be None" 47 | ): 48 | BenchmarkMethod( 49 | name="Sobol10", generation_strategy=gs, scheduler_options=options 50 | ) 51 | -------------------------------------------------------------------------------- /sphinx/source/service.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.service 5 | =================================== 6 | 7 | .. automodule:: ax.service 8 | .. currentmodule:: ax.service 9 | 10 | 11 | Ax Client 12 | ~~~~~~~~~~~~~~~~ 13 | 14 | .. automodule:: ax.service.ax_client 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Managed Loop 20 | ~~~~~~~~~~~~~~~~ 21 | 22 | .. automodule:: ax.service.managed_loop 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | 28 | Scheduler 29 | ~~~~~~~~~~ 30 | 31 | .. automodule:: ax.service.scheduler 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | .. automodule:: ax.service.utils.scheduler_options 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | 41 | Utils 42 | ------ 43 | 44 | Best Point Identification 45 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 46 | 47 | .. automodule:: ax.service.utils.best_point_mixin 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | 53 | .. automodule:: ax.service.utils.best_point 54 | :members: 55 | :undoc-members: 56 | :show-inheritance: 57 | 58 | 59 | Instantiation 60 | ~~~~~~~~~~~~~~ 61 | 62 | .. automodule:: ax.service.utils.instantiation 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | 68 | Reporting 69 | ~~~~~~~~~~~~~~ 70 | 71 | .. automodule:: ax.service.utils.report_utils 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | 76 | 77 | WithDBSettingsBase 78 | ~~~~~~~~~~~~~~~~~~ 79 | 80 | .. automodule:: ax.service.utils.with_db_settings_base 81 | :members: 82 | :undoc-members: 83 | :show-inheritance: 84 | 85 | 86 | EarlyStopping 87 | ~~~~~~~~~~~~~ 88 | 89 | .. automodule:: ax.service.utils.early_stopping 90 | :members: 91 | :undoc-members: 92 | :show-inheritance: 93 | -------------------------------------------------------------------------------- /ax/benchmark/methods/saasbo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.benchmark.benchmark_method import BenchmarkMethod 7 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 8 | from ax.modelbridge.registry import Models 9 | from ax.service.scheduler import SchedulerOptions 10 | 11 | 12 | def get_saasbo_default() -> BenchmarkMethod: 13 | generation_strategy = GenerationStrategy( 14 | name="SOBOL+FULLYBAYESIAN::default", 15 | steps=[ 16 | GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3), 17 | GenerationStep( 18 | model=Models.FULLYBAYESIAN, 19 | num_trials=-1, 20 | max_parallelism=1, 21 | ), 22 | ], 23 | ) 24 | 25 | scheduler_options = SchedulerOptions(total_trials=30) 26 | 27 | return BenchmarkMethod( 28 | name=generation_strategy.name, 29 | generation_strategy=generation_strategy, 30 | scheduler_options=scheduler_options, 31 | ) 32 | 33 | 34 | def get_saasbo_moo_default() -> BenchmarkMethod: 35 | generation_strategy = GenerationStrategy( 36 | name="SOBOL+FULLYBAYESIANMOO::default", 37 | steps=[ 38 | GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3), 39 | GenerationStep( 40 | model=Models.FULLYBAYESIANMOO, 41 | num_trials=-1, 42 | max_parallelism=1, 43 | ), 44 | ], 45 | ) 46 | 47 | scheduler_options = SchedulerOptions(total_trials=30) 48 | 49 | return BenchmarkMethod( 50 | name=generation_strategy.name, 51 | generation_strategy=generation_strategy, 52 | scheduler_options=scheduler_options, 53 | ) 54 | -------------------------------------------------------------------------------- /ax/utils/common/docutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | """Support functions for sphinx et. al 10 | """ 11 | 12 | 13 | from typing import Any, Callable, TypeVar 14 | 15 | 16 | _T = TypeVar("_T") 17 | 18 | 19 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 20 | # pyre-ignore[34]: T77127616 21 | def copy_doc(src: Callable[..., Any]) -> Callable[[_T], _T]: 22 | """A decorator that copies the docstring of another object 23 | 24 | Since ``sphinx`` actually loads the python modules to grab the docstrings 25 | this works with both ``sphinx`` and the ``help`` function. 26 | 27 | .. code:: python 28 | 29 | class Cat(Mamal): 30 | 31 | @property 32 | @copy_doc(Mamal.is_feline) 33 | def is_feline(self) -> true: 34 | ... 35 | """ 36 | # It would be tempting to try to get the doc through the class the method 37 | # is bound to (via __self__) but decorators are called before __self__ is 38 | # assigned. 39 | # One other solution would be to use a decorator on classes that would fill 40 | # all the missing docstrings but we want to be able to detect syntactically 41 | # when docstrings are copied to keep things nice and simple 42 | 43 | if src.__doc__ is None: 44 | # pyre-fixme[16]: `_T` has no attribute `__qualname__`. 45 | raise ValueError(f"{src.__qualname__} has no docstring to copy") 46 | 47 | def copy_doc(dst: _T) -> _T: 48 | if dst.__doc__ is not None: 49 | # pyre-fixme[16]: `_T` has no attribute `__qualname__`. 50 | raise ValueError(f"{dst.__qualname__} already has a docstring") 51 | dst.__doc__ = src.__doc__ 52 | return dst 53 | 54 | return copy_doc 55 | -------------------------------------------------------------------------------- /ax/models/tests/test_discrete.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from ax.models.discrete_base import DiscreteModel 9 | from ax.utils.common.testutils import TestCase 10 | 11 | 12 | class DiscreteModelTest(TestCase): 13 | def setUp(self): 14 | pass 15 | 16 | def test_discrete_model_get_state(self): 17 | discrete_model = DiscreteModel() 18 | self.assertEqual(discrete_model._get_state(), {}) 19 | 20 | def test_discrete_model_feature_importances(self): 21 | discrete_model = DiscreteModel() 22 | with self.assertRaises(NotImplementedError): 23 | discrete_model.feature_importances() 24 | 25 | def testDiscreteModelFit(self): 26 | discrete_model = DiscreteModel() 27 | discrete_model.fit( 28 | Xs=[[[0]]], 29 | Ys=[[0]], 30 | Yvars=[[1]], 31 | parameter_values=[[0, 1]], 32 | outcome_names=[], 33 | ) 34 | 35 | def testdiscreteModelPredict(self): 36 | discrete_model = DiscreteModel() 37 | with self.assertRaises(NotImplementedError): 38 | discrete_model.predict([[0]]) 39 | 40 | def testdiscreteModelGen(self): 41 | discrete_model = DiscreteModel() 42 | with self.assertRaises(NotImplementedError): 43 | discrete_model.gen( 44 | n=1, parameter_values=[[0, 1]], objective_weights=np.array([1]) 45 | ) 46 | 47 | def testdiscreteModelCrossValidate(self): 48 | discrete_model = DiscreteModel() 49 | with self.assertRaises(NotImplementedError): 50 | discrete_model.cross_validate( 51 | Xs_train=[[[0]]], Ys_train=[[1]], Yvars_train=[[1]], X_test=[[1]] 52 | ) 53 | -------------------------------------------------------------------------------- /website/static/img/ax.svg: -------------------------------------------------------------------------------- 1 | 01_FullColor -------------------------------------------------------------------------------- /ax/models/discrete/eb_thompson.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from typing import List, Tuple 9 | 10 | import numpy as np 11 | from ax.models.discrete.thompson import ThompsonSampler 12 | from ax.utils.common.logger import get_logger 13 | from ax.utils.stats.statstools import positive_part_james_stein 14 | 15 | 16 | logger: logging.Logger = get_logger(__name__) 17 | 18 | 19 | class EmpiricalBayesThompsonSampler(ThompsonSampler): 20 | """Generator for Thompson sampling using Empirical Bayes estimates. 21 | 22 | The generator applies positive-part James-Stein Estimator to the data 23 | passed in via `fit` and then performs Thompson Sampling. 24 | """ 25 | 26 | def _fit_Ys_and_Yvars( 27 | self, Ys: List[List[float]], Yvars: List[List[float]], outcome_names: List[str] 28 | ) -> Tuple[List[List[float]], List[List[float]]]: 29 | newYs = [] 30 | newYvars = [] 31 | for i, (Y, Yvar) in enumerate(zip(Ys, Yvars)): 32 | newY, newYvar = self._apply_shrinkage(Y, Yvar, i) 33 | newYs.append(newY) 34 | newYvars.append(newYvar) 35 | return newYs, newYvars 36 | 37 | def _apply_shrinkage( 38 | self, Y: List[float], Yvar: List[float], outcome: int 39 | ) -> Tuple[List[float], List[float]]: 40 | npY = np.array(Y) 41 | npYvar = np.array(Yvar) 42 | npYsem = np.sqrt(Yvar) 43 | try: 44 | npY, npYsem = positive_part_james_stein(means=npY, sems=npYsem) 45 | except ValueError as e: 46 | logger.warning( 47 | str(e) + f" Raw (unshrunk) estimates used for outcome: {outcome}" 48 | ) 49 | Y = npY.tolist() 50 | npYvar = npYsem**2 51 | Yvar = npYvar.tolist() 52 | return Y, Yvar 53 | -------------------------------------------------------------------------------- /ax/metrics/botorch_test_problem.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Optional 7 | 8 | import pandas as pd 9 | from ax.core.base_trial import BaseTrial 10 | from ax.core.data import Data 11 | from ax.core.metric import Metric 12 | 13 | 14 | class BotorchTestProblemMetric(Metric): 15 | """A Metric for retriving information from a BotorchTestProblemRunner. 16 | A BotorchTestProblemRunner will attach the result of a call to 17 | BaseTestProblem.forward per Arm on a given trial, and this Metric will extract the 18 | proper value from the resulting tensor given its index. 19 | """ 20 | 21 | def __init__(self, name: str, noise_sd: float, index: Optional[int] = None) -> None: 22 | super().__init__(name=name) 23 | self.noise_sd = noise_sd 24 | self.index = index 25 | 26 | def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> Data: 27 | # run_metadata["Ys"] can be either a list of results or a single float 28 | mean = ( 29 | [ 30 | trial.run_metadata["Ys"][name][self.index] 31 | for name, arm in trial.arms_by_name.items() 32 | ] 33 | if self.index is not None 34 | else [ 35 | trial.run_metadata["Ys"][name] 36 | for name, arm in trial.arms_by_name.items() 37 | ] 38 | ) 39 | df = pd.DataFrame( 40 | { 41 | "arm_name": [name for name, _ in trial.arms_by_name.items()], 42 | "metric_name": self.name, 43 | "mean": mean, 44 | # If no noise_std is returned then Botorch evaluated the true function 45 | "sem": self.noise_sd, 46 | "trial_index": trial.index, 47 | } 48 | ) 49 | 50 | return Data(df=df) 51 | -------------------------------------------------------------------------------- /ax/utils/testing/unittest_conventions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib 8 | import pathlib 9 | import sys 10 | import unittest 11 | 12 | import __test_modules__ 13 | from ax.utils.common import testutils 14 | 15 | 16 | def get_all_subclasses(cls): 17 | """Reccursively get all the subclasses of cls""" 18 | for x in cls.__subclasses__(): # subclasses only contains direct decendants 19 | yield x 20 | yield from get_all_subclasses(x) 21 | 22 | 23 | class TestUnittestConventions(testutils.TestCase): 24 | def test_uses_ae_unittest(self): 25 | """Check that all of our tests are inheriting from our own base class 26 | 27 | Our base class does a bit more (like making sure we don't use any of python's 28 | deprecated `assert` functions) so we want to enforce its usage everywhere. 29 | """ 30 | test_modules = set(__test_modules__.TEST_MODULES) 31 | # Make sure everything is loaded 32 | for m in test_modules: 33 | importlib.import_module(m) 34 | test_cases = [ 35 | cls 36 | for cls in get_all_subclasses(unittest.TestCase) 37 | if cls.__module__ in test_modules 38 | ] 39 | base = testutils.TestCase 40 | for t in test_cases: 41 | with self.subTest(t.__name__): 42 | if not issubclass(t, base): 43 | abs_path = pathlib.Path(sys.modules[t.__module__].__file__) 44 | root = pathlib.Path(__test_modules__.__file__).parent 45 | filename = abs_path.relative_to(root) 46 | self.fail( 47 | f"in {filename}: {t.__qualname__} should inherit from " 48 | f"{base.__module__}.{base.__name__}" 49 | ) 50 | -------------------------------------------------------------------------------- /ax/core/tests/test_runner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from unittest import mock 8 | 9 | from ax.core.base_trial import BaseTrial 10 | from ax.core.runner import Runner 11 | from ax.utils.common.testutils import TestCase 12 | from ax.utils.testing.core_stubs import get_batch_trial, get_trial 13 | 14 | 15 | class DummyRunner(Runner): 16 | def run(self, trial: BaseTrial): 17 | return {"metadatum": f"value_for_trial_{trial.index}"} 18 | 19 | 20 | class RunnerTest(TestCase): 21 | def setUp(self): 22 | self.dummy_runner = DummyRunner() 23 | self.trials = [get_trial(), get_batch_trial()] 24 | 25 | def test_base_runner_staging_required(self): 26 | self.assertFalse(self.dummy_runner.staging_required) 27 | 28 | def test_base_runner_stop(self): 29 | with self.assertRaises(NotImplementedError): 30 | self.dummy_runner.stop(trial=mock.Mock(), reason="") 31 | 32 | def test_base_runner_clone(self): 33 | runner_clone = self.dummy_runner.clone() 34 | self.assertIsInstance(runner_clone, DummyRunner) 35 | self.assertEqual(runner_clone, self.dummy_runner) 36 | 37 | def test_base_runner_run_multiple(self): 38 | metadata = self.dummy_runner.run_multiple(trials=self.trials) 39 | self.assertEqual( 40 | metadata, 41 | {t.index: {"metadatum": f"value_for_trial_{t.index}"} for t in self.trials}, 42 | ) 43 | self.assertEqual({}, self.dummy_runner.run_multiple(trials=[])) 44 | 45 | def test_base_runner_poll_trial_status(self): 46 | with self.assertRaises(NotImplementedError): 47 | self.dummy_runner.poll_trial_status(trials=self.trials) 48 | 49 | def test_poll_available_capacity(self): 50 | self.assertEqual(self.dummy_runner.poll_available_capacity(), -1) 51 | -------------------------------------------------------------------------------- /ax/metrics/jenatton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Optional 7 | 8 | import pandas as pd 9 | from ax.core.base_trial import BaseTrial 10 | from ax.core.data import Data 11 | from ax.core.metric import Metric 12 | from ax.utils.common.typeutils import not_none 13 | 14 | 15 | class JenattonMetric(Metric): 16 | def __init__( 17 | self, 18 | name: str = "jenatton", 19 | ) -> None: 20 | super().__init__(name=name) 21 | 22 | @staticmethod 23 | def _f( 24 | x1: Optional[int] = None, 25 | x2: Optional[int] = None, 26 | x3: Optional[int] = None, 27 | x4: Optional[float] = None, 28 | x5: Optional[float] = None, 29 | x6: Optional[float] = None, 30 | x7: Optional[float] = None, 31 | r8: Optional[float] = None, 32 | r9: Optional[float] = None, 33 | ) -> float: 34 | if x1 == 0: 35 | if x2 == 0: 36 | return not_none(x4) ** 2 + 0.1 + not_none(r8) 37 | else: 38 | return not_none(x5) ** 2 + 0.2 + not_none(r8) 39 | else: 40 | if x3 == 0: 41 | return not_none(x6) ** 2 + 0.3 + not_none(r9) 42 | else: 43 | return not_none(x7) ** 2 + 0.4 + not_none(r9) 44 | 45 | def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> Data: 46 | # pyre-ignore [6] 47 | mean = [self._f(**arm.parameters) for _, arm in trial.arms_by_name.items()] 48 | df = pd.DataFrame( 49 | { 50 | "arm_name": [name for name, _ in trial.arms_by_name.items()], 51 | "metric_name": self.name, 52 | "mean": mean, 53 | "sem": 0, 54 | "trial_index": trial.index, 55 | } 56 | ) 57 | 58 | return Data(df=df) 59 | -------------------------------------------------------------------------------- /ax/core/tests/test_parameter_distribution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.core.parameter_distribution import ParameterDistribution 8 | from ax.exceptions.core import UserInputError 9 | from ax.utils.common.testutils import TestCase 10 | from scipy.stats._continuous_distns import norm_gen 11 | from scipy.stats._distn_infrastructure import rv_frozen 12 | 13 | 14 | class ParameterDistributionTest(TestCase): 15 | def test_parameter_distribution(self): 16 | dist = ParameterDistribution( 17 | parameters=["x1"], 18 | distribution_class="norm", 19 | distribution_parameters={"loc": 0.0, "scale": 1.0}, 20 | multiplicative=True, 21 | ) 22 | self.assertTrue(dist.multiplicative) 23 | dist_obj = dist.distribution 24 | self.assertEqual(dist.parameters, ["x1"]) 25 | self.assertIsInstance(dist_obj, rv_frozen) 26 | self.assertIsInstance(dist_obj.dist, norm_gen) 27 | dist_kwds = dist_obj.kwds 28 | self.assertEqual(dist_kwds["loc"], 0.0) 29 | self.assertEqual(dist_kwds["scale"], 1.0) 30 | 31 | # Test repr. 32 | expected_repr = ( 33 | "ParameterDistribution(" 34 | "parameters=['x1'], " 35 | "distribution_class=norm, " 36 | "distribution_parameters={'loc': 0.0, 'scale': 1.0}, " 37 | "multiplicative=True)" 38 | ) 39 | self.assertEqual(str(dist), expected_repr) 40 | 41 | # Test weird distribution name. 42 | dist = ParameterDistribution( 43 | parameters=["x1"], 44 | distribution_class="dummy_dist", 45 | distribution_parameters={}, 46 | ) 47 | self.assertFalse(dist.multiplicative) 48 | with self.assertRaises(UserInputError): 49 | dist.distribution 50 | -------------------------------------------------------------------------------- /ax/global_stopping/strategies/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Any, Dict, Tuple 9 | 10 | from ax.core.experiment import Experiment 11 | from ax.utils.common.base import Base 12 | 13 | 14 | class BaseGlobalStoppingStrategy(ABC, Base): 15 | """Interface for strategies used to stop the optimization. 16 | 17 | Note that this is different from the `BaseEarlyStoppingStrategy`, 18 | the functionality of which is to decide whether a trial with partial 19 | results available during evaluation should be stopped before 20 | fully completing. In global early stopping, the decision is about 21 | whether or not to stop the overall optimzation altogether (e.g. b/c 22 | the expected marginal gains of running additional evaluations do not 23 | justify the cost of running these trials). 24 | """ 25 | 26 | def __init__(self, min_trials: int) -> None: 27 | """ 28 | Initiating a base stopping strategy. 29 | 30 | Args: 31 | min_trials: Minimum number of trials before the stopping strategy kicks in. 32 | """ 33 | self.min_trials = min_trials 34 | 35 | @abstractmethod 36 | def should_stop_optimization( 37 | self, 38 | experiment: Experiment, 39 | **kwargs: Dict[str, Any], 40 | ) -> Tuple[bool, str]: 41 | """Decide whether to stop optimization. 42 | 43 | Typical examples include stopping the optimization loop when the objective 44 | appears to not improve anymore. 45 | 46 | Args: 47 | experiment: Experiment that contains the trials and other contextual data. 48 | 49 | Returns: 50 | A Tuple with a boolean determining whether the optimization should stop, 51 | and a str declaring the reason for stopping. 52 | """ 53 | pass # pragma: nocover 54 | -------------------------------------------------------------------------------- /website/core/Footer.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | const React = require('react'); 9 | 10 | class Footer extends React.Component { 11 | 12 | docUrl(doc, language) { 13 | const baseUrl = this.props.config.baseUrl; 14 | const docsUrl = this.props.config.docsUrl; 15 | const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`; 16 | const langPart = `${language ? `${language}/` : ''}`; 17 | return `${baseUrl}${docsPart}${langPart}${doc}`; 18 | } 19 | 20 | pageUrl(doc, language) { 21 | const baseUrl = this.props.config.baseUrl; 22 | return baseUrl + (language ? `${language}/` : '') + doc; 23 | } 24 | 25 | render() { 26 | const currentYear = new Date().getFullYear(); 27 | 28 | return ( 29 | 63 | ); 64 | } 65 | } 66 | 67 | module.exports = Footer; 68 | -------------------------------------------------------------------------------- /ax/service/tests/test_best_point.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.service.utils.best_point_mixin import BestPointMixin 7 | from ax.utils.common.testutils import TestCase 8 | from ax.utils.common.typeutils import not_none 9 | from ax.utils.testing.core_stubs import get_experiment_with_observations 10 | 11 | 12 | class TestBestPointMixin(TestCase): 13 | def test_get_trace(self): 14 | # Alias for easier access. 15 | get_trace = BestPointMixin.get_trace 16 | 17 | # Single objective, minimize. 18 | exp = get_experiment_with_observations( 19 | observations=[[11], [10], [9], [15], [5]], minimize=True 20 | ) 21 | self.assertEqual(get_trace(exp), [11, 10, 9, 9, 5]) 22 | # Same experiment with maximize via new optimization config. 23 | opt_conf = not_none(exp.optimization_config).clone() 24 | opt_conf.objective.minimize = False 25 | self.assertEqual(get_trace(exp, opt_conf), [11, 11, 11, 15, 15]) 26 | 27 | # Scalarized. 28 | exp = get_experiment_with_observations( 29 | observations=[[1, 1], [2, 2], [3, 3]], 30 | scalarized=True, 31 | ) 32 | self.assertEqual(get_trace(exp), [2, 4, 6]) 33 | 34 | # Multi objective. 35 | exp = get_experiment_with_observations( 36 | observations=[[1, 1], [1, 2], [3, 3], [2, 4], [2, 1]], 37 | ) 38 | self.assertEqual(get_trace(exp), [1, 2, 9, 11, 11]) 39 | 40 | # W/ constraints. 41 | exp = get_experiment_with_observations( 42 | observations=[[1, 1, 1], [1, 2, -1], [3, 3, -1], [2, 4, 1], [2, 1, 1]], 43 | constrained=True, 44 | ) 45 | self.assertEqual(get_trace(exp), [1, 1, 1, 8, 8]) 46 | 47 | # W/ first objective being minimized. 48 | exp = get_experiment_with_observations( 49 | observations=[[1, 1], [-1, 2], [3, 3], [-2, 4], [2, 1]], minimize=True 50 | ) 51 | self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8]) 52 | -------------------------------------------------------------------------------- /website/static/img/ax_lockup.svg: -------------------------------------------------------------------------------- 1 | 01_FullColor -------------------------------------------------------------------------------- /ax/core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa F401 8 | from ax.core.arm import Arm 9 | from ax.core.batch_trial import BatchTrial 10 | from ax.core.data import Data 11 | from ax.core.experiment import Experiment 12 | from ax.core.generator_run import GeneratorRun 13 | from ax.core.metric import Metric 14 | from ax.core.objective import MultiObjective, Objective 15 | from ax.core.observation import Observation, ObservationData, ObservationFeatures 16 | from ax.core.optimization_config import ( 17 | MultiObjectiveOptimizationConfig, 18 | OptimizationConfig, 19 | ) 20 | from ax.core.outcome_constraint import ( 21 | ComparisonOp, 22 | ObjectiveThreshold, 23 | OutcomeConstraint, 24 | ) 25 | from ax.core.parameter import ( 26 | ChoiceParameter, 27 | FixedParameter, 28 | Parameter, 29 | ParameterType, 30 | RangeParameter, 31 | ) 32 | from ax.core.parameter_constraint import ( 33 | OrderConstraint, 34 | ParameterConstraint, 35 | SumConstraint, 36 | ) 37 | from ax.core.parameter_distribution import ParameterDistribution 38 | from ax.core.risk_measures import RiskMeasure 39 | from ax.core.runner import Runner 40 | from ax.core.search_space import SearchSpace 41 | from ax.core.trial import Trial 42 | from ax.core.types import TParameterization 43 | 44 | 45 | __all__ = [ 46 | "Arm", 47 | "BatchTrial", 48 | "ChoiceParameter", 49 | "ComparisonOp", 50 | "Data", 51 | "Experiment", 52 | "FixedParameter", 53 | "GeneratorRun", 54 | "Metric", 55 | "MultiObjective", 56 | "MultiObjectiveOptimizationConfig", 57 | "Objective", 58 | "ObjectiveThreshold", 59 | "OptimizationConfig", 60 | "OrderConstraint", 61 | "OutcomeConstraint", 62 | "Parameter", 63 | "ParameterConstraint", 64 | "ParameterDistribution", 65 | "ParameterType", 66 | "RangeParameter", 67 | "RiskMeasure", 68 | "Runner", 69 | "SearchSpace", 70 | "SimpleExperiment", 71 | "SumConstraint", 72 | "Trial", 73 | ] 74 | -------------------------------------------------------------------------------- /website/static/img/ax_logo_lockup.svg: -------------------------------------------------------------------------------- 1 | 01_FullColor -------------------------------------------------------------------------------- /website/static/img/ax_lockup_white.svg: -------------------------------------------------------------------------------- 1 | Ax_Identity_Lockup_white_font -------------------------------------------------------------------------------- /ax/modelbridge/transforms/cap_parameter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional, TYPE_CHECKING 8 | 9 | from ax.core.observation import ObservationData, ObservationFeatures 10 | from ax.core.parameter import RangeParameter 11 | from ax.core.search_space import SearchSpace 12 | from ax.modelbridge.transforms.base import Transform 13 | from ax.models.types import TConfig 14 | from ax.utils.common.typeutils import checked_cast 15 | 16 | if TYPE_CHECKING: 17 | # import as module to make sphinx-autodoc-typehints happy 18 | from ax import modelbridge as modelbridge_module # noqa F401 # pragma: no cover 19 | 20 | 21 | class CapParameter(Transform): 22 | """Cap parameter range(s) to given values. Expects a configuration of form 23 | { parameter_name -> new_upper_range_value }. 24 | 25 | This transform only transforms the search space. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | search_space: SearchSpace, 31 | observation_features: List[ObservationFeatures], 32 | observation_data: List[ObservationData], 33 | modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, 34 | config: Optional[TConfig] = None, 35 | ) -> None: 36 | self.config = config or {} 37 | self.transform_parameters = { # Only transform parameters in config. 38 | p_name for p_name in search_space.parameters if p_name in self.config 39 | } 40 | 41 | def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: 42 | for p_name, p in search_space.parameters.items(): 43 | if p_name in self.transform_parameters: 44 | if not isinstance(p, RangeParameter): 45 | raise NotImplementedError( 46 | "Can only cap range parameters currently." 47 | ) 48 | checked_cast(RangeParameter, p).update_range( 49 | upper=self.config.get(p_name) 50 | ) 51 | return search_space 52 | -------------------------------------------------------------------------------- /docs/why-ax.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: why-ax 3 | title: Why Ax? 4 | sidebar_label: Why Ax? 5 | --- 6 | 7 | Developers and researchers alike face problems which confront them with a large space of possible ways to configure something –– whether those are "magic numbers" used for infrastructure or compiler flags, learning rates or other hyperparameters in machine learning, or images and calls-to-action used in marketing promotions. Selecting and tuning these configurations can often take time, resources, and can affect the quality of user experiences. Ax is a machine learning system to help automate this process, so that researchers and developers can determine how to get the most out of their software in an optimally efficient way. 8 | 9 | Ax is a platform for optimizing any kind of experiment, including machine learning experiments, A/B tests, and simulations. Ax can optimize discrete configurations (e.g., variants of an A/B test) using multi-armed bandit optimization, and continuous (e.g., integer or floating point)-valued configurations using Bayesian optimization. This makes it suitable for a wide range of applications. 10 | 11 | Ax has been successfully applied to a variety of product, infrastructure, ML, and research applications at Facebook. 12 | 13 | # Unique capabilities 14 | - **Support for noisy functions**. Results of A/B tests and simulations with reinforcement learning agents often exhibit high amounts of noise. Ax supports [state-of-the-art algorithms](https://research.facebook.com/blog/2018/09/efficient-tuning-of-online-systems-using-bayesian-optimization/) which work better than traditional Bayesian optimization in high-noise settings. 15 | - **Customization**. Ax's developer API makes it easy to integrate custom data modeling and decision algorithms. This allows developers to build their own custom optimization services with minimal overhead. 16 | - **Multi-modal experimentation**. Ax has first-class support for running and combining data from different types of experiments, such as "offline" simulation data and "online" data from real-world experiments. 17 | - **Multi-objective optimization**. Ax supports multi-objective and constrained optimization which are common to real-world problems, like improving load time without increasing data use. 18 | -------------------------------------------------------------------------------- /ax/plot/tests/test_contours.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import plotly.graph_objects as go 8 | from ax.modelbridge.registry import Models 9 | from ax.plot.base import AxPlotConfig 10 | from ax.plot.contour import ( 11 | interact_contour, 12 | interact_contour_plotly, 13 | plot_contour, 14 | plot_contour_plotly, 15 | ) 16 | from ax.utils.common.testutils import TestCase 17 | from ax.utils.testing.core_stubs import get_branin_experiment 18 | from ax.utils.testing.mock import fast_botorch_optimize 19 | 20 | 21 | class ContoursTest(TestCase): 22 | @fast_botorch_optimize 23 | def testContours(self): 24 | exp = get_branin_experiment(with_str_choice_param=True, with_batch=True) 25 | exp.trials[0].run() 26 | model = Models.BOTORCH( 27 | # Model bridge kwargs 28 | experiment=exp, 29 | data=exp.fetch_data(), 30 | ) 31 | # Assert that each type of plot can be constructed successfully 32 | plot = plot_contour_plotly( 33 | model, model.parameters[0], model.parameters[1], list(model.metric_names)[0] 34 | ) 35 | self.assertIsInstance(plot, go.Figure) 36 | plot = interact_contour_plotly(model, list(model.metric_names)[0]) 37 | self.assertIsInstance(plot, go.Figure) 38 | plot = interact_contour(model, list(model.metric_names)[0]) 39 | self.assertIsInstance(plot, AxPlotConfig) 40 | plot = plot = plot_contour( 41 | model, model.parameters[0], model.parameters[1], list(model.metric_names)[0] 42 | ) 43 | self.assertIsInstance(plot, AxPlotConfig) 44 | 45 | # Make sure all parameters and metrics are displayed in tooltips 46 | tooltips = list(exp.parameters.keys()) + list(exp.metrics.keys()) 47 | for d in plot.data["data"]: 48 | # Only check scatter plots hoverovers 49 | if d["type"] != "scatter": 50 | continue 51 | for text in d["text"]: 52 | for tt in tooltips: 53 | self.assertTrue(tt in text) 54 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_typeutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Dict 8 | 9 | import numpy as np 10 | from ax.utils.common.testutils import TestCase 11 | from ax.utils.common.typeutils import ( 12 | checked_cast, 13 | checked_cast_complex, 14 | checked_cast_dict, 15 | checked_cast_list, 16 | checked_cast_optional, 17 | not_none, 18 | numpy_type_to_python_type, 19 | ) 20 | 21 | 22 | class TestTypeUtils(TestCase): 23 | def test_not_none(self): 24 | self.assertEqual(not_none("not_none"), "not_none") 25 | with self.assertRaises(ValueError): 26 | not_none(None) 27 | 28 | def test_checked_cast(self): 29 | self.assertEqual(checked_cast(float, 2.0), 2.0) 30 | with self.assertRaises(ValueError): 31 | checked_cast(float, 2) 32 | 33 | def test_checked_cast_complex(self): 34 | t = Dict[int, str] 35 | self.assertEqual(checked_cast_complex(t, {1: "one"}), {1: "one"}) 36 | with self.assertRaises(ValueError): 37 | checked_cast_complex(t, {"one": 1}) 38 | 39 | def test_checked_cast_list(self): 40 | self.assertEqual(checked_cast_list(float, [1.0, 2.0]), [1.0, 2.0]) 41 | with self.assertRaises(ValueError): 42 | checked_cast_list(float, [1.0, 2]) 43 | 44 | def test_checked_cast_optional(self): 45 | self.assertEqual(checked_cast_optional(float, None), None) 46 | with self.assertRaises(ValueError): 47 | checked_cast_optional(float, 2) 48 | 49 | def test_checked_cast_dict(self): 50 | self.assertEqual(checked_cast_dict(str, int, {"some": 1}), {"some": 1}) 51 | with self.assertRaises(ValueError): 52 | checked_cast_dict(str, int, {"some": 1.0}) 53 | with self.assertRaises(ValueError): 54 | checked_cast_dict(str, int, {1: 1}) 55 | 56 | def test_numpy_type_to_python_type(self): 57 | self.assertEqual(type(numpy_type_to_python_type(np.int64(2))), int) 58 | self.assertEqual(type(numpy_type_to_python_type(np.float64(2))), float) 59 | -------------------------------------------------------------------------------- /ax/core/tests/test_risk_measures.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.core.risk_measures import RISK_MEASURE_NAME_TO_CLASS, RiskMeasure, VaR 8 | from ax.exceptions.core import UserInputError 9 | from ax.utils.common.testutils import TestCase 10 | 11 | 12 | class TestRiskMeasure(TestCase): 13 | def test_risk_measure(self): 14 | rm = RiskMeasure( 15 | risk_measure="VaR", 16 | options={"alpha": 0.8, "n_w": 5}, 17 | ) 18 | self.assertEqual(rm.risk_measure, "VaR") 19 | self.assertEqual(rm.options, {"alpha": 0.8, "n_w": 5}) 20 | rm_module = rm.module 21 | self.assertIsInstance(rm_module, VaR) 22 | self.assertEqual(rm_module.alpha, 0.8) 23 | self.assertEqual(rm_module.n_w, 5) 24 | self.assertFalse(rm.is_multi_output) 25 | 26 | # Test repr. 27 | expected_repr = ( 28 | "RiskMeasure(risk_measure=VaR, options={'alpha': 0.8, 'n_w': 5})" 29 | ) 30 | self.assertEqual(str(rm), expected_repr) 31 | 32 | # Test clone. 33 | rm_clone = rm.clone() 34 | self.assertEqual(str(rm), str(rm_clone)) 35 | 36 | # Test unknown risk measure. 37 | with self.assertRaisesRegex(UserInputError, "constructing"): 38 | RiskMeasure( 39 | risk_measure="VVar", 40 | options={}, 41 | ) 42 | # Test invalid options. 43 | with self.assertRaisesRegex(UserInputError, "constructing"): 44 | RiskMeasure( 45 | risk_measure="VaR", 46 | options={"alpha": 5, "n_w": 5}, 47 | ) 48 | 49 | def test_custom_risk_measure(self): 50 | # Test using user-defined risk measures. 51 | 52 | class CustomRM(VaR): 53 | pass 54 | 55 | RISK_MEASURE_NAME_TO_CLASS["custom"] = CustomRM 56 | 57 | rm = RiskMeasure( 58 | risk_measure="custom", 59 | options={"alpha": 0.8, "n_w": 5}, 60 | ) 61 | self.assertEqual(rm.risk_measure, "custom") 62 | self.assertIsInstance(rm.module, CustomRM) 63 | -------------------------------------------------------------------------------- /ax/storage/sqa_store/sqa_enum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import enum 8 | from typing import Any, Dict, List 9 | 10 | from ax.storage.sqa_store.db import NAME_OR_TYPE_FIELD_LENGTH 11 | from sqlalchemy import types 12 | 13 | 14 | class BaseNullableEnum(types.TypeDecorator): 15 | cache_ok = True 16 | 17 | def __init__(self, enum: Any, *arg: List[Any], **kw: Dict[Any, Any]) -> None: 18 | types.TypeDecorator.__init__(self, *arg, **kw) 19 | self._member_map = enum._member_map_ 20 | self._value2member_map = enum._value2member_map_ 21 | 22 | def process_bind_param(self, value: Any, dialect: Any) -> Any: 23 | if value is None: 24 | return value # pragma: no cover 25 | if not isinstance(value, enum.Enum): 26 | raise TypeError("Value is not an instance of Enum.") # pragma: no cover 27 | val = self._member_map.get(value.name) 28 | if val is None: 29 | raise ValueError( # pragma: no cover 30 | "Member '{value}' is not a supported enum: {members}".format( 31 | value=value, members=list(self._member_map.keys()) 32 | ) 33 | ) 34 | return val._value_ 35 | 36 | def process_result_value(self, value: Any, dialect: Any) -> Any: 37 | if value is None: 38 | return value # pragma: no cover 39 | member = self._value2member_map.get(value) 40 | if member is None: 41 | raise ValueError( # pragma: no cover 42 | f"Value '{value}' is not one of the supported " 43 | + "enum values: {supported_values}".format( 44 | supported_values=list(self._value2member_map.keys()) 45 | ) 46 | ) 47 | return member 48 | 49 | 50 | class IntEnum(BaseNullableEnum): 51 | # pyre-fixme[8]: Attribute has type `SmallInteger`; used as 52 | # `Type[sqlalchemy.sql.sqltypes.SmallInteger]`. 53 | impl: types.SmallInteger = types.SmallInteger 54 | 55 | 56 | class StringEnum(BaseNullableEnum): 57 | impl = types.VARCHAR(NAME_OR_TYPE_FIELD_LENGTH) 58 | -------------------------------------------------------------------------------- /ax/early_stopping/strategies/logical.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Dict, Optional, Set 7 | 8 | from ax.core.experiment import Experiment 9 | from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy 10 | 11 | 12 | class LogicalEarlyStoppingStrategy(BaseEarlyStoppingStrategy): 13 | def __init__( 14 | self, 15 | left: BaseEarlyStoppingStrategy, 16 | right: BaseEarlyStoppingStrategy, 17 | seconds_between_polls: int = 60, 18 | true_objective_metric_name: Optional[str] = None, 19 | ) -> None: 20 | super().__init__( 21 | seconds_between_polls=seconds_between_polls, 22 | true_objective_metric_name=true_objective_metric_name, 23 | ) 24 | 25 | self.left = left 26 | self.right = right 27 | 28 | 29 | class AndEarlyStoppingStrategy(LogicalEarlyStoppingStrategy): 30 | def should_stop_trials_early( 31 | self, 32 | trial_indices: Set[int], 33 | experiment: Experiment, 34 | **kwargs: Dict[str, Any], 35 | ) -> Dict[int, Optional[str]]: 36 | 37 | left = self.left.should_stop_trials_early( 38 | trial_indices=trial_indices, experiment=experiment, **kwargs 39 | ) 40 | right = self.right.should_stop_trials_early( 41 | trial_indices=trial_indices, experiment=experiment, **kwargs 42 | ) 43 | return { 44 | trial: f"{left[trial]}, {right[trial]}" for trial in left if trial in right 45 | } 46 | 47 | 48 | class OrEarlyStoppingStrategy(LogicalEarlyStoppingStrategy): 49 | def should_stop_trials_early( 50 | self, 51 | trial_indices: Set[int], 52 | experiment: Experiment, 53 | **kwargs: Dict[str, Any], 54 | ) -> Dict[int, Optional[str]]: 55 | return { 56 | **self.left.should_stop_trials_early( 57 | trial_indices=trial_indices, experiment=experiment, **kwargs 58 | ), 59 | **self.right.should_stop_trials_early( 60 | trial_indices=trial_indices, experiment=experiment, **kwargs 61 | ), 62 | } 63 | -------------------------------------------------------------------------------- /ax/utils/common/tests/test_logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from tempfile import NamedTemporaryFile 9 | from unittest.mock import patch 10 | 11 | from ax.utils.common.logger import build_file_handler, get_logger 12 | from ax.utils.common.testutils import TestCase 13 | 14 | 15 | BASE_LOGGER_NAME = f"ax.{__name__}" 16 | 17 | 18 | class LoggerTest(TestCase): 19 | def setUp(self): 20 | self.warning_string = "Test warning" 21 | 22 | def testLogger(self): 23 | logger = get_logger(BASE_LOGGER_NAME + ".testLogger") 24 | # Verify it doesn't crash 25 | logger.warning(self.warning_string) 26 | # Patch it, verify we actually called it 27 | patcher = patch.object(logger, "warning") 28 | mock_warning = patcher.start() 29 | logger.warning(self.warning_string) 30 | mock_warning.assert_called_once_with(self.warning_string) 31 | # Need to stop patcher, else in some environments (like pytest) 32 | # the mock will leak into other tests, since it's getting set 33 | # onto the python logger directly. 34 | patcher.stop() 35 | 36 | def testLoggerWithFile(self): 37 | with NamedTemporaryFile() as tf: 38 | logger = get_logger(BASE_LOGGER_NAME + ".testLoggerWithFile") 39 | logger.addHandler(build_file_handler(tf.name)) 40 | logger.info(self.warning_string) 41 | output = str(tf.read()) 42 | self.assertIn(BASE_LOGGER_NAME, output) 43 | self.assertIn(self.warning_string, output) 44 | tf.close() 45 | 46 | def testLoggerOutputNameWithFile(self): 47 | with NamedTemporaryFile() as tf: 48 | logger = get_logger(BASE_LOGGER_NAME + ".testLoggerOutputNameWithFile") 49 | logger.addHandler(build_file_handler(tf.name)) 50 | logger = logging.LoggerAdapter(logger, {"output_name": "my_output_name"}) 51 | logger.warning(self.warning_string) 52 | output = str(tf.read()) 53 | self.assertIn("my_output_name", output) 54 | self.assertIn(self.warning_string, output) 55 | tf.close() 56 | -------------------------------------------------------------------------------- /ax/models/tests/test_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from ax.core.search_space import SearchSpaceDigest 9 | from ax.models.torch_base import TorchModel 10 | from ax.utils.common.testutils import TestCase 11 | from botorch.utils.datasets import FixedNoiseDataset 12 | 13 | 14 | class TorchModelTest(TestCase): 15 | def setUp(self): 16 | self.dataset = FixedNoiseDataset( 17 | X=torch.zeros(1), Y=torch.zeros(1), Yvar=torch.ones(1) 18 | ) 19 | 20 | def testTorchModelFit(self): 21 | torch_model = TorchModel() 22 | torch_model.fit( 23 | datasets=[self.dataset], 24 | metric_names=["y"], 25 | search_space_digest=SearchSpaceDigest( 26 | feature_names=["x1"], 27 | bounds=[(0, 1)], 28 | ), 29 | ) 30 | 31 | def testTorchModelPredict(self): 32 | torch_model = TorchModel() 33 | with self.assertRaises(NotImplementedError): 34 | torch_model.predict(torch.zeros(1)) 35 | 36 | def testTorchModelGen(self): 37 | torch_model = TorchModel() 38 | with self.assertRaises(NotImplementedError): 39 | torch_model.gen(n=1, bounds=[(0, 1)], objective_weights=torch.ones(1)) 40 | 41 | def testNumpyTorchBestPoint(self): 42 | torch_model = TorchModel() 43 | x = torch_model.best_point(bounds=[(0, 1)], objective_weights=torch.ones(1)) 44 | self.assertIsNone(x) 45 | 46 | def testTorchModelCrossValidate(self): 47 | torch_model = TorchModel() 48 | with self.assertRaises(NotImplementedError): 49 | torch_model.cross_validate( 50 | datasets=[self.dataset], 51 | metric_names=["y"], 52 | X_test=torch.ones(1), 53 | search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]), 54 | ) 55 | 56 | def testTorchModelUpdate(self): 57 | model = TorchModel() 58 | with self.assertRaises(NotImplementedError): 59 | model.update( 60 | datasets=[self.dataset], 61 | metric_names=["y"], 62 | search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]), 63 | ) 64 | -------------------------------------------------------------------------------- /ax/models/tests/test_full_factorial.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import numpy as np 9 | from ax.models.discrete.full_factorial import FullFactorialGenerator 10 | from ax.utils.common.testutils import TestCase 11 | 12 | 13 | class FullFactorialGeneratorTest(TestCase): 14 | def testFullFactorial(self): 15 | generator = FullFactorialGenerator() 16 | parameter_values = [[1, 2], ["foo", "bar"]] 17 | generated_points, weights, _ = generator.gen( 18 | n=-1, parameter_values=parameter_values, objective_weights=np.ones(1) 19 | ) 20 | expected_points = [[1, "foo"], [1, "bar"], [2, "foo"], [2, "bar"]] 21 | self.assertEqual(generated_points, expected_points) 22 | self.assertEqual(weights, [1 for _ in range(len(expected_points))]) 23 | 24 | def testFullFactorialValidation(self): 25 | # Raise error because cardinality exceeds max cardinality 26 | generator = FullFactorialGenerator(max_cardinality=5, check_cardinality=True) 27 | parameter_values = [[1, 2], ["foo", "bar"], [True, False]] 28 | with self.assertRaises(ValueError): 29 | generated_points, weights, _ = generator.gen( 30 | n=-1, parameter_values=parameter_values, objective_weights=np.ones(1) 31 | ) 32 | 33 | # Raise error because n != -1 34 | generator = FullFactorialGenerator() 35 | parameter_values = [[1, 2], ["foo", "bar"]] 36 | with self.assertRaises(ValueError): 37 | generated_points, weights, _ = generator.gen( 38 | n=5, parameter_values=parameter_values, objective_weights=np.ones(1) 39 | ) 40 | 41 | def testFullFactorialFixedFeatures(self): 42 | generator = FullFactorialGenerator(max_cardinality=5, check_cardinality=True) 43 | parameter_values = [[1, 2], ["foo", "bar"]] 44 | generated_points, weights, _ = generator.gen( 45 | n=-1, 46 | parameter_values=parameter_values, 47 | objective_weights=np.ones(1), 48 | fixed_features={1: "foo"}, 49 | ) 50 | expected_points = [[1, "foo"], [2, "foo"]] 51 | self.assertEqual(generated_points, expected_points) 52 | self.assertEqual(weights, [1 for _ in range(len(expected_points))]) 53 | -------------------------------------------------------------------------------- /ax/benchmark/tests/test_scored_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.benchmark.scored_benchmark import ( 7 | scored_benchmark_full_run, 8 | scored_benchmark_replication, 9 | scored_benchmark_test, 10 | ) 11 | from ax.utils.common.testutils import TestCase 12 | from ax.utils.testing.benchmark_stubs import ( 13 | get_aggregated_benchmark_result, 14 | get_single_objective_benchmark_problem, 15 | get_sobol_gpei_benchmark_method, 16 | ) 17 | from ax.utils.testing.mock import fast_botorch_optimize 18 | 19 | 20 | class TestProblems(TestCase): 21 | @fast_botorch_optimize 22 | def test_scored_benchmark_replication(self) -> None: 23 | scored_result = scored_benchmark_replication( 24 | problem=get_single_objective_benchmark_problem(), 25 | method=get_sobol_gpei_benchmark_method(), 26 | ) 27 | 28 | self.assertEqual(len(scored_result.score_trace), 4) 29 | self.assertTrue( 30 | (scored_result.score_trace < 100).all() 31 | ) # Score should never be over 100 32 | 33 | @fast_botorch_optimize 34 | def test_scored_benchmark_test(self) -> None: 35 | aggregated_scored_result = scored_benchmark_test( 36 | problem=get_single_objective_benchmark_problem(), 37 | method=get_sobol_gpei_benchmark_method(), 38 | num_replications=2, 39 | ) 40 | 41 | self.assertEqual(len(aggregated_scored_result.score_trace), 4) 42 | self.assertTrue( 43 | (aggregated_scored_result.score_trace["mean"] < 100).all() 44 | ) # Score should never be over 100 45 | self.assertTrue( 46 | (aggregated_scored_result.score_trace["median"] < 100).all() 47 | ) # Score should never be over 100 48 | 49 | @fast_botorch_optimize 50 | def test_scored_benchmark_full_run(self) -> None: 51 | aggregated_scored_results = scored_benchmark_full_run( 52 | problems_baseline_results=[ 53 | ( 54 | get_single_objective_benchmark_problem(), 55 | get_aggregated_benchmark_result(), 56 | ) 57 | ], 58 | methods=[get_sobol_gpei_benchmark_method()], 59 | num_replications=2, 60 | ) 61 | 62 | self.assertEqual(len(aggregated_scored_results), 1) 63 | -------------------------------------------------------------------------------- /sphinx/source/metrics.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | ax.metrics 5 | =================================== 6 | 7 | .. automodule:: ax.metrics 8 | .. currentmodule:: ax.metrics 9 | 10 | 11 | BoTorch Test Problem 12 | ~~~~~~ 13 | 14 | .. automodule:: ax.metrics.botorch_test_problem 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Branin 20 | ~~~~~~ 21 | 22 | .. automodule:: ax.metrics.branin 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Branin Map 28 | ~~~~~~~~~~ 29 | 30 | .. automodule:: ax.metrics.branin_map 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | 35 | Chemistry 36 | ~~~~~~~~~ 37 | 38 | .. automodule:: ax.metrics.chemistry 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | 43 | Curve 44 | ~~~~~~~~~ 45 | 46 | .. automodule:: ax.metrics.curve 47 | :members: 48 | :undoc-members: 49 | :show-inheritance: 50 | 51 | 52 | Factorial 53 | ~~~~~~~~~ 54 | 55 | .. automodule:: ax.metrics.factorial 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | Hartmann6 61 | ~~~~~~~~~ 62 | 63 | .. automodule:: ax.metrics.hartmann6 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | Jenatton 69 | ~~~~~~~~~ 70 | 71 | .. automodule:: ax.metrics.jenatton 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | 76 | 77 | L2 Norm 78 | ~~~~~~~ 79 | 80 | .. automodule:: ax.metrics.l2norm 81 | :members: 82 | :undoc-members: 83 | :show-inheritance: 84 | 85 | Noisy Functions 86 | ~~~~~~~~~~~~~~~ 87 | 88 | .. automodule:: ax.metrics.noisy_function 89 | :members: 90 | :undoc-members: 91 | :show-inheritance: 92 | 93 | Noisy Function Map 94 | ~~~~~~~~~~~~~~~~~~~~ 95 | 96 | .. automodule:: ax.metrics.noisy_function_map 97 | :members: 98 | :undoc-members: 99 | :show-inheritance: 100 | 101 | Sklearn 102 | ~~~~~~~ 103 | 104 | .. automodule:: ax.metrics.sklearn 105 | :members: 106 | :undoc-members: 107 | :show-inheritance: 108 | 109 | Tensorboard 110 | ~~~~~~~~~~~ 111 | 112 | .. automodule:: ax.metrics.tensorboard 113 | :members: 114 | :undoc-members: 115 | :show-inheritance: 116 | 117 | 118 | TorchX 119 | ~~~~~~~~~~~ 120 | 121 | .. automodule:: ax.metrics.torchx 122 | :members: 123 | :undoc-members: 124 | :show-inheritance: 125 | -------------------------------------------------------------------------------- /ax/benchmark/problems/synthetic/hss/jenatton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from ax.benchmark.benchmark_problem import SingleObjectiveBenchmarkProblem 7 | from ax.core.objective import Objective 8 | from ax.core.optimization_config import OptimizationConfig 9 | from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter 10 | from ax.core.search_space import HierarchicalSearchSpace 11 | from ax.metrics.jenatton import JenattonMetric 12 | from ax.runners.synthetic import SyntheticRunner 13 | 14 | 15 | def get_jenatton_benchmark_problem() -> SingleObjectiveBenchmarkProblem: 16 | search_space = HierarchicalSearchSpace( 17 | parameters=[ 18 | ChoiceParameter( 19 | name="x1", 20 | parameter_type=ParameterType.INT, 21 | values=[0, 1], 22 | dependents={0: ["x2", "r8"], 1: ["x3", "r9"]}, 23 | ), 24 | ChoiceParameter( 25 | name="x2", 26 | parameter_type=ParameterType.INT, 27 | values=[0, 1], 28 | dependents={0: ["x4"], 1: ["x5"]}, 29 | ), 30 | ChoiceParameter( 31 | name="x3", 32 | parameter_type=ParameterType.INT, 33 | values=[0, 1], 34 | dependents={0: ["x6"], 1: ["x7"]}, 35 | ), 36 | *[ 37 | RangeParameter( 38 | name=f"x{i}", 39 | parameter_type=ParameterType.FLOAT, 40 | lower=0.0, 41 | upper=1.0, 42 | ) 43 | for i in range(4, 8) 44 | ], 45 | RangeParameter( 46 | name="r8", parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0 47 | ), 48 | RangeParameter( 49 | name="r9", parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0 50 | ), 51 | ] 52 | ) 53 | 54 | optimization_config = OptimizationConfig( 55 | objective=Objective(metric=JenattonMetric(), minimize=True) 56 | ) 57 | 58 | return SingleObjectiveBenchmarkProblem( 59 | name="Jenatton", 60 | search_space=search_space, 61 | optimization_config=optimization_config, 62 | runner=SyntheticRunner(), 63 | optimal_value=0.1, 64 | ) 65 | -------------------------------------------------------------------------------- /ax/plot/tests/test_feature_importances.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.modelbridge.base import ModelBridge 8 | from ax.modelbridge.registry import Models 9 | from ax.plot.base import AxPlotConfig 10 | from ax.plot.feature_importances import ( 11 | plot_feature_importance_by_feature, 12 | plot_feature_importance_by_feature_plotly, 13 | plot_feature_importance_by_metric, 14 | plot_feature_importance_by_metric_plotly, 15 | plot_relative_feature_importance, 16 | plot_relative_feature_importance_plotly, 17 | ) 18 | from ax.utils.common.testutils import TestCase 19 | from ax.utils.testing.core_stubs import get_branin_experiment 20 | from ax.utils.testing.mock import fast_botorch_optimize 21 | from plotly import graph_objects as go 22 | 23 | DUMMY_CAPTION = "test_caption" 24 | 25 | 26 | def get_modelbridge() -> ModelBridge: 27 | exp = get_branin_experiment(with_batch=True) 28 | exp.trials[0].run() 29 | return Models.BOTORCH( 30 | # Model bridge kwargs 31 | experiment=exp, 32 | data=exp.fetch_data(), 33 | ) 34 | 35 | 36 | class FeatureImportancesTest(TestCase): 37 | @fast_botorch_optimize 38 | def testFeatureImportances(self): 39 | model = get_modelbridge() 40 | # Assert that each type of plot can be constructed successfully 41 | plot = plot_feature_importance_by_feature_plotly(model=model) 42 | self.assertIsInstance(plot, go.Figure) 43 | plot = plot_feature_importance_by_feature_plotly( 44 | model=model, caption=DUMMY_CAPTION 45 | ) 46 | self.assertIsInstance(plot, go.Figure) 47 | self.assertEqual(len(plot.layout.annotations), 1) 48 | self.assertEqual(plot.layout.annotations[0].text, DUMMY_CAPTION) 49 | plot = plot_feature_importance_by_feature(model=model) 50 | self.assertIsInstance(plot, AxPlotConfig) 51 | plot = plot_feature_importance_by_metric_plotly(model=model) 52 | self.assertIsInstance(plot, go.Figure) 53 | plot = plot_feature_importance_by_metric(model=model) 54 | self.assertIsInstance(plot, AxPlotConfig) 55 | plot = plot_relative_feature_importance_plotly(model=model) 56 | self.assertIsInstance(plot, go.Figure) 57 | plot = plot_relative_feature_importance(model=model) 58 | self.assertIsInstance(plot, AxPlotConfig) 59 | -------------------------------------------------------------------------------- /scripts/insert_api_refs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import ast 9 | import glob 10 | import re 11 | 12 | 13 | def list_functions(source_glob): 14 | """ 15 | List all of the functions and classes defined 16 | """ 17 | defined = [] 18 | # Iterate through each source file 19 | for sp in glob.glob(source_glob): 20 | module_name = sp[:-3] 21 | module_name = module_name.replace("/", ".") 22 | # Parse the source file into an AST 23 | node = ast.parse(open(sp).read()) 24 | # Extract the names of all functions and classes defined in this file 25 | defined.extend( 26 | (n.name, module_name + "." + n.name) 27 | for n in node.body 28 | if (isinstance(n, ast.FunctionDef) or isinstance(n, ast.ClassDef)) 29 | ) 30 | return defined 31 | 32 | 33 | def replace_backticks(source_path, docs_path): 34 | markdown_glob = docs_path + "/*.md" 35 | source_glob = source_path + "/**/*.py" 36 | methods = list_functions(source_glob) 37 | for f in glob.glob(markdown_glob): 38 | for n, m in methods: 39 | # Match backquoted mentions of the function/class name which are 40 | # not already links 41 | pattern = "(? AxPlotConfig: 19 | """ 20 | Calculates and plots the marginal effects -- the effect of changing one 21 | factor away from the randomized distribution of the experiment and fixing it 22 | at a particular level. 23 | 24 | Args: 25 | model: Model to use for estimating effects 26 | metric: The metric for which to plot marginal effects. 27 | 28 | Returns: 29 | AxPlotConfig of the marginal effects 30 | """ 31 | plot_data, _, _ = get_plot_data(model, {}, {metric}) 32 | 33 | arm_dfs = [] 34 | for arm in plot_data.in_sample.values(): 35 | arm_df = pd.DataFrame(arm.parameters, index=[arm.name]) 36 | arm_df["mean"] = arm.y_hat[metric] 37 | arm_df["sem"] = arm.se_hat[metric] 38 | arm_dfs.append(arm_df) 39 | effect_table = marginal_effects(pd.concat(arm_dfs, 0)) 40 | 41 | varnames = effect_table["Name"].unique() 42 | data: List[Any] = [] 43 | for varname in varnames: 44 | var_df = effect_table[effect_table["Name"] == varname] 45 | data += [ 46 | go.Bar( 47 | x=var_df["Level"], 48 | y=var_df["Beta"], 49 | error_y={"type": "data", "array": var_df["SE"]}, 50 | name=varname, 51 | ) 52 | ] 53 | fig = subplots.make_subplots( 54 | cols=len(varnames), 55 | rows=1, 56 | subplot_titles=list(varnames), 57 | print_grid=False, 58 | shared_yaxes=True, 59 | ) 60 | for idx, item in enumerate(data): 61 | fig.append_trace(item, 1, idx + 1) 62 | fig.layout.showlegend = False 63 | # fig.layout.margin = go.layout.Margin(l=2, r=2) 64 | fig.layout.title = "Marginal Effects by Factor" 65 | fig.layout.yaxis = { 66 | "title": "% better than experiment average", 67 | "hoverformat": ".{}f".format(DECIMALS), 68 | } 69 | return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC) 70 | -------------------------------------------------------------------------------- /ax/modelbridge/tests/test_cap_parameter_transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter 8 | from ax.core.search_space import SearchSpace 9 | from ax.exceptions.core import UnsupportedError 10 | from ax.modelbridge.transforms.cap_parameter import CapParameter 11 | from ax.utils.common.testutils import TestCase 12 | from ax.utils.testing.core_stubs import get_robust_search_space 13 | 14 | 15 | class CapParameterTest(TestCase): 16 | def setUp(self): 17 | self.search_space = SearchSpace( 18 | parameters=[ 19 | RangeParameter( 20 | "a", lower=1, upper=3, parameter_type=ParameterType.FLOAT 21 | ), 22 | ChoiceParameter( 23 | "b", parameter_type=ParameterType.STRING, values=["a", "b", "c"] 24 | ), 25 | ] 26 | ) 27 | 28 | def test_transform_search_space(self): 29 | t = CapParameter( 30 | search_space=self.search_space, 31 | observation_features=[], 32 | observation_data=[], 33 | config={"a": "2"}, 34 | ) 35 | t.transform_search_space(self.search_space) 36 | self.assertEqual(self.search_space.parameters.get("a").upper, 2) 37 | t2 = CapParameter( 38 | search_space=self.search_space, 39 | observation_features=[], 40 | observation_data=[], 41 | config={"b": "2"}, 42 | ) 43 | with self.assertRaises(NotImplementedError): 44 | t2.transform_search_space(self.search_space) 45 | 46 | def test_w_parameter_distributions(self): 47 | rss = get_robust_search_space() 48 | # Transform a non-distributional parameter. 49 | t = CapParameter( 50 | search_space=rss, 51 | observation_features=[], 52 | observation_data=[], 53 | config={"z": "2"}, 54 | ) 55 | t.transform_search_space(rss) 56 | self.assertEqual(rss.parameters.get("z").upper, 2) 57 | # Error with distributional parameter. 58 | t = CapParameter( 59 | search_space=rss, 60 | observation_features=[], 61 | observation_data=[], 62 | config={"x": "2"}, 63 | ) 64 | with self.assertRaisesRegex(UnsupportedError, "transform is not supported"): 65 | t.transform_search_space(rss) 66 | -------------------------------------------------------------------------------- /ax/core/tests/test_arm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ax.core.arm import Arm 8 | from ax.utils.common.testutils import TestCase 9 | 10 | 11 | class ArmTest(TestCase): 12 | def setUp(self): 13 | pass 14 | 15 | def testInit(self): 16 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}) 17 | self.assertEqual(str(arm), "Arm(parameters={'y': 0.25, 'x': 0.75, 'z': 75})") 18 | 19 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}, name="status_quo") 20 | self.assertEqual( 21 | str(arm), 22 | "Arm(name='status_quo', parameters={'y': 0.25, 'x': 0.75, 'z': 75})", 23 | ) 24 | 25 | def testNameValidation(self): 26 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}) 27 | self.assertFalse(arm.has_name) 28 | with self.assertRaises(ValueError): 29 | arm.name 30 | arm.name = "0_0" 31 | with self.assertRaises(ValueError): 32 | arm.name = "1_0" 33 | 34 | def testNameOrShortSignature(self): 35 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}, name="0_0") 36 | self.assertEqual(arm.name_or_short_signature, "0_0") 37 | 38 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}) 39 | self.assertEqual(arm.name_or_short_signature, arm.signature[-4:]) 40 | 41 | def testEq(self): 42 | arm1 = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}) 43 | arm2 = Arm(parameters={"z": 75, "x": 0.75, "y": 0.25}) 44 | self.assertEqual(arm1, arm2) 45 | 46 | arm3 = Arm(parameters={"z": 5, "x": 0.75, "y": 0.25}) 47 | self.assertNotEqual(arm1, arm3) 48 | 49 | arm4 = Arm(name="0_0", parameters={"y": 0.25, "x": 0.75, "z": 75}) 50 | arm5 = Arm(name="0_0", parameters={"y": 0.25, "x": 0.75, "z": 75}) 51 | self.assertEqual(arm4, arm5) 52 | 53 | arm6 = Arm(name="0_1", parameters={"y": 0.25, "x": 0.75, "z": 75}) 54 | self.assertNotEqual(arm4, arm6) 55 | 56 | def testClone(self): 57 | arm1 = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}) 58 | arm2 = arm1.clone() 59 | self.assertFalse(arm1 is arm2) 60 | self.assertEqual(arm1, arm2) 61 | self.assertFalse(arm1.parameters is arm2.parameters) 62 | 63 | def testSortable(self): 64 | arm1 = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}) 65 | arm2 = Arm(parameters={"z": 0, "x": 0, "y": 0}) 66 | self.assertTrue(arm1 < arm2) 67 | --------------------------------------------------------------------------------