├── website
├── static
│ ├── .nojekyll
│ ├── CNAME
│ ├── img
│ │ ├── favicon.png
│ │ ├── oss_logo.png
│ │ ├── favicon
│ │ │ └── favicon.ico
│ │ ├── ax_wireframe.svg
│ │ ├── database-solid.svg
│ │ ├── th-large-solid.svg
│ │ ├── dice-solid.svg
│ │ ├── ax.svg
│ │ ├── ax_lockup.svg
│ │ ├── ax_logo_lockup.svg
│ │ └── ax_lockup_white.svg
│ └── js
│ │ ├── mathjax.js
│ │ └── plotUtils.js
├── versioned_docs
│ └── .gitkeep
├── versioned_sidebars
│ └── .gitkeep
├── sidebars.json
├── package.json
├── tutorials.json
└── core
│ └── Footer.js
├── docs
├── assets
│ ├── gp_opt.png
│ ├── bo_1d_opt.gif
│ ├── mab_animate.gif
│ ├── mab_probs.png
│ ├── mab_regret.png
│ ├── gp_posterior.png
│ ├── bandit_allocation.png
│ └── example_shrinkage.png
├── algo-overview.md
└── why-ax.md
├── ax
├── metrics
│ ├── chemistry_data.zip
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_curve.py
│ │ └── test_tensorboard.py
│ ├── l2norm.py
│ ├── __init__.py
│ ├── hartmann6.py
│ ├── branin.py
│ ├── botorch_test_problem.py
│ └── jenatton.py
├── benchmark
│ ├── methods
│ │ ├── __init__.py
│ │ ├── choose_generation_strategy.py
│ │ └── saasbo.py
│ ├── problems
│ │ ├── __init__.py
│ │ ├── hpo
│ │ │ └── __init__.py
│ │ ├── hss
│ │ │ └── __init__.py
│ │ ├── synthetic
│ │ │ ├── __init__.py
│ │ │ └── hss
│ │ │ │ ├── __init__.py
│ │ │ │ └── jenatton.py
│ │ ├── baseline_results
│ │ │ ├── __init__.py
│ │ │ ├── hpo
│ │ │ │ ├── __init__.py
│ │ │ │ └── torchvision
│ │ │ │ │ └── __init__.py
│ │ │ └── synthetic
│ │ │ │ ├── __init__.py
│ │ │ │ ├── hd
│ │ │ │ └── __init__.py
│ │ │ │ └── hss
│ │ │ │ └── __init__.py
│ │ └── hd_embedding.py
│ ├── __init__.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_torchvision_problem_storage.py
│ │ ├── test_problems.py
│ │ ├── test_benchmark_method.py
│ │ └── test_scored_benchmark.py
│ └── benchmark_method.py
├── plot
│ ├── __init__.py
│ ├── css
│ │ ├── __init__.py
│ │ └── base.css
│ ├── js
│ │ ├── __init__.py
│ │ ├── common
│ │ │ ├── __init__.py
│ │ │ ├── css.js
│ │ │ ├── plotly_requires.js
│ │ │ ├── plotly_offline.js
│ │ │ └── plotly_online.js
│ │ └── generic_plotly.js
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_parallel_coordinates.py
│ │ ├── test_helper.py
│ │ ├── test_diagnostic.py
│ │ ├── test_slices.py
│ │ ├── test_traces.py
│ │ ├── test_fitted_scatter.py
│ │ ├── test_contours.py
│ │ └── test_feature_importances.py
│ └── marginal_effects.py
├── utils
│ ├── __init__.py
│ ├── stats
│ │ ├── __init__.py
│ │ └── tests
│ │ │ └── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── tests
│ │ │ ├── __init__.py
│ │ │ ├── test_serialization.py
│ │ │ ├── test_docutils.py
│ │ │ ├── test_equality.py
│ │ │ ├── test_typeutils.py
│ │ │ └── test_logger.py
│ │ ├── base.py
│ │ ├── timeutils.py
│ │ └── docutils.py
│ ├── notebook
│ │ ├── __init__.py
│ │ └── plotting.py
│ ├── report
│ │ ├── __init__.py
│ │ ├── tests
│ │ │ ├── __init__.py
│ │ │ └── test_render.py
│ │ └── resources
│ │ │ ├── __init__.py
│ │ │ ├── simple_template.html
│ │ │ ├── sufficient_statistic.html
│ │ │ └── base_template.html
│ ├── testing
│ │ ├── __init__.py
│ │ ├── metrics
│ │ │ ├── __init__.py
│ │ │ └── backend_simulator_map.py
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── doctest.py
│ │ ├── test_init_files.py
│ │ ├── pyre_strict.py
│ │ ├── torch_stubs.py
│ │ └── unittest_conventions.py
│ ├── tutorials
│ │ └── __init__.py
│ ├── flake8_plugins
│ │ └── __init__.py
│ └── measurement
│ │ ├── __init__.py
│ │ └── tests
│ │ └── __init__.py
├── core
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_map_metric.py
│ │ ├── test_metric.py
│ │ ├── test_types.py
│ │ ├── test_runner.py
│ │ ├── test_parameter_distribution.py
│ │ ├── test_risk_measures.py
│ │ └── test_arm.py
│ ├── map_metric.py
│ └── __init__.py
├── exceptions
│ ├── __init__.py
│ ├── model.py
│ ├── constants.py
│ ├── storage.py
│ ├── generation_strategy.py
│ └── data_provider.py
├── models
│ ├── discrete
│ │ ├── __init__.py
│ │ └── eb_thompson.py
│ ├── random
│ │ ├── __init__.py
│ │ └── uniform.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_base.py
│ │ ├── test_alebo_initializer.py
│ │ ├── test_rembo_initializer.py
│ │ ├── test_randomforest.py
│ │ ├── test_discrete.py
│ │ ├── test_torch.py
│ │ └── test_full_factorial.py
│ ├── torch
│ │ ├── __init__.py
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── botorch_modular
│ │ │ └── __init__.py
│ │ └── frontier_utils.py
│ ├── __init__.py
│ └── types.py
├── runners
│ ├── tests
│ │ └── __init__.py
│ ├── __init__.py
│ └── synthetic.py
├── service
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_early_stopping.py
│ │ └── test_best_point.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── early_stopping.py
│ └── __init__.py
├── early_stopping
│ ├── tests
│ │ └── __init__.py
│ ├── __init__.py
│ └── strategies
│ │ ├── __init__.py
│ │ └── logical.py
├── global_stopping
│ ├── tests
│ │ └── __init__.py
│ ├── __init__.py
│ └── strategies
│ │ ├── __init__.py
│ │ └── base.py
├── modelbridge
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_centered_unit_x_transform.py
│ │ ├── test_transform_utils.py
│ │ ├── test_rounding.py
│ │ ├── test_base_transform.py
│ │ └── test_cap_parameter_transform.py
│ ├── strategies
│ │ └── __init__.py
│ ├── transforms
│ │ ├── __init__.py
│ │ ├── centered_unit_x.py
│ │ └── cap_parameter.py
│ └── __init__.py
├── storage
│ ├── json_store
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── load.py
│ │ └── save.py
│ ├── sqa_store
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── delete.py
│ │ ├── structs.py
│ │ ├── timestamp.py
│ │ ├── reduced_state.py
│ │ └── sqa_enum.py
│ ├── __init__.py
│ └── utils.py
└── __init__.py
├── CHANGELOG.md
├── pyproject.toml
├── sphinx
└── source
│ ├── ax.rst
│ ├── global_stopping.rst
│ ├── index.rst
│ ├── runners.rst
│ ├── early_stopping.rst
│ ├── exceptions.rst
│ ├── service.rst
│ └── metrics.rst
├── pytest.ini
├── scripts
├── import_ax.py
├── docker_install.sh
├── build_ax.sh
├── wheels_build.ps1
├── patch_site_config.py
└── insert_api_refs.py
├── .flake8
└── LICENSE
/website/static/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/website/static/CNAME:
--------------------------------------------------------------------------------
1 | ax.dev
2 |
--------------------------------------------------------------------------------
/website/versioned_docs/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/website/versioned_sidebars/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/assets/gp_opt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/gp_opt.png
--------------------------------------------------------------------------------
/docs/assets/bo_1d_opt.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/bo_1d_opt.gif
--------------------------------------------------------------------------------
/docs/assets/mab_animate.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/mab_animate.gif
--------------------------------------------------------------------------------
/docs/assets/mab_probs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/mab_probs.png
--------------------------------------------------------------------------------
/docs/assets/mab_regret.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/mab_regret.png
--------------------------------------------------------------------------------
/ax/metrics/chemistry_data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/ax/metrics/chemistry_data.zip
--------------------------------------------------------------------------------
/docs/assets/gp_posterior.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/gp_posterior.png
--------------------------------------------------------------------------------
/website/static/img/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/website/static/img/favicon.png
--------------------------------------------------------------------------------
/website/static/img/oss_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/website/static/img/oss_logo.png
--------------------------------------------------------------------------------
/docs/assets/bandit_allocation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/bandit_allocation.png
--------------------------------------------------------------------------------
/docs/assets/example_shrinkage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/docs/assets/example_shrinkage.png
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | Ax uses GitHub tags for managing releases. See changelog [here](https://github.com/facebook/Ax/releases).
2 |
--------------------------------------------------------------------------------
/website/static/img/favicon/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basnijholt/Ax/main/website/static/img/favicon/favicon.ico
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=34.4", "wheel", "setuptools_scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.usort]
6 | first_party_detection = false
7 |
--------------------------------------------------------------------------------
/docs/algo-overview.md:
--------------------------------------------------------------------------------
1 | ---
2 | id: algo-overview
3 | title: Overview
4 | ---
5 |
6 | Ax supports:
7 | * Bandit optimization
8 | * Empirical Bayes with Thompson sampling
9 | * Bayesian optimization
10 |
--------------------------------------------------------------------------------
/ax/benchmark/methods/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/sphinx/source/ax.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax
5 | ===================================
6 |
7 | .. automodule:: ax
8 | :members:
9 | :noindex:
10 |
11 | .. currentmodule:: ax
12 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/hpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/hss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/synthetic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/synthetic/hss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/plot/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/baseline_results/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/core/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/exceptions/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/plot/css/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/plot/js/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/plot/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/stats/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/baseline_results/hpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/benchmark/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/metrics/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/discrete/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/random/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/torch/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/plot/js/common/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/runners/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/service/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/service/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/common/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/notebook/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/report/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/testing/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/tutorials/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/baseline_results/synthetic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/baseline_results/synthetic/hd/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/early_stopping/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/global_stopping/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/modelbridge/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/torch/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/flake8_plugins/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/measurement/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/report/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/stats/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/testing/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/testing/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/baseline_results/hpo/torchvision/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/baseline_results/synthetic/hss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/modelbridge/strategies/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/modelbridge/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/storage/json_store/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/measurement/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/report/resources/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/torch/botorch_modular/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/website/sidebars.json:
--------------------------------------------------------------------------------
1 | {
2 | "docs": {
3 | "Introduction": ["why-ax"],
4 | "Getting Started": ["installation", "api", "glossary"],
5 | "Algorithms": ["bayesopt", "banditopt"],
6 | "Components": ["core", "trial-evaluation", "data", "models", "storage"]
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/ax/utils/report/resources/simple_template.html:
--------------------------------------------------------------------------------
1 |
4 | {% extends "base_template.html" %}
5 | {% block content %}
6 | {% for element in html_elements %}
7 | {{element}}
8 | {% endfor %}
9 | {% endblock %}
10 |
--------------------------------------------------------------------------------
/ax/early_stopping/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from ax.early_stopping import strategies
9 |
10 | __all__ = ["strategies"]
11 |
--------------------------------------------------------------------------------
/ax/global_stopping/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from ax.global_stopping import strategies
9 |
10 | __all__ = ["strategies"]
11 |
--------------------------------------------------------------------------------
/ax/plot/js/generic_plotly.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | Plotly.newPlot(
9 | {{id}},
10 | {{data}},
11 | {{layout}},
12 | {"showLink": false}
13 | );
14 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | # Configuration for pytest.
2 | [pytest]
3 | filterwarnings =
4 | # Filter out parameters and sklearn deprecation warnings.
5 | ignore::DeprecationWarning:.*paramz.*
6 | ignore::DeprecationWarning:.*sklearn*
7 | # Filter out numpy non-integer indices warning.
8 | ignore::DeprecationWarning:.*using a non-integer array as obj in delete*
9 |
--------------------------------------------------------------------------------
/ax/service/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.service.managed_loop import OptimizationLoop, optimize
8 |
9 |
10 | __all__ = ["OptimizationLoop", "optimize"]
11 |
--------------------------------------------------------------------------------
/ax/storage/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.storage.json_store import load as json_load, save as json_save
8 |
9 |
10 | __all__ = ["json_save", "json_load"]
11 |
--------------------------------------------------------------------------------
/website/static/img/ax_wireframe.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ax/plot/js/common/css.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | var css = document.createElement('style');
9 | css.type = 'text/css';
10 | css.innerHTML = "{{css}}";
11 | document.getElementsByTagName("head")[0].appendChild(css);
12 |
--------------------------------------------------------------------------------
/scripts/import_ax.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import ax
8 | from ax.service.ax_client import AxClient
9 |
10 |
11 | if __name__ == "__main__":
12 | assert ax is not None
13 | assert AxClient is not None
14 |
--------------------------------------------------------------------------------
/ax/plot/js/common/plotly_requires.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | require(['plotly'], function(Plotly) {
9 | window.PLOTLYENV = window.PLOTLYENV || {};
10 | window.PLOTLYENV.BASE_URL = 'https://plot.ly';
11 | {{script}}
12 | });
13 |
--------------------------------------------------------------------------------
/ax/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa F401
8 | from ax.models.random.sobol import SobolGenerator
9 | from ax.models.torch.botorch import BotorchModel
10 |
11 |
12 | __all__ = ["SobolGenerator", "BotorchModel"]
13 |
--------------------------------------------------------------------------------
/ax/storage/json_store/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.storage.json_store.load import load_experiment as json_load
8 | from ax.storage.json_store.save import save_experiment as json_save
9 |
10 |
11 | __all__ = ["json_load", "json_save"]
12 |
--------------------------------------------------------------------------------
/website/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "scripts": {
3 | "examples": "docusaurus-examples",
4 | "start": "docusaurus-start",
5 | "build": "docusaurus-build",
6 | "publish-gh-pages": "docusaurus-publish",
7 | "write-translations": "docusaurus-write-translations",
8 | "version": "docusaurus-version",
9 | "rename-version": "docusaurus-rename-version"
10 | },
11 | "devDependencies": {
12 | "docusaurus": "^1.7.2"
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 |
3 | # E704: the linter doesn't parse types properly
4 | # T499, T484: silence mypy since using pyre for typechecking
5 | # W503: black and flake8 disagree on how to place operators
6 | # E231: black and flake8 disagree on whitespace after ','
7 | # E203: black and flake8 disagree on whitespace before ':'
8 | ignore = T484, T499, W503, E704, E231, E203
9 |
10 | # Black really wants lines to be 88 chars...
11 | max-line-length = 88
12 |
--------------------------------------------------------------------------------
/ax/runners/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa F401
8 | from ax.runners.simulated_backend import SimulatedBackendRunner
9 | from ax.runners.synthetic import SyntheticRunner
10 |
11 |
12 | __all__ = ["SimulatedBackendRunner", "SyntheticRunner"]
13 |
--------------------------------------------------------------------------------
/ax/plot/js/common/plotly_offline.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | if (!window.Plotly) {
9 | define('plotly', function(require, exports, module) {
10 | {{library}}
11 | });
12 | require(['plotly'], function(Plotly) {
13 | window.Plotly = Plotly;
14 | });
15 | }
16 |
--------------------------------------------------------------------------------
/ax/metrics/l2norm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | from ax.metrics.noisy_function import NoisyFunctionMetric
9 |
10 |
11 | class L2NormMetric(NoisyFunctionMetric):
12 | def f(self, x: np.ndarray) -> float:
13 | return np.sqrt((x**2).sum())
14 |
--------------------------------------------------------------------------------
/ax/plot/js/common/plotly_online.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | requirejs.config({
9 | paths: {
10 | plotly: ['https://cdn.plot.ly/plotly-latest.min'],
11 | },
12 | });
13 | if (!window.Plotly) {
14 | require(['plotly'], function(plotly) {
15 | window.Plotly = plotly;
16 | });
17 | }
18 |
--------------------------------------------------------------------------------
/ax/global_stopping/strategies/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
8 | from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy
9 |
10 |
11 | __all__ = [
12 | "BaseGlobalStoppingStrategy",
13 | "ImprovementGlobalStoppingStrategy",
14 | ]
15 |
--------------------------------------------------------------------------------
/ax/models/types.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Dict, Union
8 |
9 | from ax.core.optimization_config import OptimizationConfig
10 | from botorch.acquisition import AcquisitionFunction
11 |
12 | TConfig = Dict[
13 | str,
14 | Union[
15 | int, float, str, AcquisitionFunction, Dict[str, Any], OptimizationConfig, None
16 | ],
17 | ]
18 |
--------------------------------------------------------------------------------
/ax/modelbridge/transforms/centered_unit_x.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from ax.modelbridge.transforms.unit_x import UnitX
9 |
10 |
11 | class CenteredUnitX(UnitX):
12 | """Map X to [-1, 1]^d for RangeParameter of type float and not log scale.
13 |
14 | Transform is done in-place.
15 | """
16 |
17 | target_lb: float = -1.0
18 | target_range: float = 2.0
19 |
--------------------------------------------------------------------------------
/ax/exceptions/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.exceptions.core import AxError
8 |
9 |
10 | class ModelError(AxError):
11 | """Raised when an error occurs during modeling."""
12 |
13 | pass
14 |
15 |
16 | class CVNotSupportedError(AxError):
17 | """Raised when cross validation is applied to a model which doesn't
18 | support it.
19 | """
20 |
21 | pass
22 |
--------------------------------------------------------------------------------
/ax/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa F401
8 | from ax.metrics.branin import BraninMetric
9 | from ax.metrics.chemistry import ChemistryMetric
10 | from ax.metrics.factorial import FactorialMetric
11 | from ax.metrics.sklearn import SklearnMetric
12 |
13 | __all__ = [
14 | "BraninMetric",
15 | "ChemistryMetric",
16 | "FactorialMetric",
17 | "SklearnMetric",
18 | ]
19 |
--------------------------------------------------------------------------------
/ax/models/torch/frontier_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.models.torch.botorch_moo_defaults import (
8 | get_default_frontier_evaluator,
9 | get_weighted_mc_objective_and_objective_thresholds,
10 | TFrontierEvaluator,
11 | )
12 |
13 |
14 | __all__ = [
15 | "get_weighted_mc_objective_and_objective_thresholds",
16 | "get_default_frontier_evaluator",
17 | "TFrontierEvaluator",
18 | ]
19 |
--------------------------------------------------------------------------------
/ax/metrics/tests/test_curve.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from ax.metrics.curve import AbstractCurveMetric
9 | from ax.utils.common.testutils import TestCase
10 |
11 |
12 | class AbstractCurveMetricTest(TestCase):
13 | def testAbstractCurveMetric(self):
14 | self.assertTrue(AbstractCurveMetric.is_available_while_running())
15 | with self.assertRaises(TypeError):
16 | AbstractCurveMetric("foo", "bar")
17 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # necessary to import this file so SQLAlchemy knows about the event listeners
8 | # see https://fburl.com/8mn7yjt2
9 | from ax.storage.sqa_store import validation
10 | from ax.storage.sqa_store.load import load_experiment as sqa_load
11 | from ax.storage.sqa_store.save import save_experiment as sqa_save
12 |
13 |
14 | __all__ = ["sqa_load", "sqa_save"]
15 |
16 | del validation
17 |
--------------------------------------------------------------------------------
/scripts/docker_install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | VERSION=$1
8 |
9 | # Build Linux Python3.7
10 | # Execute from Ax root directory.
11 | cd ../ || exit
12 | /opt/python/cp37-cp37m/bin/pip3.7 install numpy
13 | /opt/python/cp37-cp37m/bin/python3.7 setup.py bdist_wheel
14 |
15 | # Convert to manylinux
16 | cd dist || exit
17 | auditwheel repair ax_platform-"$VERSION"-cp37-cp37m-linux_x86_64.whl
18 | rm ./*
19 | mv wheelhouse/* .
20 | rm -rf wheelhouse
21 |
--------------------------------------------------------------------------------
/ax/modelbridge/tests/test_centered_unit_x_transform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.modelbridge.tests.test_unit_x_transform import UnitXTransformTest
8 | from ax.modelbridge.transforms.centered_unit_x import CenteredUnitX
9 |
10 |
11 | class CenteredUnitXTransformTest(UnitXTransformTest):
12 |
13 | transform_class = CenteredUnitX
14 | expected_c_dicts = [{"x": -0.5, "y": 0.5}, {"x": -0.5, "a": 1.0}]
15 | expected_c_bounds = [0.0, 1.5]
16 |
--------------------------------------------------------------------------------
/sphinx/source/global_stopping.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.global_stopping
5 | ===================================
6 |
7 | .. automodule:: ax.global_stopping
8 | .. currentmodule:: ax.global_stopping
9 |
10 | Strategies
11 | ----------
12 |
13 | Base Strategies
14 | ~~~~~~~~~~~~~~~
15 |
16 | .. automodule:: ax.global_stopping.strategies.base
17 | :members:
18 | :undoc-members:
19 | :show-inheritance:
20 |
21 | `ImprovementGlobalStoppingStrategy`
22 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23 |
24 | .. automodule:: ax.global_stopping.strategies.improvement
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
--------------------------------------------------------------------------------
/sphinx/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Ax documentation index file, created by
2 | sphinx-quickstart on Sat Mar 2 00:03:32 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | API Reference
7 | ===============================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 |
12 | ax
13 | benchmark
14 | core
15 | exceptions
16 | metrics
17 | modelbridge
18 | models
19 | plot
20 | runners
21 | service
22 | storage
23 | utils
24 |
25 |
26 | Indices and tables
27 | ==================
28 |
29 | * :ref:`genindex`
30 | * :ref:`modindex`
31 | * :ref:`search`
32 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/delete.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.storage.sqa_store.db import session_scope
7 | from ax.storage.sqa_store.sqa_classes import SQAExperiment
8 |
9 |
10 | def delete_experiment(exp_name: str) -> None:
11 | """Delete experiment by name.
12 |
13 | Args:
14 | experiment_name: Name of the experiment to delete.
15 | """
16 | with session_scope() as session:
17 | exp = session.query(SQAExperiment).filter_by(name=exp_name).one_or_none()
18 | session.delete(exp)
19 | session.flush()
20 |
--------------------------------------------------------------------------------
/scripts/build_ax.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # ManyLinux build
8 | # Mount top-level Ax directory as Ax-main.
9 | docker container run --mount type=bind,source="$(pwd)/../",target=/Ax-main -it quay.io/pypa/manylinux2010_x86_64
10 |
11 | # MANUAL STEP FOR NOW
12 | # Now, in Docker container, cd Ax-main and MANUALLY RUN ./docker_install.sh
13 |
14 | # LOCAL BUILD
15 | # Requires Python 3.7 installed locally, and on path
16 | cd ..
17 | pip3.7 install numpy
18 | python3.7 setup.py bdist_wheel
19 |
20 | # Final PyPI Upload
21 | twine upload dist/*
22 |
--------------------------------------------------------------------------------
/ax/metrics/hartmann6.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | from ax.metrics.noisy_function import NoisyFunctionMetric
9 | from ax.utils.common.typeutils import checked_cast
10 | from ax.utils.measurement.synthetic_functions import aug_hartmann6, hartmann6
11 |
12 |
13 | class Hartmann6Metric(NoisyFunctionMetric):
14 | def f(self, x: np.ndarray) -> float:
15 | return checked_cast(float, hartmann6(x))
16 |
17 |
18 | class AugmentedHartmann6Metric(NoisyFunctionMetric):
19 | def f(self, x: np.ndarray) -> float:
20 | return checked_cast(float, aug_hartmann6(x))
21 |
--------------------------------------------------------------------------------
/ax/utils/report/resources/sufficient_statistic.html:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | {% for cell in cells %}
14 | {% if cell.html %}
15 |
16 |
{{cell.caption}}
17 |
18 |
19 | {{cell.html}}
20 |
21 | {% endif %}
22 | {% endfor %}
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/ax/utils/testing/doctest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import doctest
8 | import unittest
9 |
10 | from ax.utils.common import testutils
11 | from ax.utils.testing.manifest import ModuleInfo, populate_test_class
12 |
13 |
14 | def run_doctests(t: unittest.TestCase, m: ModuleInfo) -> None:
15 | results = doctest.testmod(m.module, optionflags=doctest.ELLIPSIS)
16 | assert results.failed == 0
17 |
18 |
19 | @populate_test_class(run_doctests)
20 | class TestDocTests(testutils.TestCase):
21 | """
22 | Run all the doctests in the main library.
23 |
24 | This is a support file for `ae_unittest`.
25 | """
26 |
--------------------------------------------------------------------------------
/website/static/js/mathjax.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | window.MathJax = {
9 | tex2jax: {
10 | inlineMath: [['$', '$'], ['\\(', '\\)']],
11 | displayMath: [['$$', '$$'], ['\\[', '\\]']],
12 | processEscapes: true,
13 | processEnvironments: true,
14 | },
15 | // Center justify equations in code and markdown cells. Note that this
16 | // doesn't work with Plotly though, hence the !important declaratio
17 | // below.
18 | displayAlign: 'center',
19 | 'HTML-CSS': {
20 | styles: {
21 | '.MathJax_Display': {margin: 0, 'text-align': 'center !important'},
22 | },
23 | linebreaks: {automatic: true},
24 | },
25 | };
26 |
--------------------------------------------------------------------------------
/ax/modelbridge/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa F401
8 | from ax.modelbridge import transforms
9 | from ax.modelbridge.base import ModelBridge
10 | from ax.modelbridge.factory import (
11 | get_factorial,
12 | get_GPEI,
13 | get_sobol,
14 | get_thompson,
15 | get_uniform,
16 | Models,
17 | )
18 | from ax.modelbridge.torch import TorchModelBridge
19 |
20 |
21 | __all__ = [
22 | "ModelBridge",
23 | "Models",
24 | "TorchModelBridge",
25 | "get_factorial",
26 | "get_GPEI",
27 | "get_GPKG",
28 | "get_sobol",
29 | "get_thompson",
30 | "get_uniform",
31 | "transforms",
32 | ]
33 |
--------------------------------------------------------------------------------
/website/static/img/database-solid.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
12 |
--------------------------------------------------------------------------------
/ax/models/tests/test_base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.models.base import Model
8 | from ax.utils.common.testutils import TestCase
9 |
10 |
11 | class BaseModelTest(TestCase):
12 | def test_base_model(self):
13 | model = Model()
14 | raw_state = {"foo": "bar", "two": 3.0}
15 | self.assertEqual(model.serialize_state(raw_state), raw_state)
16 | self.assertEqual(model.deserialize_state(raw_state), raw_state)
17 | self.assertEqual(model._get_state(), {})
18 | with self.assertRaisesRegex(
19 | NotImplementedError, "Feature importance not available"
20 | ):
21 | model.feature_importances()
22 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/structs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, NamedTuple, Optional
8 |
9 | from ax.storage.sqa_store.decoder import Decoder
10 | from ax.storage.sqa_store.encoder import Encoder
11 | from ax.storage.sqa_store.sqa_config import SQAConfig
12 |
13 |
14 | class DBSettings(NamedTuple):
15 | """
16 | Defines behavior for loading/saving experiment to/from db.
17 | Either creator or url must be specified as a way to connect to the SQL db.
18 | """
19 |
20 | creator: Optional[Callable] = None
21 | decoder: Decoder = Decoder(config=SQAConfig())
22 | encoder: Encoder = Encoder(config=SQAConfig())
23 | url: Optional[str] = None
24 |
--------------------------------------------------------------------------------
/sphinx/source/runners.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.runners
5 | ===================================
6 |
7 | .. automodule:: ax.runners
8 | .. currentmodule:: ax.runners
9 |
10 | BoTorch Test Problem
11 | ~~~~~~
12 |
13 | .. automodule:: ax.runners.botorch_test_problem
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
18 |
19 | Synthetic Runner
20 | ~~~~~~~~~~~~~~~~
21 |
22 | .. automodule:: ax.runners.synthetic
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Simulated Backend Runner
28 | ~~~~~~~~~~~~~~~~
29 |
30 | .. automodule:: ax.runners.simulated_backend
31 | :members:
32 | :undoc-members:
33 | :show-inheritance:
34 |
35 | TorchX Runner
36 | ~~~~~~~~~~~~~~~~
37 |
38 | .. automodule:: ax.runners.torchx
39 | :members:
40 | :undoc-members:
41 | :show-inheritance:
42 |
--------------------------------------------------------------------------------
/ax/utils/testing/test_init_files.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | from glob import glob
9 |
10 | from ax.utils.common.testutils import TestCase
11 |
12 |
13 | class InitTest(TestCase):
14 | def testInitFiles(self) -> None:
15 | """__init__.py files are necessary when not using buck targets"""
16 | for root, _dirs, files in os.walk("./ax/ax", topdown=False):
17 | if len(glob(f"{root}/**/*.py", recursive=True)) > 0:
18 | with self.subTest(root):
19 | self.assertTrue(
20 | "__init__.py" in files,
21 | "directory " + root + " does not contain a .__init__.py file",
22 | )
23 |
--------------------------------------------------------------------------------
/ax/utils/testing/pyre_strict.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import unittest
8 |
9 | from ax.utils.common import testutils
10 | from ax.utils.testing.manifest import ModuleInfo, populate_test_class
11 |
12 |
13 | def test_pyre_strict(t: unittest.TestCase, m: ModuleInfo) -> None:
14 | with open(m.file) as fd:
15 | for line in fd:
16 | if line == "# pyre-strict\n" or line == "# no-strict-types\n":
17 | return
18 | raise Exception(f"{m.path}'s header should contain '# pyre-strict'")
19 |
20 |
21 | @populate_test_class(test_pyre_strict)
22 | class TestPyreStrict(testutils.TestCase):
23 | """
24 | Test that all the files start are marked pyre strict.
25 | """
26 |
27 | pass
28 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_parallel_coordinates.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | from ax.plot.base import AxPlotConfig, AxPlotTypes
8 | from ax.plot.parallel_coordinates import plot_parallel_coordinates
9 | from ax.utils.common.testutils import TestCase
10 | from ax.utils.testing.core_stubs import get_branin_experiment
11 |
12 |
13 | class ParallelCoordinatesTest(TestCase):
14 | def testParallelCoordinates(self):
15 | exp = get_branin_experiment(with_batch=True)
16 | exp.trials[0].run()
17 |
18 | # Assert that each type of plot can be constructed successfully
19 | plot = plot_parallel_coordinates(experiment=exp)
20 |
21 | self.assertIsInstance(plot, AxPlotConfig)
22 | self.assertEqual(plot.plot_type, AxPlotTypes.GENERIC)
23 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_serialization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import NamedTuple
8 |
9 | from ax.utils.common.serialization import named_tuple_to_dict
10 | from ax.utils.common.testutils import TestCase
11 |
12 |
13 | class TestSerializationUtils(TestCase):
14 | def test_named_tuple_to_dict(self):
15 | class Foo(NamedTuple):
16 | x: int
17 | y: str
18 |
19 | foo = Foo(x=5, y="g")
20 | self.assertEqual(named_tuple_to_dict(foo), {"x": 5, "y": "g"})
21 |
22 | bar = {"x": 5, "foo": foo, "y": [(1, True), foo]}
23 | self.assertEqual(
24 | named_tuple_to_dict(bar),
25 | {"x": 5, "foo": {"x": 5, "y": "g"}, "y": [(1, True), {"x": 5, "y": "g"}]},
26 | )
27 |
--------------------------------------------------------------------------------
/ax/metrics/branin.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | from ax.metrics.noisy_function import NoisyFunctionMetric
9 | from ax.utils.common.typeutils import checked_cast
10 | from ax.utils.measurement.synthetic_functions import aug_branin, branin
11 |
12 |
13 | class BraninMetric(NoisyFunctionMetric):
14 | def f(self, x: np.ndarray) -> float:
15 | x1, x2 = x
16 | return checked_cast(float, branin(x1=x1, x2=x2))
17 |
18 |
19 | class NegativeBraninMetric(BraninMetric):
20 | def f(self, x: np.ndarray) -> float:
21 | fpos = super().f(x)
22 | return -fpos
23 |
24 |
25 | class AugmentedBraninMetric(NoisyFunctionMetric):
26 | def f(self, x: np.ndarray) -> float:
27 | return checked_cast(float, aug_branin(x))
28 |
--------------------------------------------------------------------------------
/website/static/img/th-large-solid.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
14 |
--------------------------------------------------------------------------------
/ax/benchmark/tests/test_torchvision_problem_storage.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem
7 | from ax.storage.json_store.decoder import object_from_json
8 | from ax.storage.json_store.encoder import object_to_json
9 | from ax.utils.common.testutils import TestCase
10 |
11 |
12 | class TestProblems(TestCase):
13 | def test_encode_decode(self):
14 | original_object = PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name(
15 | name="MNIST"
16 | )
17 |
18 | json_object = object_to_json(
19 | original_object,
20 | )
21 | converted_object = object_from_json(
22 | json_object,
23 | )
24 |
25 | self.assertEqual(original_object, converted_object)
26 |
--------------------------------------------------------------------------------
/ax/exceptions/constants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | TS_MIN_WEIGHT_ERROR = """
8 | No arms generated by Thompson Sampling had weight > min_weight.
9 | The minimum weight required is {min_weight:2.4}, and the
10 | maximum weight of any arm generated is {max_weight:2.4}.
11 | """
12 |
13 | TS_NO_FEASIBLE_ARMS_ERROR = """
14 | Less than 1% of samples have a feasible arm.
15 | Check your outcome constraints.
16 | """
17 |
18 | CHOLESKY_ERROR_ANNOTATION = (
19 | "Cholesky errors typically occur when the same or very similar "
20 | "arms are suggested repeatedly. This can mean the model has "
21 | "already converged and you should avoid running further trials. "
22 | "It will also help to convert integer or categorical parameters "
23 | "to float ranges where reasonable.\nOriginal error: "
24 | )
25 |
--------------------------------------------------------------------------------
/scripts/wheels_build.ps1:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # Python path to use for build, flag to upload file
7 | param(
8 | [string]$pypath = $("python3"),
9 | [switch]$upload = $false
10 | )
11 |
12 | Set-Alias -Name mypy -Value $pypath
13 | Write-Host "Py Version:"
14 | mypy --version
15 |
16 | # Jump into the Ax repo folder
17 | pushd $PSScriptRoot\..
18 |
19 | # Install or upgrade all the dependecies
20 | mypy -m pip install botorch jinja2 pandas scipy simplejson sklearn plotly numpy twine wheel
21 | mypy -m pip install --upgrade botorch jinja2 pandas scipy simplejson sklearn plotly numpy twine wheel
22 |
23 | # Let's build
24 | mypy ./setup.py bdist_wheel
25 | # Validate the build
26 | twine check dist/*
27 |
28 | # Final PyPI Upload
29 | If ($upload) {
30 | echo "Uploading"
31 | twine upload dist/*
32 | }
33 |
34 | # Done!
35 | popd
36 |
--------------------------------------------------------------------------------
/ax/benchmark/tests/test_problems.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.benchmark.benchmark_result import AggregatedBenchmarkResult
7 | from ax.benchmark.problems.registry import (
8 | BENCHMARK_PROBLEM_REGISTRY,
9 | get_problem_and_baseline,
10 | )
11 | from ax.utils.common.testutils import TestCase
12 |
13 |
14 | class TestProblems(TestCase):
15 | def test_load_baselines(self):
16 |
17 | # Make sure the json parsing suceeded
18 | for name in BENCHMARK_PROBLEM_REGISTRY.keys():
19 | if "MNIST" in name:
20 | continue # Skip these as they cause the test to take a long time
21 |
22 | problem, baseline = get_problem_and_baseline(problem_name=name)
23 |
24 | self.assertTrue(isinstance(baseline, AggregatedBenchmarkResult))
25 | self.assertIn(problem.name, baseline.name)
26 |
--------------------------------------------------------------------------------
/ax/modelbridge/tests/test_transform_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.modelbridge.transforms.utils import ClosestLookupDict
8 | from ax.utils.common.testutils import TestCase
9 |
10 |
11 | class TransformUtilsTest(TestCase):
12 | def test_closest_lookup_dict(self):
13 | # test empty lookup
14 | d = ClosestLookupDict()
15 | with self.assertRaises(RuntimeError):
16 | d[0]
17 | # basic test
18 | keys = (1.0, 2, 4)
19 | vals = ("a", "b", "c")
20 | d = ClosestLookupDict(zip(keys, vals))
21 | for k, v in zip(keys, vals):
22 | self.assertEqual(d[k], v)
23 | self.assertEqual(d[2.5], "b")
24 | self.assertEqual(d[0], "a")
25 | self.assertEqual(d[6], "c")
26 | with self.assertRaises(ValueError):
27 | d["str_key"] = 3
28 |
--------------------------------------------------------------------------------
/ax/metrics/tests/test_tensorboard.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from unittest import mock
8 |
9 | import numpy as np
10 | import pandas as pd
11 | from ax.metrics.tensorboard import TensorboardCurveMetric
12 | from ax.utils.common.testutils import TestCase
13 |
14 |
15 | class TensorboardCurveMetricTest(TestCase):
16 | def test_GetCurvesFromIds(self):
17 | def mock_get_tb_from_posix(path):
18 | return pd.Series([int(path)] * 2)
19 |
20 | mock_path = "ax.metrics.tensorboard.get_tb_from_posix"
21 | with mock.patch(mock_path, side_effect=mock_get_tb_from_posix):
22 | out = TensorboardCurveMetric.get_curves_from_ids(["1", "2"])
23 | self.assertEqual(len(out), 2)
24 | self.assertTrue(np.array_equal(out["1"].values, np.array([1, 1])))
25 | self.assertTrue(np.array_equal(out["2"].values, np.array([2, 2])))
26 |
--------------------------------------------------------------------------------
/ax/benchmark/methods/choose_generation_strategy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.benchmark.benchmark_method import BenchmarkMethod
7 | from ax.benchmark.benchmark_problem import BenchmarkProblem
8 | from ax.modelbridge.dispatch_utils import choose_generation_strategy
9 | from ax.service.scheduler import SchedulerOptions
10 |
11 |
12 | def get_choose_generation_strategy_method(
13 | problem: BenchmarkProblem, num_trials: int = 30
14 | ) -> BenchmarkMethod:
15 | generation_strategy = choose_generation_strategy(
16 | search_space=problem.search_space,
17 | optimization_config=problem.optimization_config,
18 | num_trials=num_trials,
19 | )
20 |
21 | return BenchmarkMethod(
22 | name=f"ChooseGenerationStrategy::{problem.name}",
23 | generation_strategy=generation_strategy,
24 | scheduler_options=SchedulerOptions(total_trials=num_trials),
25 | )
26 |
--------------------------------------------------------------------------------
/ax/core/map_metric.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from __future__ import annotations
8 |
9 | from typing import Type
10 |
11 | from ax.core.map_data import MapData
12 | from ax.core.metric import Metric
13 |
14 |
15 | class MapMetric(Metric):
16 | """Base class for representing metrics that return `MapData`.
17 |
18 | The `fetch_trial_data` method is the essential method to override when
19 | subclassing, which specifies how to retrieve a Metric, for a given trial.
20 |
21 | A MapMetric must return a MapData object, which requires (at minimum) the following:
22 | https://ax.dev/api/_modules/ax/core/data.html#Data.required_columns
23 |
24 | Attributes:
25 | lower_is_better: Flag for metrics which should be minimized.
26 | properties: Properties specific to a particular metric.
27 | """
28 |
29 | data_constructor: Type[MapData] = MapData
30 |
--------------------------------------------------------------------------------
/ax/modelbridge/tests/test_rounding.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | from ax.modelbridge.transforms.rounding import (
9 | randomized_onehot_round,
10 | strict_onehot_round,
11 | )
12 | from ax.utils.common.testutils import TestCase
13 |
14 |
15 | class RoundingTest(TestCase):
16 | def setUp(self):
17 | pass
18 |
19 | def testOneHotRound(self):
20 | self.assertTrue(
21 | np.allclose(
22 | strict_onehot_round(np.array([0.1, 0.5, 0.3])), np.array([0, 1, 0])
23 | )
24 | )
25 | # One item should be set to one at random.
26 | self.assertEqual(
27 | np.count_nonzero(
28 | np.isclose(
29 | randomized_onehot_round(np.array([0.0, 0.0, 0.0])),
30 | np.array([1, 1, 1]),
31 | )
32 | ),
33 | 1,
34 | )
35 |
--------------------------------------------------------------------------------
/ax/early_stopping/strategies/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.early_stopping.strategies.base import (
8 | BaseEarlyStoppingStrategy,
9 | EarlyStoppingTrainingData,
10 | ModelBasedEarlyStoppingStrategy,
11 | )
12 | from ax.early_stopping.strategies.logical import (
13 | AndEarlyStoppingStrategy,
14 | LogicalEarlyStoppingStrategy,
15 | OrEarlyStoppingStrategy,
16 | )
17 | from ax.early_stopping.strategies.percentile import PercentileEarlyStoppingStrategy
18 | from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy
19 |
20 |
21 | __all__ = [
22 | "BaseEarlyStoppingStrategy",
23 | "EarlyStoppingTrainingData",
24 | "ModelBasedEarlyStoppingStrategy",
25 | "PercentileEarlyStoppingStrategy",
26 | "ThresholdEarlyStoppingStrategy",
27 | "AndEarlyStoppingStrategy",
28 | "OrEarlyStoppingStrategy",
29 | "LogicalEarlyStoppingStrategy",
30 | ]
31 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/timestamp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import datetime
8 | from typing import Optional
9 |
10 | from sqlalchemy.engine.interfaces import Dialect
11 | from sqlalchemy.types import Integer, TypeDecorator
12 |
13 |
14 | class IntTimestamp(TypeDecorator):
15 | impl = Integer
16 | cache_ok = True
17 |
18 | # pyre-fixme[15]: `process_bind_param` overrides method defined in
19 | # `TypeDecorator` inconsistently.
20 | def process_bind_param(
21 | self, value: Optional[datetime.datetime], dialect: Dialect
22 | ) -> Optional[int]:
23 | if value is None:
24 | return None # pragma: no cover
25 | else:
26 | return int(value.timestamp())
27 |
28 | def process_result_value(
29 | self, value: Optional[int], dialect: Dialect
30 | ) -> Optional[datetime.datetime]:
31 | return None if value is None else datetime.datetime.fromtimestamp(value)
32 |
--------------------------------------------------------------------------------
/ax/modelbridge/tests/test_base_transform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from unittest.mock import MagicMock
8 |
9 | from ax.modelbridge.transforms.base import Transform
10 | from ax.utils.common.testutils import TestCase
11 |
12 |
13 | class TransformsTest(TestCase):
14 | def testIdentityTransform(self):
15 | # Test that the identity transform does not mutate anything
16 | t = Transform(MagicMock(), MagicMock(), MagicMock())
17 | x = MagicMock()
18 | ys = []
19 | ys.append(t.transform_search_space(x))
20 | ys.append(t.transform_optimization_config(x, x, x))
21 | ys.append(t.transform_observation_features(x))
22 | ys.append(t.transform_observation_data(x, x))
23 | ys.append(t.untransform_observation_features(x))
24 | ys.append(t.untransform_observation_data(x, x))
25 | self.assertEqual(len(x.mock_calls), 0)
26 | for y in ys:
27 | self.assertEqual(y, x)
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Meta Platforms, Inc. and affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_docutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.utils.common.docutils import copy_doc
8 | from ax.utils.common.testutils import TestCase
9 |
10 |
11 | def has_doc():
12 | """I have a docstring"""
13 |
14 |
15 | def has_no_doc():
16 | pass
17 |
18 |
19 | class TestDocUtils(TestCase):
20 | def test_transfer_doc(self):
21 | @copy_doc(has_doc)
22 | def inherits_doc():
23 | pass
24 |
25 | self.assertEqual(inherits_doc.__doc__, "I have a docstring")
26 |
27 | def test_fail_when_already_has_doc(self):
28 | with self.assertRaises(ValueError):
29 |
30 | @copy_doc(has_doc)
31 | def inherits_doc():
32 | """I already have a doc string"""
33 | pass
34 |
35 | def test_fail_when_no_doc_to_copy(self):
36 | with self.assertRaises(ValueError):
37 |
38 | @copy_doc(has_no_doc)
39 | def f():
40 | pass
41 |
--------------------------------------------------------------------------------
/ax/storage/json_store/load.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import json
8 | from typing import Any, Callable, Dict, Type
9 |
10 | from ax.core.experiment import Experiment
11 | from ax.storage.json_store.decoder import object_from_json
12 | from ax.storage.json_store.registry import (
13 | CORE_CLASS_DECODER_REGISTRY,
14 | CORE_DECODER_REGISTRY,
15 | )
16 |
17 |
18 | def load_experiment(
19 | filepath: str,
20 | decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY,
21 | class_decoder_registry: Dict[
22 | str, Callable[[Dict[str, Any]], Any]
23 | ] = CORE_CLASS_DECODER_REGISTRY,
24 | ) -> Experiment:
25 | """Load experiment from file.
26 |
27 | 1) Read file.
28 | 2) Convert dictionary to Ax experiment instance.
29 | """
30 | with open(filepath, "r") as file:
31 | json_experiment = json.loads(file.read())
32 | return object_from_json(
33 | json_experiment, decoder_registry, class_decoder_registry
34 | )
35 |
--------------------------------------------------------------------------------
/ax/models/tests/test_alebo_initializer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | from ax.models.random.alebo_initializer import ALEBOInitializer
9 | from ax.utils.common.testutils import TestCase
10 |
11 |
12 | class ALEBOSobolTest(TestCase):
13 | def testALEBOSobolModel(self):
14 | B = np.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
15 | Q = np.linalg.pinv(B) @ B
16 | # Test setting attributes
17 | m = ALEBOInitializer(B=B)
18 | self.assertTrue(np.allclose(Q, m.Q))
19 |
20 | # Test gen
21 | Z, w = m.gen(5, bounds=[(-1.0, 1.0)] * 3)
22 | self.assertEqual(Z.shape, (5, 3))
23 | self.assertTrue(Z.min() >= -1.0)
24 | self.assertTrue(Z.max() <= 1.0)
25 | # Verify that it is in the subspace
26 | self.assertTrue(np.allclose(Q @ Z.transpose(), Z.transpose()))
27 |
28 | m = ALEBOInitializer(B=B, nsamp=1)
29 | with self.assertRaises(ValueError):
30 | m.gen(2, bounds=[(-1.0, 1.0)] * 3)
31 |
--------------------------------------------------------------------------------
/ax/benchmark/benchmark_method.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from dataclasses import dataclass
7 |
8 | from ax.exceptions.core import UserInputError
9 | from ax.modelbridge.generation_strategy import GenerationStrategy
10 | from ax.service.utils.scheduler_options import SchedulerOptions
11 | from ax.utils.common.base import Base
12 |
13 |
14 | @dataclass(frozen=True)
15 | class BenchmarkMethod(Base):
16 | """Benchmark method, represented in terms of Ax generation strategy (which tells us
17 | which models to use when) and scheduler options (which tell us extra execution
18 | information like maximum parallelism, early stopping configuration, etc.)
19 | """
20 |
21 | name: str
22 | generation_strategy: GenerationStrategy
23 | scheduler_options: SchedulerOptions
24 |
25 | def __post_init__(self) -> None:
26 | if self.scheduler_options.total_trials is None:
27 | raise UserInputError(
28 | "SchedulerOptions.total_trials may not be None in BenchmarkMethod."
29 | )
30 |
--------------------------------------------------------------------------------
/sphinx/source/early_stopping.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.early_stopping
5 | ===================================
6 |
7 | .. automodule:: ax.early_stopping
8 | .. currentmodule:: ax.early_stopping
9 |
10 | Strategies
11 | ----------
12 |
13 | Base Strategies
14 | ~~~~~~~~~~~~~~~
15 |
16 | .. automodule:: ax.early_stopping.strategies.base
17 | :members:
18 | :undoc-members:
19 | :show-inheritance:
20 |
21 | Logical Strategies
22 | ~~~~~~~~~~~~~~~~~~
23 |
24 | .. automodule:: ax.early_stopping.strategies.logical
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
29 | `PercentileEarlyStoppingStrategy`
30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
31 |
32 | .. automodule:: ax.early_stopping.strategies.percentile
33 | :members:
34 | :undoc-members:
35 | :show-inheritance:
36 |
37 | `ThresholdEarlyStoppingStrategy`
38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39 |
40 | .. automodule:: ax.early_stopping.strategies.threshold
41 | :members:
42 | :undoc-members:
43 | :show-inheritance:
44 |
45 | Utils
46 | -----
47 |
48 | .. automodule:: ax.early_stopping.utils
49 | :members:
50 | :undoc-members:
51 | :show-inheritance:
52 |
--------------------------------------------------------------------------------
/sphinx/source/exceptions.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.exceptions
5 | ===================================
6 |
7 | .. automodule:: ax.exceptions
8 | .. currentmodule:: ax.exceptions
9 |
10 |
11 | Constants
12 | ~~~~~~~~~~~~~~~~~~
13 |
14 | .. automodule:: ax.exceptions.constants
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 |
20 | Core
21 | ~~~~~~~~~~~~~~~~~~
22 |
23 | .. automodule:: ax.exceptions.core
24 | :members:
25 | :undoc-members:
26 | :show-inheritance:
27 |
28 | Data
29 | ~~~~~~~~~~~~~~~~~~
30 |
31 | .. automodule:: ax.exceptions.data_provider
32 | :members:
33 | :undoc-members:
34 | :show-inheritance:
35 |
36 | Generation Strategy
37 | ~~~~~~~~~~~~~~~~~~~~
38 |
39 | .. automodule:: ax.exceptions.generation_strategy
40 | :members:
41 | :undoc-members:
42 | :show-inheritance:
43 |
44 | Model
45 | ~~~~~~~~~~~~~~~~~~
46 |
47 | .. automodule:: ax.exceptions.model
48 | :members:
49 | :undoc-members:
50 | :show-inheritance:
51 |
52 | Storage
53 | ~~~~~~~~~~~~~~~~~~
54 |
55 | .. automodule:: ax.exceptions.storage
56 | :members:
57 | :undoc-members:
58 | :show-inheritance:
59 |
--------------------------------------------------------------------------------
/ax/models/tests/test_rembo_initializer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | from ax.models.random.rembo_initializer import REMBOInitializer
9 | from ax.utils.common.testutils import TestCase
10 |
11 |
12 | class REMBOInitializerTest(TestCase):
13 | def testREMBOInitializerModel(self):
14 | A = np.vstack((np.eye(2, 2), -(np.eye(2, 2))))
15 | # Test setting attributes
16 | m = REMBOInitializer(A=A, bounds_d=[(-2, 2)] * 2)
17 | self.assertTrue(np.allclose(A, m.A))
18 | self.assertEqual(m.bounds_d, [(-2, 2), (-2, 2)])
19 |
20 | # Test project up
21 | Z = m.project_up(5 * np.random.rand(3, 2))
22 | self.assertEqual(Z.shape, (3, 4))
23 | self.assertTrue(Z.min() >= -1.0)
24 | self.assertTrue(Z.max() <= 1.0)
25 |
26 | # Test gen
27 | Z, w = m.gen(3, bounds=[(-1.0, 1.0)] * 4)
28 | self.assertEqual(Z.shape, (3, 4))
29 | self.assertTrue(Z.min() >= -1.0)
30 | self.assertTrue(Z.max() <= 1.0)
31 |
--------------------------------------------------------------------------------
/ax/plot/css/base.css:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | .hovertext {
9 | opacity: 0.85;
10 | }
11 |
12 | .plot-menu div {
13 | margin-top: 5px;
14 | margin-bottom: 5px;
15 | }
16 |
17 | .plot-menu select {
18 | height: 30px;
19 | padding: 6px 10px;
20 | background-color: #fff;
21 | border: 1px solid #D1D1D1;
22 | border-radius: 4px;
23 | box-shadow: none;
24 | box-sizing: border-box;
25 | margin-left: 15px;
26 | margin-bottom: 0px;
27 | }
28 |
29 | .plot-menu select:focus {
30 | border: 1px solid #33C3F0;
31 | outline: 0;
32 | }
33 |
34 | .plot-menu button {
35 | display: inline-block;
36 | height: 30px;
37 | padding: 0 10px;
38 | color: #555;
39 | text-align: center;
40 | font-size: 12px;
41 | font-weight: 600;
42 | text-transform: uppercase;
43 | text-decoration: none;
44 | white-space: nowrap;
45 | background-color: transparent;
46 | border-radius: 4px;
47 | border: 1px solid #bbb;
48 | cursor: pointer;
49 | box-sizing: border-box;
50 | background-color: #fff;
51 | margin-left: 15px;
52 | }
53 |
--------------------------------------------------------------------------------
/ax/utils/report/tests/test_render.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.utils.common.testutils import TestCase
8 | from ax.utils.report.render import (
9 | h2_html,
10 | h3_html,
11 | link_html,
12 | list_item_html,
13 | p_html,
14 | render_report_elements,
15 | table_cell_html,
16 | table_heading_cell_html,
17 | table_html,
18 | table_row_html,
19 | unordered_list_html,
20 | )
21 |
22 |
23 | class RenderTest(TestCase):
24 | def testRenderReportElements(self):
25 | elements = [
26 | p_html("foobar"),
27 | h2_html("foobar"),
28 | h3_html("foobar"),
29 | list_item_html("foobar"),
30 | unordered_list_html(["foo", "bar"]),
31 | link_html("foo", "bar"),
32 | table_cell_html("foobar"),
33 | table_cell_html("foobar", width="100px"),
34 | table_heading_cell_html("foobar"),
35 | table_row_html(["foo", "bar"]),
36 | table_html(["foo", "bar"]),
37 | ]
38 | render_report_elements("test", elements)
39 |
--------------------------------------------------------------------------------
/ax/exceptions/storage.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from ax.exceptions.core import AxError
9 |
10 |
11 | class JSONDecodeError(AxError):
12 | """Raised when an error occurs during JSON decoding."""
13 |
14 | pass
15 |
16 |
17 | class JSONEncodeError(AxError):
18 | """Raised when an error occurs during JSON encoding."""
19 |
20 | pass
21 |
22 |
23 | class SQADecodeError(AxError):
24 | """Raised when an error occurs during SQA decoding."""
25 |
26 | pass
27 |
28 |
29 | class SQAEncodeError(AxError):
30 | """Raised when an error occurs during SQA encoding."""
31 |
32 | pass
33 |
34 |
35 | class ImmutabilityError(AxError):
36 | """Raised when an attempt is made to update an immutable object."""
37 |
38 | pass
39 |
40 |
41 | class IncorrectDBConfigurationError(AxError):
42 | """Raised when an attempt is made to save and load an object, but
43 | the current engine and session factory is setup up incorrectly to
44 | process the call (e.g. current session factory will connect to a
45 | wrong database for the call).
46 | """
47 |
48 | pass
49 |
--------------------------------------------------------------------------------
/ax/core/tests/test_map_metric.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.core.map_metric import MapMetric
8 | from ax.utils.common.testutils import TestCase
9 |
10 |
11 | METRIC_STRING = "MapMetric('m1')"
12 |
13 |
14 | class MapMetricTest(TestCase):
15 | def setUp(self):
16 | pass
17 |
18 | def testInit(self):
19 | metric = MapMetric(name="m1", lower_is_better=False)
20 | self.assertEqual(str(metric), METRIC_STRING)
21 |
22 | def testEq(self):
23 | metric1 = MapMetric(name="m1", lower_is_better=False)
24 | metric2 = MapMetric(name="m1", lower_is_better=False)
25 | self.assertEqual(metric1, metric2)
26 |
27 | metric3 = MapMetric(name="m1", lower_is_better=True)
28 | self.assertNotEqual(metric1, metric3)
29 |
30 | def testClone(self):
31 | metric1 = MapMetric(name="m1", lower_is_better=False)
32 | self.assertEqual(metric1, metric1.clone())
33 |
34 | def testSortable(self):
35 | metric1 = MapMetric(name="m1", lower_is_better=False)
36 | metric2 = MapMetric(name="m2", lower_is_better=False)
37 | self.assertTrue(metric1 < metric2)
38 |
--------------------------------------------------------------------------------
/ax/runners/synthetic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Dict, Iterable, Optional, Set
8 |
9 | from ax.core.base_trial import BaseTrial, TrialStatus
10 | from ax.core.runner import Runner
11 |
12 |
13 | class SyntheticRunner(Runner):
14 | """Class for synthetic or dummy runner.
15 |
16 | Currently acts as a shell runner, only creating a name.
17 | """
18 |
19 | def __init__(self, dummy_metadata: Optional[str] = None) -> None:
20 | self.dummy_metadata = dummy_metadata
21 |
22 | def run(self, trial: BaseTrial) -> Dict[str, Any]:
23 | deployed_name = (
24 | trial.experiment.name + "_" + str(trial.index)
25 | if trial.experiment.has_name
26 | else str(trial.index)
27 | )
28 | metadata = {"name": deployed_name}
29 |
30 | # Add dummy metadata if needed for testing
31 | if self.dummy_metadata:
32 | metadata["dummy_metadata"] = self.dummy_metadata
33 | return metadata
34 |
35 | def poll_trial_status(
36 | self, trials: Iterable[BaseTrial]
37 | ) -> Dict[TrialStatus, Set[int]]:
38 | return {TrialStatus.COMPLETED: {t.index for t in trials}}
39 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_helper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from ax.plot.helper import arm_name_to_sort_key, extend_range
9 | from ax.utils.common.testutils import TestCase
10 |
11 |
12 | class HelperTest(TestCase):
13 | def test_extend_range(self):
14 | with self.assertRaises(ValueError):
15 | extend_range(lower=1, upper=-1)
16 | self.assertEqual(extend_range(lower=-1, upper=1), (-1.2, 1.2))
17 | self.assertEqual(extend_range(lower=-1, upper=0, percent=30), (-1.3, 0.3))
18 | self.assertEqual(extend_range(lower=0, upper=1, percent=50), (-0.5, 1.5))
19 |
20 | def test_arm_name_to_sort_key(self):
21 | arm_names = ["0_0", "1_10", "1_2", "10_0", "control"]
22 | sorted_names = sorted(arm_names, key=arm_name_to_sort_key, reverse=True)
23 | expected = ["control", "0_0", "1_2", "1_10", "10_0"]
24 | self.assertEqual(sorted_names, expected)
25 |
26 | arm_names = ["0_0", "0", "1_10", "3_2_x", "3_x", "1_2", "control"]
27 | sorted_names = sorted(arm_names, key=arm_name_to_sort_key, reverse=True)
28 | expected = ["control", "3_x", "3_2_x", "0", "0_0", "1_2", "1_10"]
29 | self.assertEqual(sorted_names, expected)
30 |
--------------------------------------------------------------------------------
/ax/storage/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import enum
8 |
9 | from ax.core.experiment import DataType # noqa F401
10 |
11 |
12 | class DomainType(enum.Enum):
13 | """Class for enumerating domain types."""
14 |
15 | FIXED: int = 0
16 | RANGE: int = 1
17 | CHOICE: int = 2
18 |
19 |
20 | class MetricIntent(enum.Enum):
21 | """Class for enumerating metric use types."""
22 |
23 | OBJECTIVE: str = "objective"
24 | MULTI_OBJECTIVE: str = "multi_objective"
25 | SCALARIZED_OBJECTIVE: str = "scalarized_objective"
26 | # Additional objective is not yet supported in Ax open-source.
27 | ADDITIONAL_OBJECTIVE: str = "additional_objective"
28 | OUTCOME_CONSTRAINT: str = "outcome_constraint"
29 | SCALARIZED_OUTCOME_CONSTRAINT: str = "scalarized_outcome_constraint"
30 | OBJECTIVE_THRESHOLD: str = "objective_threshold"
31 | TRACKING: str = "tracking"
32 |
33 |
34 | class ParameterConstraintType(enum.Enum):
35 | """Class for enumerating parameter constraint types.
36 |
37 | Linear constraint is base type whereas other constraint types are
38 | special types of linear constraints.
39 | """
40 |
41 | LINEAR: int = 0
42 | ORDER: int = 1
43 | SUM: int = 2
44 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_diagnostic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import plotly.graph_objects as go
8 | from ax.modelbridge.cross_validation import cross_validate
9 | from ax.modelbridge.registry import Models
10 | from ax.plot.base import AxPlotConfig
11 | from ax.plot.diagnostic import (
12 | interact_cross_validation,
13 | interact_cross_validation_plotly,
14 | )
15 | from ax.utils.common.testutils import TestCase
16 | from ax.utils.testing.core_stubs import get_branin_experiment
17 | from ax.utils.testing.mock import fast_botorch_optimize
18 |
19 |
20 | class DiagnosticTest(TestCase):
21 | @fast_botorch_optimize
22 | def test_cross_validation(self):
23 | exp = get_branin_experiment(with_batch=True)
24 | exp.trials[0].run()
25 | model = Models.BOTORCH(
26 | # Model bridge kwargs
27 | experiment=exp,
28 | data=exp.fetch_data(),
29 | )
30 | cv = cross_validate(model)
31 | # Assert that each type of plot can be constructed successfully
32 | plot = interact_cross_validation_plotly(cv)
33 | self.assertIsInstance(plot, go.Figure)
34 | plot = interact_cross_validation(cv)
35 | self.assertIsInstance(plot, AxPlotConfig)
36 |
--------------------------------------------------------------------------------
/ax/utils/report/resources/base_template.html:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 | {{experiment_name}}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | {% if headfoot %}
16 |
17 |
18 | {% if trial_index is none %}
19 | Report for {{experiment_name}}
20 | {% else %}
21 | Report for {{ batch_noun }} #{{ trial_index }} of
22 | {{experiment_name}}
23 | {% endif %}
24 |
25 |
Powered by Ax
26 |
27 | {% endif %}
28 |
29 | {% block content %}
30 | {% endblock %}
31 |
32 | {% if headfoot %}
33 |
34 | {% endif %}
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/website/static/img/dice-solid.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
17 |
--------------------------------------------------------------------------------
/ax/models/random/uniform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import numpy as np
10 | from ax.models.random.base import RandomModel
11 | from scipy.stats import uniform
12 |
13 |
14 | class UniformGenerator(RandomModel):
15 | """This class specifies a uniform random generation algorithm.
16 |
17 | As a uniform generator does not make use of a model, it does not implement
18 | the fit or predict methods.
19 |
20 | Attributes:
21 | seed: An optional seed value for the underlying PRNG.
22 |
23 | """
24 |
25 | def __init__(self, deduplicate: bool = False, seed: Optional[int] = None) -> None:
26 | super().__init__(deduplicate=deduplicate, seed=seed)
27 | self._rs = np.random.RandomState(seed=seed)
28 |
29 | def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray:
30 | """Generate samples from the scipy uniform distribution.
31 |
32 | Args:
33 | n: Number of samples to generate.
34 | tunable_d: Dimension of samples to generate.
35 |
36 | Returns:
37 | samples: An (n x d) array of random points.
38 |
39 | """
40 | return uniform.rvs(size=(n, tunable_d), random_state=self._rs) # pyre-ignore
41 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/hd_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from dataclasses import asdict
7 |
8 | from ax.benchmark.benchmark_problem import BenchmarkProblem
9 | from ax.core.parameter import ParameterType, RangeParameter
10 | from ax.core.search_space import SearchSpace
11 |
12 |
13 | def embed_higher_dimension(
14 | problem: BenchmarkProblem, total_dimensionality: int
15 | ) -> BenchmarkProblem:
16 | num_dummy_dimensions = total_dimensionality - len(problem.search_space.parameters)
17 |
18 | search_space = SearchSpace(
19 | parameters=[
20 | *problem.search_space.parameters.values(),
21 | *[
22 | RangeParameter(
23 | name=f"embedding_dummy_{i}",
24 | parameter_type=ParameterType.FLOAT,
25 | lower=0,
26 | upper=1,
27 | )
28 | for i in range(num_dummy_dimensions)
29 | ],
30 | ],
31 | parameter_constraints=problem.search_space.parameter_constraints,
32 | )
33 |
34 | problem_kwargs = asdict(problem)
35 | problem_kwargs["name"] = f"{problem_kwargs['name']}_{total_dimensionality}d"
36 | problem_kwargs["search_space"] = search_space
37 |
38 | return problem.__class__(**problem_kwargs)
39 |
--------------------------------------------------------------------------------
/ax/service/utils/early_stopping.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from typing import Dict, Optional, Set
7 |
8 | from ax.core.experiment import Experiment
9 | from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
10 | from ax.utils.common.typeutils import not_none
11 |
12 |
13 | def should_stop_trials_early(
14 | early_stopping_strategy: Optional[BaseEarlyStoppingStrategy],
15 | trial_indices: Set[int],
16 | experiment: Experiment,
17 | ) -> Dict[int, Optional[str]]:
18 | """Evaluate whether to early-stop running trials.
19 |
20 | Args:
21 | early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines
22 | whether a trial should be stopped given the state of an experiment.
23 | trial_indices: Indices of trials to consider for early stopping.
24 | experiment: The experiment containing the trials.
25 |
26 | Returns:
27 | A dictionary mapping trial indices that should be early stopped to
28 | (optional) messages with the associated reason.
29 | """
30 | if early_stopping_strategy is None:
31 | return {}
32 |
33 | early_stopping_strategy = not_none(early_stopping_strategy)
34 | return early_stopping_strategy.should_stop_trials_early(
35 | trial_indices=trial_indices, experiment=experiment
36 | )
37 |
--------------------------------------------------------------------------------
/ax/utils/notebook/plotting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.plot.base import AxPlotConfig, AxPlotTypes
8 | from ax.plot.render import _js_requires, _wrap_js, plot_config_to_html
9 | from ax.utils.common.logger import get_logger
10 | from IPython.display import display
11 | from plotly.offline import init_notebook_mode, iplot
12 |
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | def init_notebook_plotting(offline=False):
18 | """Initialize plotting in notebooks, either in online or offline mode."""
19 | display_bundle = {"text/html": _wrap_js(_js_requires(offline=offline))}
20 | display(display_bundle, raw=True)
21 | logger.info("Injecting Plotly library into cell. Do not overwrite or delete cell.")
22 | init_notebook_mode()
23 |
24 |
25 | def render(plot_config: AxPlotConfig, inject_helpers=False) -> None:
26 | """Render plot config."""
27 | if plot_config.plot_type == AxPlotTypes.GENERIC:
28 | iplot(plot_config.data)
29 | elif plot_config.plot_type == AxPlotTypes.HTML:
30 | assert "text/html" in plot_config.data
31 | display(plot_config.data, raw=True)
32 | else:
33 | display_bundle = {
34 | "text/html": plot_config_to_html(plot_config, inject_helpers=inject_helpers)
35 | }
36 | display(display_bundle, raw=True)
37 |
--------------------------------------------------------------------------------
/ax/storage/json_store/save.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import json
8 | from typing import Any, Callable, Dict, Type
9 |
10 | from ax.core.experiment import Experiment
11 | from ax.storage.json_store.encoder import object_to_json
12 | from ax.storage.json_store.registry import (
13 | CORE_CLASS_ENCODER_REGISTRY,
14 | CORE_ENCODER_REGISTRY,
15 | )
16 |
17 |
18 | def save_experiment(
19 | experiment: Experiment,
20 | filepath: str,
21 | encoder_registry: Dict[
22 | Type, Callable[[Any], Dict[str, Any]]
23 | ] = CORE_ENCODER_REGISTRY,
24 | class_encoder_registry: Dict[
25 | Type, Callable[[Any], Dict[str, Any]]
26 | ] = CORE_CLASS_ENCODER_REGISTRY,
27 | ) -> None:
28 | """Save experiment to file.
29 |
30 | 1) Convert Ax experiment to JSON-serializable dictionary.
31 | 2) Write to file.
32 | """
33 | if not isinstance(experiment, Experiment):
34 | raise ValueError("Can only save instances of Experiment")
35 |
36 | if not filepath.endswith(".json"):
37 | raise ValueError("Filepath must end in .json")
38 |
39 | json_experiment = object_to_json(
40 | experiment,
41 | encoder_registry=encoder_registry,
42 | class_encoder_registry=class_encoder_registry,
43 | )
44 | with open(filepath, "w+") as file:
45 | file.write(json.dumps(json_experiment))
46 |
--------------------------------------------------------------------------------
/ax/utils/common/base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from __future__ import annotations
8 |
9 | import abc
10 | from typing import Optional
11 |
12 | from ax.utils.common.equality import equality_typechecker, object_attribute_dicts_equal
13 |
14 |
15 | class Base:
16 | """Metaclass for core Ax classes. Provides an equality check and `db_id`
17 | property for SQA storage.
18 | """
19 |
20 | _db_id: Optional[int] = None
21 |
22 | @property
23 | def db_id(self) -> Optional[int]:
24 | return self._db_id
25 |
26 | @db_id.setter
27 | def db_id(self, db_id: int) -> None:
28 | self._db_id = db_id
29 |
30 | @equality_typechecker
31 | def __eq__(self, other: Base) -> bool:
32 | return object_attribute_dicts_equal(
33 | one_dict=self.__dict__, other_dict=other.__dict__
34 | )
35 |
36 |
37 | class SortableBase(Base, metaclass=abc.ABCMeta):
38 | """Extension to the base class that also provides an inequality check."""
39 |
40 | @property
41 | @abc.abstractmethod
42 | def _unique_id(self) -> str:
43 | """Returns an identification string that can be used to uniquely
44 | identify this instance from others attached to the same experiment.
45 | """
46 | pass
47 |
48 | def __lt__(self, other: SortableBase) -> bool:
49 | return self._unique_id < other._unique_id
50 |
--------------------------------------------------------------------------------
/ax/models/tests/test_randomforest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from ax.core.search_space import SearchSpaceDigest
9 | from ax.models.torch.randomforest import RandomForest
10 | from ax.utils.common.testutils import TestCase
11 | from botorch.utils.datasets import FixedNoiseDataset
12 |
13 |
14 | class RandomForestTest(TestCase):
15 | def testRFModel(self):
16 | datasets = [
17 | FixedNoiseDataset(
18 | X=torch.rand(10, 2), Y=torch.rand(10, 1), Yvar=torch.rand(10, 1)
19 | )
20 | for _ in range(2)
21 | ]
22 |
23 | m = RandomForest(num_trees=5)
24 | m.fit(
25 | datasets=datasets,
26 | metric_names=["y1", "y2"],
27 | search_space_digest=SearchSpaceDigest(
28 | feature_names=["x1", "x2"],
29 | bounds=[(0, 1)] * 2,
30 | ),
31 | )
32 | self.assertEqual(len(m.models), 2)
33 | self.assertEqual(len(m.models[0].estimators_), 5)
34 |
35 | f, cov = m.predict(torch.rand(5, 2))
36 | self.assertEqual(f.shape, torch.Size((5, 2)))
37 | self.assertEqual(cov.shape, torch.Size((5, 2, 2)))
38 |
39 | f, cov = m.cross_validate(datasets=datasets, X_test=torch.rand(3, 2))
40 | self.assertEqual(f.shape, torch.Size((3, 2)))
41 | self.assertEqual(cov.shape, torch.Size((3, 2, 2)))
42 |
--------------------------------------------------------------------------------
/ax/service/tests/test_early_stopping.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.service.utils import early_stopping as early_stopping_utils
7 | from ax.utils.common.testutils import TestCase
8 | from ax.utils.testing.core_stubs import (
9 | DummyEarlyStoppingStrategy,
10 | get_branin_experiment,
11 | )
12 |
13 |
14 | class TestEarlyStoppingUtils(TestCase):
15 | """Testing the early stopping utilities functionality that is not tested in
16 | main `AxClient` testing suite (`TestServiceAPI`)."""
17 |
18 | def setUp(self):
19 | self.branin_experiment = get_branin_experiment()
20 |
21 | def test_should_stop_trials_early(self):
22 | expected = {
23 | 1: "Stopped due to testing.",
24 | 3: "Stopped due to testing.",
25 | }
26 | actual = early_stopping_utils.should_stop_trials_early(
27 | early_stopping_strategy=DummyEarlyStoppingStrategy(expected),
28 | trial_indices=[1, 2, 3],
29 | experiment=self.branin_experiment,
30 | )
31 | self.assertEqual(actual, expected)
32 |
33 | def test_should_stop_trials_early_no_strategy(self):
34 | actual = early_stopping_utils.should_stop_trials_early(
35 | early_stopping_strategy=None,
36 | trial_indices=[1, 2, 3],
37 | experiment=self.branin_experiment,
38 | )
39 | expected = {}
40 | self.assertEqual(actual, expected)
41 |
--------------------------------------------------------------------------------
/ax/core/tests/test_metric.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.core.metric import Metric
8 | from ax.utils.common.testutils import TestCase
9 | from ax.utils.testing.core_stubs import get_branin_metric, get_factorial_metric
10 |
11 |
12 | METRIC_STRING = "Metric('m1')"
13 |
14 |
15 | class MetricTest(TestCase):
16 | def setUp(self):
17 | pass
18 |
19 | def testInit(self):
20 | metric = Metric(name="m1", lower_is_better=False)
21 | self.assertEqual(str(metric), METRIC_STRING)
22 |
23 | def testEq(self):
24 | metric1 = Metric(name="m1", lower_is_better=False)
25 | metric2 = Metric(name="m1", lower_is_better=False)
26 | self.assertEqual(metric1, metric2)
27 |
28 | metric3 = Metric(name="m1", lower_is_better=True)
29 | self.assertNotEqual(metric1, metric3)
30 |
31 | def testClone(self):
32 | metric1 = Metric(name="m1", lower_is_better=False)
33 | self.assertEqual(metric1, metric1.clone())
34 |
35 | metric2 = get_branin_metric(name="branin")
36 | self.assertEqual(metric2, metric2.clone())
37 |
38 | metric3 = get_factorial_metric(name="factorial")
39 | self.assertEqual(metric3, metric3.clone())
40 |
41 | def testSortable(self):
42 | metric1 = Metric(name="m1", lower_is_better=False)
43 | metric2 = Metric(name="m2", lower_is_better=False)
44 | self.assertTrue(metric1 < metric2)
45 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_slices.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import plotly.graph_objects as go
8 | from ax.modelbridge.registry import Models
9 | from ax.plot.base import AxPlotConfig
10 | from ax.plot.slice import (
11 | interact_slice,
12 | interact_slice_plotly,
13 | plot_slice,
14 | plot_slice_plotly,
15 | )
16 | from ax.utils.common.testutils import TestCase
17 | from ax.utils.testing.core_stubs import get_branin_experiment
18 | from ax.utils.testing.mock import fast_botorch_optimize
19 |
20 |
21 | class SlicesTest(TestCase):
22 | @fast_botorch_optimize
23 | def testSlices(self):
24 | exp = get_branin_experiment(with_batch=True)
25 | exp.trials[0].run()
26 | model = Models.BOTORCH(
27 | # Model bridge kwargs
28 | experiment=exp,
29 | data=exp.fetch_data(),
30 | )
31 | # Assert that each type of plot can be constructed successfully
32 | plot = plot_slice_plotly(
33 | model, model.parameters[0], list(model.metric_names)[0]
34 | )
35 | self.assertIsInstance(plot, go.Figure)
36 | plot = interact_slice_plotly(model)
37 | self.assertIsInstance(plot, go.Figure)
38 | plot = plot_slice(model, model.parameters[0], list(model.metric_names)[0])
39 | self.assertIsInstance(plot, AxPlotConfig)
40 | plot = interact_slice(model)
41 | self.assertIsInstance(plot, AxPlotConfig)
42 |
--------------------------------------------------------------------------------
/ax/utils/testing/torch_stubs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from __future__ import annotations
8 |
9 | from typing import Dict
10 |
11 | import torch
12 |
13 |
14 | def get_optimizer_kwargs() -> Dict[str, int]:
15 | return {"num_restarts": 2, "raw_samples": 2, "maxiter": 2, "batch_limit": 1}
16 |
17 |
18 | def get_torch_test_data(
19 | dtype=torch.float,
20 | cuda: bool = False,
21 | constant_noise: bool = True,
22 | task_features=None,
23 | offset: float = 0.0,
24 | ):
25 | tkwargs = {"device": torch.device("cuda" if cuda else "cpu"), "dtype": dtype}
26 | Xs = [
27 | torch.tensor(
28 | [
29 | [1.0 + offset, 2.0 + offset, 3.0 + offset],
30 | [2.0 + offset, 3.0 + offset, 4.0 + offset],
31 | ],
32 | **tkwargs,
33 | )
34 | ]
35 | Ys = [torch.tensor([[3.0 + offset], [4.0 + offset]], **tkwargs)]
36 | Yvars = [torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs)]
37 | if constant_noise:
38 | Yvars[0].fill_(1.0)
39 | bounds = [
40 | (0.0 + offset, 1.0 + offset),
41 | (1.0 + offset, 4.0 + offset),
42 | (2.0 + offset, 5.0 + offset),
43 | ]
44 | feature_names = ["x1", "x2", "x3"]
45 | task_features = [] if task_features is None else task_features
46 | metric_names = ["y", "r"]
47 | return Xs, Ys, Yvars, bounds, task_features, feature_names, metric_names
48 |
--------------------------------------------------------------------------------
/ax/utils/common/timeutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from datetime import datetime, timedelta
8 | from time import time
9 | from typing import Generator
10 |
11 | import pandas as pd
12 |
13 |
14 | DS_FRMT = "%Y-%m-%d" # Format to use for parsing DS strings.
15 |
16 |
17 | def to_ds(ts: datetime) -> str:
18 | """Convert a `datetime` to a DS string."""
19 | return datetime.strftime(ts, DS_FRMT)
20 |
21 |
22 | def to_ts(ds: str) -> datetime:
23 | """Convert a DS string to a `datetime`."""
24 | return datetime.strptime(ds, DS_FRMT)
25 |
26 |
27 | def _ts_to_pandas(ts: int) -> pd.Timestamp:
28 | """Convert int timestamp into pandas timestamp."""
29 | return pd.Timestamp(datetime.fromtimestamp(ts))
30 |
31 |
32 | def _pandas_ts_to_int(ts: pd.Timestamp) -> int:
33 | """Convert int timestamp into pandas timestamp."""
34 | # pyre-fixme[7]: Expected `int` but got `float`.
35 | return ts.to_pydatetime().timestamp()
36 |
37 |
38 | def current_timestamp_in_millis() -> int:
39 | """Grab current timestamp in milliseconds as an int."""
40 | return int(round(time() * 1000))
41 |
42 |
43 | def timestamps_in_range(
44 | start: datetime, end: datetime, delta: timedelta
45 | ) -> Generator[datetime, None, None]:
46 | """Generator of timestamps in range [start, end], at intervals
47 | delta.
48 | """
49 | curr = start
50 | while curr <= end:
51 | yield curr
52 | curr += delta
53 |
--------------------------------------------------------------------------------
/ax/core/tests/test_types.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.core.types import merge_model_predict
8 | from ax.utils.common.testutils import TestCase
9 |
10 |
11 | class TypesTest(TestCase):
12 | def setUp(self):
13 | self.num_arms = 2
14 | mu = {"m1": [0.0, 0.5], "m2": [0.1, 0.6]}
15 | cov = {
16 | "m1": {"m1": [0.0, 0.0], "m2": [0.0, 0.0]},
17 | "m2": {"m1": [0.0, 0.0], "m2": [0.0, 0.0]},
18 | }
19 | self.predict = (mu, cov)
20 |
21 | def testMergeModelPredict(self):
22 | mu_append = {"m1": [0.6], "m2": [0.7]}
23 | cov_append = {
24 | "m1": {"m1": [0.0], "m2": [0.0]},
25 | "m2": {"m1": [0.0], "m2": [0.0]},
26 | }
27 | merged_predicts = merge_model_predict(self.predict, (mu_append, cov_append))
28 | self.assertEqual(len(merged_predicts[0]["m1"]), 3)
29 |
30 | def testMergeModelPredictFail(self):
31 | mu_append = {"m1": [0.6]}
32 | cov_append = {
33 | "m1": {"m1": [0.0], "m2": [0.0]},
34 | "m2": {"m1": [0.0], "m2": [0.0]},
35 | }
36 | with self.assertRaises(ValueError):
37 | merge_model_predict(self.predict, (mu_append, cov_append))
38 |
39 | mu_append = {"m1": [0.6], "m2": [0.7]}
40 | cov_append = {"m1": {"m1": [0.0], "m2": [0.0]}}
41 | with self.assertRaises(ValueError):
42 | merge_model_predict(self.predict, (mu_append, cov_append))
43 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_traces.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import plotly.graph_objects as go
9 | from ax.modelbridge.registry import Models
10 | from ax.plot.base import AxPlotConfig
11 | from ax.plot.trace import (
12 | optimization_trace_single_method,
13 | optimization_trace_single_method_plotly,
14 | )
15 | from ax.utils.common.testutils import TestCase
16 | from ax.utils.testing.core_stubs import get_branin_experiment
17 | from ax.utils.testing.mock import fast_botorch_optimize
18 |
19 |
20 | class TracesTest(TestCase):
21 | @fast_botorch_optimize
22 | def testTraces(self):
23 | exp = get_branin_experiment(with_batch=True)
24 | exp.trials[0].run()
25 | model = Models.BOTORCH(
26 | # Model bridge kwargs
27 | experiment=exp,
28 | data=exp.fetch_data(),
29 | )
30 | # Assert that each type of plot can be constructed successfully
31 | plot = optimization_trace_single_method_plotly(
32 | np.array([[1, 2, 3], [4, 5, 6]]),
33 | list(model.metric_names)[0],
34 | optimization_direction="minimize",
35 | )
36 | self.assertIsInstance(plot, go.Figure)
37 | plot = optimization_trace_single_method(
38 | np.array([[1, 2, 3], [4, 5, 6]]),
39 | list(model.metric_names)[0],
40 | optimization_direction="minimize",
41 | )
42 | self.assertIsInstance(plot, AxPlotConfig)
43 |
--------------------------------------------------------------------------------
/scripts/patch_site_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import re
9 |
10 |
11 | def patch_config(
12 | config_file: str, base_url: str = None, disable_algolia: bool = True
13 | ) -> None:
14 | config = open(config_file, "r").read()
15 |
16 | if base_url is not None:
17 | config = re.sub("baseUrl = '/';", "baseUrl = '{}';".format(base_url), config)
18 | if disable_algolia is True:
19 | config = re.sub(
20 | "const includeAlgolia = true;", "const includeAlgolia = false;", config
21 | )
22 |
23 | with open(config_file, "w") as outfile:
24 | outfile.write(config)
25 |
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser(
29 | description="Path Docusaurus siteConfig.js file when building site."
30 | )
31 | parser.add_argument(
32 | "-f",
33 | "--config_file",
34 | metavar="path",
35 | required=True,
36 | help="Path to configuration file.",
37 | )
38 | parser.add_argument(
39 | "-b",
40 | "--base_url",
41 | type=str,
42 | required=False,
43 | help="Value for baseUrl.",
44 | default=None,
45 | )
46 | parser.add_argument(
47 | "--disable_algolia",
48 | required=False,
49 | action="store_true",
50 | help="Disable algolia.",
51 | )
52 | args = parser.parse_args()
53 | patch_config(args.config_file, args.base_url, args.disable_algolia)
54 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_equality.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from datetime import datetime
8 |
9 | import pandas as pd
10 | from ax.utils.common.equality import (
11 | dataframe_equals,
12 | datetime_equals,
13 | equality_typechecker,
14 | same_elements,
15 | )
16 | from ax.utils.common.testutils import TestCase
17 |
18 |
19 | class EqualityTest(TestCase):
20 | def testEqualityTypechecker(self):
21 | @equality_typechecker
22 | def eq(x, y):
23 | return x == y
24 |
25 | self.assertFalse(eq(5, 5.0))
26 | self.assertTrue(eq(5, 5))
27 |
28 | def testListsEquals(self):
29 | self.assertFalse(same_elements([0], [0, 1]))
30 | self.assertFalse(same_elements([1, 0], [0, 2]))
31 | self.assertTrue(same_elements([1, 0], [0, 1]))
32 |
33 | def testDatetimeEquals(self):
34 | now = datetime.now()
35 | self.assertTrue(datetime_equals(None, None))
36 | self.assertFalse(datetime_equals(None, now))
37 | self.assertTrue(datetime_equals(now, now))
38 |
39 | def testDataframeEquals(self):
40 | pd1 = pd.DataFrame.from_records([{"x": 100, "y": 200}])
41 | pd2 = pd.DataFrame.from_records([{"y": 200, "x": 100}])
42 | pd3 = pd.DataFrame.from_records([{"x": 100, "y": 300}])
43 |
44 | self.assertTrue(dataframe_equals(pd.DataFrame(), pd.DataFrame()))
45 | self.assertTrue(dataframe_equals(pd1, pd2))
46 | self.assertFalse(dataframe_equals(pd1, pd3))
47 |
--------------------------------------------------------------------------------
/website/static/js/plotUtils.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | // helper functions used across multiple plots
9 | function rgb(rgb_array) {
10 | return 'rgb(' + rgb_array.join() + ')';
11 | }
12 |
13 | function copy_and_reverse(arr) {
14 | const copy = arr.slice();
15 | copy.reverse();
16 | return copy;
17 | }
18 |
19 | function axis_range(grid, is_log) {
20 | return is_log ?
21 | [Math.log10(Math.min(...grid)), Math.log10(Math.max(...grid))]:
22 | [Math.min(...grid), Math.max(...grid)];
23 | }
24 |
25 | function relativize_data(f, sd, rel, arm_data, metric) {
26 | // if relative, extract status quo & compute ratio
27 | const f_final = rel === true ? [] : f;
28 | const sd_final = rel === true ? []: sd;
29 |
30 | if (rel === true) {
31 | const f_sq = (
32 | arm_data['in_sample'][arm_data['status_quo_name']]['y'][metric]
33 | );
34 | const sd_sq = (
35 | arm_data['in_sample'][arm_data['status_quo_name']]['se'][metric]
36 | );
37 |
38 | for (let i = 0; i < f.length; i++) {
39 | res = relativize(f[i], sd[i], f_sq, sd_sq);
40 | f_final.push(100 * res[0]);
41 | sd_final.push(100 * res[1]);
42 | }
43 | }
44 |
45 | return [f_final, sd_final];
46 | }
47 |
48 | function relativize(m_t, sem_t, m_c, sem_c) {
49 | r_hat = (
50 | (m_t - m_c) / Math.abs(m_c) -
51 | Math.pow(sem_c, 2) * m_t / Math.pow(Math.abs(m_c), 3)
52 | );
53 | variance = (
54 | (Math.pow(sem_t, 2) + Math.pow((m_t / m_c * sem_c), 2)) /
55 | Math.pow(m_c, 2)
56 | )
57 | return [r_hat, Math.sqrt(variance)];
58 | }
59 |
--------------------------------------------------------------------------------
/ax/exceptions/generation_strategy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing # noqa F401, this is to enable type-checking
8 |
9 | from ax.exceptions.core import AxError, OptimizationComplete
10 |
11 |
12 | class MaxParallelismReachedException(AxError):
13 | """Special exception indicating that maximum number of trials running in
14 | parallel set on a given step (as `GenerationStep.max_parallelism`) has been
15 | reached. Upon getting this exception, users should wait until more trials
16 | are completed with data, to generate new trials.
17 | """
18 |
19 | def __init__(self, step_index: int, model_name: str, num_running: int) -> None:
20 | super().__init__(
21 | f"Maximum parallelism for generation step #{step_index} ({model_name})"
22 | f" has been reached: {num_running} trials are currently 'running'. Some "
23 | "trials need to be completed before more trials can be generated. See "
24 | "https://ax.dev/docs/bayesopt.html to understand why limited parallelism "
25 | "improves performance of Bayesian optimization."
26 | )
27 |
28 |
29 | class GenerationStrategyCompleted(OptimizationComplete):
30 | """Special exception indicating that the generation strategy has been
31 | completed.
32 | """
33 |
34 | pass
35 |
36 |
37 | class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted):
38 | """Special exception indicating that the generation strategy is repeatedly
39 | suggesting previously sampled points.
40 | """
41 |
42 | pass
43 |
--------------------------------------------------------------------------------
/ax/exceptions/data_provider.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Iterable
8 |
9 |
10 | class DataProviderError(Exception):
11 | """Base Exception for Ax DataProviders.
12 |
13 | The type of the data provider must be included.
14 | The raw error is stored in the data_provider_error section,
15 | and an Ax-friendly message is stored as the actual error message.
16 | """
17 |
18 | def __init__(
19 | self, message: str, data_provider: str, data_provider_error: Any
20 | ) -> None:
21 | self.message = message
22 | self.data_provider = data_provider
23 | self.data_provider_error = data_provider_error
24 |
25 | def __str__(self) -> str:
26 | return (
27 | "{message}. \n Error thrown by: {dp} data provider \n"
28 | + "Native {dp} data provider error: {dp_error}"
29 | ).format(
30 | dp=self.data_provider,
31 | message=self.message,
32 | dp_error=self.data_provider_error,
33 | )
34 |
35 |
36 | class MissingDataError(Exception):
37 | def __init__(self, missing_trial_indexes: Iterable[int]) -> None:
38 | missing_trial_str = ", ".join([str(index) for index in missing_trial_indexes])
39 | self.message: str = (
40 | f"Unable to find data for the following trials: {missing_trial_str} "
41 | "consider updating the data fetching kwargs or manually fetching "
42 | "data via `refetch_data()`"
43 | )
44 |
45 | def __str__(self) -> str:
46 | return self.message
47 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_fitted_scatter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import plotly.graph_objects as go
8 | from ax.modelbridge.registry import Models
9 | from ax.plot.base import AxPlotConfig
10 | from ax.plot.scatter import interact_fitted, interact_fitted_plotly
11 | from ax.utils.common.testutils import TestCase
12 | from ax.utils.testing.core_stubs import get_branin_experiment
13 | from ax.utils.testing.mock import fast_botorch_optimize
14 |
15 |
16 | class FittedScatterTest(TestCase):
17 | @fast_botorch_optimize
18 | def test_fitted_scatter(self):
19 | exp = get_branin_experiment(with_str_choice_param=True, with_batch=True)
20 | exp.trials[0].run()
21 | model = Models.BOTORCH(
22 | # Model bridge kwargs
23 | experiment=exp,
24 | data=exp.fetch_data(),
25 | )
26 | # Assert that each type of plot can be constructed successfully
27 | plot = interact_fitted_plotly(model=model, rel=False)
28 | self.assertIsInstance(plot, go.Figure)
29 | plot = interact_fitted(model=model, rel=False)
30 | self.assertIsInstance(plot, AxPlotConfig)
31 |
32 | # Make sure all parameters and metrics are displayed in tooltips
33 | tooltips = list(exp.parameters.keys()) + list(exp.metrics.keys())
34 | for d in plot.data["data"]:
35 | # Only check scatter plots hoverovers
36 | if d["type"] != "scatter":
37 | continue
38 | for text in d["text"]:
39 | for tt in tooltips:
40 | self.assertTrue(tt in text)
41 |
--------------------------------------------------------------------------------
/ax/utils/testing/metrics/backend_simulator_map.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, List
8 |
9 | from ax.core.base_trial import BaseTrial
10 | from ax.core.map_data import MapData
11 | from ax.metrics.noisy_function_map import NoisyFunctionMapMetric
12 |
13 |
14 | class BackendSimulatorTimestampMapMetric(NoisyFunctionMapMetric):
15 | """A metric that interfaces with an underlying ``BackendSimulator`` and
16 | returns timestamp map data."""
17 |
18 | def fetch_trial_data(
19 | self, trial: BaseTrial, noisy: bool = True, **kwargs: Any
20 | ) -> MapData:
21 | """Fetch data for one trial."""
22 | backend_simulator = trial.experiment.runner.simulator # pyre-ignore[16]
23 | sim_trial = backend_simulator.get_sim_trial_by_index(trial.index)
24 | end_time = (
25 | backend_simulator.time
26 | if sim_trial.sim_completed_time is None
27 | else sim_trial.sim_completed_time
28 | )
29 | timestamps = self.convert_to_timestamps(
30 | start_time=sim_trial.sim_start_time, end_time=end_time
31 | )
32 | timestamp_kwargs = {"map_keys": ["timestamp"], "timestamp": timestamps}
33 | return NoisyFunctionMapMetric.fetch_trial_data(
34 | self, trial=trial, noisy=noisy, **kwargs, **timestamp_kwargs
35 | )
36 |
37 | def convert_to_timestamps(self, start_time: float, end_time: float) -> List[float]:
38 | """Given a starting and current time, get the list of intermediate
39 | timestamps at which we have observations."""
40 | raise NotImplementedError
41 |
--------------------------------------------------------------------------------
/website/tutorials.json:
--------------------------------------------------------------------------------
1 | {
2 | "API Comparison": [
3 | {
4 | "id": "gpei_hartmann_loop",
5 | "title": "Loop API"
6 | },
7 | {
8 | "id": "gpei_hartmann_service",
9 | "title": "Service API"
10 | },
11 | {
12 | "id": "gpei_hartmann_developer",
13 | "title": "Developer API"
14 | }
15 | ],
16 | "Deep Dives": [
17 | {
18 | "id": "visualizations",
19 | "title": "Visualizations"
20 | },
21 | {
22 | "id": "generation_strategy",
23 | "title": "Generation Strategy"
24 | },
25 | {
26 | "id": "scheduler",
27 | "title": "Scheduler"
28 | },
29 | {
30 | "id": "modular_botax",
31 | "title": "Modular `BoTorchModel`"
32 | }
33 | ],
34 | "Bayesian Optimization": [
35 | {
36 | "id": "tune_cnn",
37 | "title": "Hyperparameter Optimization for PyTorch"
38 | },
39 | {
40 | "id": "raytune_pytorch_cnn",
41 | "title": "Hyperparameter Optimization via Raytune"
42 | },
43 | {
44 | "id": "multi_task",
45 | "title": "Multi-Task Modeling"
46 | },
47 | {
48 | "id": "multiobjective_optimization",
49 | "title": "Multi-Objective Optimization"
50 | },
51 | {
52 | "id": "saasbo",
53 | "title": "High-Dimensional Bayesian Optimization with Sparse Axis-Aligned Subspaces (SAASBO)"
54 | },
55 | {
56 | "id": "saasbo_nehvi",
57 | "title": "Fully Bayesian, High-Dimensional, Multi-Objective Optimization"
58 | }
59 | ],
60 | "Field Experiments": [
61 | {
62 | "id": "factorial",
63 | "title": "Bandit Optimization"
64 | },
65 | {
66 | "dir": "human_in_the_loop",
67 | "id": "human_in_the_loop",
68 | "title": "Human-in-the-Loop Optimization"
69 | }
70 | ]
71 | }
72 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/reduced_state.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List
8 |
9 | from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun
10 | from sqlalchemy.orm import defaultload, lazyload, strategy_options
11 | from sqlalchemy.orm.attributes import InstrumentedAttribute
12 |
13 |
14 | GR_LARGE_MODEL_ATTRS: List[InstrumentedAttribute] = [ # pyre-ignore[9]
15 | SQAGeneratorRun.model_kwargs,
16 | SQAGeneratorRun.bridge_kwargs,
17 | SQAGeneratorRun.model_state_after_gen,
18 | SQAGeneratorRun.gen_metadata,
19 | ]
20 |
21 |
22 | GR_PARAMS_METRICS_COLS = [
23 | "parameters",
24 | "parameter_constraints",
25 | "metrics",
26 | ]
27 |
28 |
29 | def get_query_options_to_defer_immutable_duplicates() -> List[strategy_options.Load]:
30 | """Returns the query options that defer loading of attributes that are duplicated
31 | on each trial (like search space attributes and metrics). These attributes do not
32 | need to be loaded for experiments with immutable search space and optimization
33 | configuration.
34 | """
35 | options = [lazyload(f"generator_runs.{col}") for col in GR_PARAMS_METRICS_COLS]
36 | return options
37 |
38 |
39 | def get_query_options_to_defer_large_model_cols() -> List[strategy_options.Load]:
40 | """Returns the query options that defer loading of model-state-related columns
41 | of generator runs, which can be large and are not needed on every generator run
42 | when loading experiment and generation strategy in reduced state.
43 | """
44 | return [
45 | defaultload("generator_runs").defer(col.key) for col in GR_LARGE_MODEL_ATTRS
46 | ]
47 |
--------------------------------------------------------------------------------
/ax/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.core import (
8 | Arm,
9 | BatchTrial,
10 | ChoiceParameter,
11 | ComparisonOp,
12 | Data,
13 | Experiment,
14 | FixedParameter,
15 | GeneratorRun,
16 | Metric,
17 | MultiObjective,
18 | MultiObjectiveOptimizationConfig,
19 | Objective,
20 | ObjectiveThreshold,
21 | OptimizationConfig,
22 | OrderConstraint,
23 | OutcomeConstraint,
24 | Parameter,
25 | ParameterConstraint,
26 | ParameterType,
27 | RangeParameter,
28 | Runner,
29 | SearchSpace,
30 | SumConstraint,
31 | Trial,
32 | )
33 | from ax.modelbridge import Models
34 | from ax.service import OptimizationLoop, optimize
35 | from ax.storage import json_load, json_save
36 |
37 |
38 | try:
39 | # pyre-fixme[21]: Could not find a module... to import `ax.version`.
40 | from ax.version import version as __version__
41 | except Exception:
42 | __version__ = "Unknown"
43 |
44 | __all__ = [
45 | "Arm",
46 | "BatchTrial",
47 | "ChoiceParameter",
48 | "ComparisonOp",
49 | "Data",
50 | "Experiment",
51 | "FixedParameter",
52 | "GeneratorRun",
53 | "Metric",
54 | "Models",
55 | "MultiObjective",
56 | "MultiObjectiveOptimizationConfig",
57 | "Objective",
58 | "ObjectiveThreshold",
59 | "OptimizationConfig",
60 | "OptimizationLoop",
61 | "OrderConstraint",
62 | "OutcomeConstraint",
63 | "Parameter",
64 | "ParameterConstraint",
65 | "ParameterType",
66 | "RangeParameter",
67 | "Runner",
68 | "SearchSpace",
69 | "SumConstraint",
70 | "Trial",
71 | "optimize",
72 | "json_save",
73 | "json_load",
74 | ]
75 |
--------------------------------------------------------------------------------
/ax/benchmark/tests/test_benchmark_method.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.benchmark.benchmark_method import BenchmarkMethod
7 | from ax.exceptions.core import UserInputError
8 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
9 | from ax.modelbridge.registry import Models
10 | from ax.service.scheduler import SchedulerOptions
11 | from ax.utils.common.testutils import TestCase
12 |
13 |
14 | class TestBenchmarkMethod(TestCase):
15 | def test_benchmark_method(self):
16 | gs = GenerationStrategy(
17 | steps=[
18 | GenerationStep(
19 | model=Models.SOBOL,
20 | num_trials=10,
21 | )
22 | ],
23 | name="SOBOL",
24 | )
25 | options = SchedulerOptions(total_trials=10)
26 | method = BenchmarkMethod(
27 | name="Sobol10", generation_strategy=gs, scheduler_options=options
28 | )
29 |
30 | self.assertEqual(method.generation_strategy, gs)
31 | self.assertEqual(method.scheduler_options, options)
32 |
33 | def test_total_trials_none(self):
34 | gs = GenerationStrategy(
35 | steps=[
36 | GenerationStep(
37 | model=Models.SOBOL,
38 | num_trials=10,
39 | )
40 | ],
41 | name="SOBOL",
42 | )
43 | options = SchedulerOptions()
44 |
45 | with self.assertRaisesRegex(
46 | UserInputError, "SchedulerOptions.total_trials may not be None"
47 | ):
48 | BenchmarkMethod(
49 | name="Sobol10", generation_strategy=gs, scheduler_options=options
50 | )
51 |
--------------------------------------------------------------------------------
/sphinx/source/service.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.service
5 | ===================================
6 |
7 | .. automodule:: ax.service
8 | .. currentmodule:: ax.service
9 |
10 |
11 | Ax Client
12 | ~~~~~~~~~~~~~~~~
13 |
14 | .. automodule:: ax.service.ax_client
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 | Managed Loop
20 | ~~~~~~~~~~~~~~~~
21 |
22 | .. automodule:: ax.service.managed_loop
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 |
28 | Scheduler
29 | ~~~~~~~~~~
30 |
31 | .. automodule:: ax.service.scheduler
32 | :members:
33 | :undoc-members:
34 | :show-inheritance:
35 |
36 | .. automodule:: ax.service.utils.scheduler_options
37 | :members:
38 | :undoc-members:
39 | :show-inheritance:
40 |
41 | Utils
42 | ------
43 |
44 | Best Point Identification
45 | ~~~~~~~~~~~~~~~~~~~~~~~~~
46 |
47 | .. automodule:: ax.service.utils.best_point_mixin
48 | :members:
49 | :undoc-members:
50 | :show-inheritance:
51 |
52 |
53 | .. automodule:: ax.service.utils.best_point
54 | :members:
55 | :undoc-members:
56 | :show-inheritance:
57 |
58 |
59 | Instantiation
60 | ~~~~~~~~~~~~~~
61 |
62 | .. automodule:: ax.service.utils.instantiation
63 | :members:
64 | :undoc-members:
65 | :show-inheritance:
66 |
67 |
68 | Reporting
69 | ~~~~~~~~~~~~~~
70 |
71 | .. automodule:: ax.service.utils.report_utils
72 | :members:
73 | :undoc-members:
74 | :show-inheritance:
75 |
76 |
77 | WithDBSettingsBase
78 | ~~~~~~~~~~~~~~~~~~
79 |
80 | .. automodule:: ax.service.utils.with_db_settings_base
81 | :members:
82 | :undoc-members:
83 | :show-inheritance:
84 |
85 |
86 | EarlyStopping
87 | ~~~~~~~~~~~~~
88 |
89 | .. automodule:: ax.service.utils.early_stopping
90 | :members:
91 | :undoc-members:
92 | :show-inheritance:
93 |
--------------------------------------------------------------------------------
/ax/benchmark/methods/saasbo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.benchmark.benchmark_method import BenchmarkMethod
7 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
8 | from ax.modelbridge.registry import Models
9 | from ax.service.scheduler import SchedulerOptions
10 |
11 |
12 | def get_saasbo_default() -> BenchmarkMethod:
13 | generation_strategy = GenerationStrategy(
14 | name="SOBOL+FULLYBAYESIAN::default",
15 | steps=[
16 | GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
17 | GenerationStep(
18 | model=Models.FULLYBAYESIAN,
19 | num_trials=-1,
20 | max_parallelism=1,
21 | ),
22 | ],
23 | )
24 |
25 | scheduler_options = SchedulerOptions(total_trials=30)
26 |
27 | return BenchmarkMethod(
28 | name=generation_strategy.name,
29 | generation_strategy=generation_strategy,
30 | scheduler_options=scheduler_options,
31 | )
32 |
33 |
34 | def get_saasbo_moo_default() -> BenchmarkMethod:
35 | generation_strategy = GenerationStrategy(
36 | name="SOBOL+FULLYBAYESIANMOO::default",
37 | steps=[
38 | GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
39 | GenerationStep(
40 | model=Models.FULLYBAYESIANMOO,
41 | num_trials=-1,
42 | max_parallelism=1,
43 | ),
44 | ],
45 | )
46 |
47 | scheduler_options = SchedulerOptions(total_trials=30)
48 |
49 | return BenchmarkMethod(
50 | name=generation_strategy.name,
51 | generation_strategy=generation_strategy,
52 | scheduler_options=scheduler_options,
53 | )
54 |
--------------------------------------------------------------------------------
/ax/utils/common/docutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | """Support functions for sphinx et. al
10 | """
11 |
12 |
13 | from typing import Any, Callable, TypeVar
14 |
15 |
16 | _T = TypeVar("_T")
17 |
18 |
19 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
20 | # pyre-ignore[34]: T77127616
21 | def copy_doc(src: Callable[..., Any]) -> Callable[[_T], _T]:
22 | """A decorator that copies the docstring of another object
23 |
24 | Since ``sphinx`` actually loads the python modules to grab the docstrings
25 | this works with both ``sphinx`` and the ``help`` function.
26 |
27 | .. code:: python
28 |
29 | class Cat(Mamal):
30 |
31 | @property
32 | @copy_doc(Mamal.is_feline)
33 | def is_feline(self) -> true:
34 | ...
35 | """
36 | # It would be tempting to try to get the doc through the class the method
37 | # is bound to (via __self__) but decorators are called before __self__ is
38 | # assigned.
39 | # One other solution would be to use a decorator on classes that would fill
40 | # all the missing docstrings but we want to be able to detect syntactically
41 | # when docstrings are copied to keep things nice and simple
42 |
43 | if src.__doc__ is None:
44 | # pyre-fixme[16]: `_T` has no attribute `__qualname__`.
45 | raise ValueError(f"{src.__qualname__} has no docstring to copy")
46 |
47 | def copy_doc(dst: _T) -> _T:
48 | if dst.__doc__ is not None:
49 | # pyre-fixme[16]: `_T` has no attribute `__qualname__`.
50 | raise ValueError(f"{dst.__qualname__} already has a docstring")
51 | dst.__doc__ = src.__doc__
52 | return dst
53 |
54 | return copy_doc
55 |
--------------------------------------------------------------------------------
/ax/models/tests/test_discrete.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | from ax.models.discrete_base import DiscreteModel
9 | from ax.utils.common.testutils import TestCase
10 |
11 |
12 | class DiscreteModelTest(TestCase):
13 | def setUp(self):
14 | pass
15 |
16 | def test_discrete_model_get_state(self):
17 | discrete_model = DiscreteModel()
18 | self.assertEqual(discrete_model._get_state(), {})
19 |
20 | def test_discrete_model_feature_importances(self):
21 | discrete_model = DiscreteModel()
22 | with self.assertRaises(NotImplementedError):
23 | discrete_model.feature_importances()
24 |
25 | def testDiscreteModelFit(self):
26 | discrete_model = DiscreteModel()
27 | discrete_model.fit(
28 | Xs=[[[0]]],
29 | Ys=[[0]],
30 | Yvars=[[1]],
31 | parameter_values=[[0, 1]],
32 | outcome_names=[],
33 | )
34 |
35 | def testdiscreteModelPredict(self):
36 | discrete_model = DiscreteModel()
37 | with self.assertRaises(NotImplementedError):
38 | discrete_model.predict([[0]])
39 |
40 | def testdiscreteModelGen(self):
41 | discrete_model = DiscreteModel()
42 | with self.assertRaises(NotImplementedError):
43 | discrete_model.gen(
44 | n=1, parameter_values=[[0, 1]], objective_weights=np.array([1])
45 | )
46 |
47 | def testdiscreteModelCrossValidate(self):
48 | discrete_model = DiscreteModel()
49 | with self.assertRaises(NotImplementedError):
50 | discrete_model.cross_validate(
51 | Xs_train=[[[0]]], Ys_train=[[1]], Yvars_train=[[1]], X_test=[[1]]
52 | )
53 |
--------------------------------------------------------------------------------
/website/static/img/ax.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ax/models/discrete/eb_thompson.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from typing import List, Tuple
9 |
10 | import numpy as np
11 | from ax.models.discrete.thompson import ThompsonSampler
12 | from ax.utils.common.logger import get_logger
13 | from ax.utils.stats.statstools import positive_part_james_stein
14 |
15 |
16 | logger: logging.Logger = get_logger(__name__)
17 |
18 |
19 | class EmpiricalBayesThompsonSampler(ThompsonSampler):
20 | """Generator for Thompson sampling using Empirical Bayes estimates.
21 |
22 | The generator applies positive-part James-Stein Estimator to the data
23 | passed in via `fit` and then performs Thompson Sampling.
24 | """
25 |
26 | def _fit_Ys_and_Yvars(
27 | self, Ys: List[List[float]], Yvars: List[List[float]], outcome_names: List[str]
28 | ) -> Tuple[List[List[float]], List[List[float]]]:
29 | newYs = []
30 | newYvars = []
31 | for i, (Y, Yvar) in enumerate(zip(Ys, Yvars)):
32 | newY, newYvar = self._apply_shrinkage(Y, Yvar, i)
33 | newYs.append(newY)
34 | newYvars.append(newYvar)
35 | return newYs, newYvars
36 |
37 | def _apply_shrinkage(
38 | self, Y: List[float], Yvar: List[float], outcome: int
39 | ) -> Tuple[List[float], List[float]]:
40 | npY = np.array(Y)
41 | npYvar = np.array(Yvar)
42 | npYsem = np.sqrt(Yvar)
43 | try:
44 | npY, npYsem = positive_part_james_stein(means=npY, sems=npYsem)
45 | except ValueError as e:
46 | logger.warning(
47 | str(e) + f" Raw (unshrunk) estimates used for outcome: {outcome}"
48 | )
49 | Y = npY.tolist()
50 | npYvar = npYsem**2
51 | Yvar = npYvar.tolist()
52 | return Y, Yvar
53 |
--------------------------------------------------------------------------------
/ax/metrics/botorch_test_problem.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from typing import Any, Optional
7 |
8 | import pandas as pd
9 | from ax.core.base_trial import BaseTrial
10 | from ax.core.data import Data
11 | from ax.core.metric import Metric
12 |
13 |
14 | class BotorchTestProblemMetric(Metric):
15 | """A Metric for retriving information from a BotorchTestProblemRunner.
16 | A BotorchTestProblemRunner will attach the result of a call to
17 | BaseTestProblem.forward per Arm on a given trial, and this Metric will extract the
18 | proper value from the resulting tensor given its index.
19 | """
20 |
21 | def __init__(self, name: str, noise_sd: float, index: Optional[int] = None) -> None:
22 | super().__init__(name=name)
23 | self.noise_sd = noise_sd
24 | self.index = index
25 |
26 | def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> Data:
27 | # run_metadata["Ys"] can be either a list of results or a single float
28 | mean = (
29 | [
30 | trial.run_metadata["Ys"][name][self.index]
31 | for name, arm in trial.arms_by_name.items()
32 | ]
33 | if self.index is not None
34 | else [
35 | trial.run_metadata["Ys"][name]
36 | for name, arm in trial.arms_by_name.items()
37 | ]
38 | )
39 | df = pd.DataFrame(
40 | {
41 | "arm_name": [name for name, _ in trial.arms_by_name.items()],
42 | "metric_name": self.name,
43 | "mean": mean,
44 | # If no noise_std is returned then Botorch evaluated the true function
45 | "sem": self.noise_sd,
46 | "trial_index": trial.index,
47 | }
48 | )
49 |
50 | return Data(df=df)
51 |
--------------------------------------------------------------------------------
/ax/utils/testing/unittest_conventions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import importlib
8 | import pathlib
9 | import sys
10 | import unittest
11 |
12 | import __test_modules__
13 | from ax.utils.common import testutils
14 |
15 |
16 | def get_all_subclasses(cls):
17 | """Reccursively get all the subclasses of cls"""
18 | for x in cls.__subclasses__(): # subclasses only contains direct decendants
19 | yield x
20 | yield from get_all_subclasses(x)
21 |
22 |
23 | class TestUnittestConventions(testutils.TestCase):
24 | def test_uses_ae_unittest(self):
25 | """Check that all of our tests are inheriting from our own base class
26 |
27 | Our base class does a bit more (like making sure we don't use any of python's
28 | deprecated `assert` functions) so we want to enforce its usage everywhere.
29 | """
30 | test_modules = set(__test_modules__.TEST_MODULES)
31 | # Make sure everything is loaded
32 | for m in test_modules:
33 | importlib.import_module(m)
34 | test_cases = [
35 | cls
36 | for cls in get_all_subclasses(unittest.TestCase)
37 | if cls.__module__ in test_modules
38 | ]
39 | base = testutils.TestCase
40 | for t in test_cases:
41 | with self.subTest(t.__name__):
42 | if not issubclass(t, base):
43 | abs_path = pathlib.Path(sys.modules[t.__module__].__file__)
44 | root = pathlib.Path(__test_modules__.__file__).parent
45 | filename = abs_path.relative_to(root)
46 | self.fail(
47 | f"in {filename}: {t.__qualname__} should inherit from "
48 | f"{base.__module__}.{base.__name__}"
49 | )
50 |
--------------------------------------------------------------------------------
/ax/core/tests/test_runner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from unittest import mock
8 |
9 | from ax.core.base_trial import BaseTrial
10 | from ax.core.runner import Runner
11 | from ax.utils.common.testutils import TestCase
12 | from ax.utils.testing.core_stubs import get_batch_trial, get_trial
13 |
14 |
15 | class DummyRunner(Runner):
16 | def run(self, trial: BaseTrial):
17 | return {"metadatum": f"value_for_trial_{trial.index}"}
18 |
19 |
20 | class RunnerTest(TestCase):
21 | def setUp(self):
22 | self.dummy_runner = DummyRunner()
23 | self.trials = [get_trial(), get_batch_trial()]
24 |
25 | def test_base_runner_staging_required(self):
26 | self.assertFalse(self.dummy_runner.staging_required)
27 |
28 | def test_base_runner_stop(self):
29 | with self.assertRaises(NotImplementedError):
30 | self.dummy_runner.stop(trial=mock.Mock(), reason="")
31 |
32 | def test_base_runner_clone(self):
33 | runner_clone = self.dummy_runner.clone()
34 | self.assertIsInstance(runner_clone, DummyRunner)
35 | self.assertEqual(runner_clone, self.dummy_runner)
36 |
37 | def test_base_runner_run_multiple(self):
38 | metadata = self.dummy_runner.run_multiple(trials=self.trials)
39 | self.assertEqual(
40 | metadata,
41 | {t.index: {"metadatum": f"value_for_trial_{t.index}"} for t in self.trials},
42 | )
43 | self.assertEqual({}, self.dummy_runner.run_multiple(trials=[]))
44 |
45 | def test_base_runner_poll_trial_status(self):
46 | with self.assertRaises(NotImplementedError):
47 | self.dummy_runner.poll_trial_status(trials=self.trials)
48 |
49 | def test_poll_available_capacity(self):
50 | self.assertEqual(self.dummy_runner.poll_available_capacity(), -1)
51 |
--------------------------------------------------------------------------------
/ax/metrics/jenatton.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from typing import Any, Optional
7 |
8 | import pandas as pd
9 | from ax.core.base_trial import BaseTrial
10 | from ax.core.data import Data
11 | from ax.core.metric import Metric
12 | from ax.utils.common.typeutils import not_none
13 |
14 |
15 | class JenattonMetric(Metric):
16 | def __init__(
17 | self,
18 | name: str = "jenatton",
19 | ) -> None:
20 | super().__init__(name=name)
21 |
22 | @staticmethod
23 | def _f(
24 | x1: Optional[int] = None,
25 | x2: Optional[int] = None,
26 | x3: Optional[int] = None,
27 | x4: Optional[float] = None,
28 | x5: Optional[float] = None,
29 | x6: Optional[float] = None,
30 | x7: Optional[float] = None,
31 | r8: Optional[float] = None,
32 | r9: Optional[float] = None,
33 | ) -> float:
34 | if x1 == 0:
35 | if x2 == 0:
36 | return not_none(x4) ** 2 + 0.1 + not_none(r8)
37 | else:
38 | return not_none(x5) ** 2 + 0.2 + not_none(r8)
39 | else:
40 | if x3 == 0:
41 | return not_none(x6) ** 2 + 0.3 + not_none(r9)
42 | else:
43 | return not_none(x7) ** 2 + 0.4 + not_none(r9)
44 |
45 | def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> Data:
46 | # pyre-ignore [6]
47 | mean = [self._f(**arm.parameters) for _, arm in trial.arms_by_name.items()]
48 | df = pd.DataFrame(
49 | {
50 | "arm_name": [name for name, _ in trial.arms_by_name.items()],
51 | "metric_name": self.name,
52 | "mean": mean,
53 | "sem": 0,
54 | "trial_index": trial.index,
55 | }
56 | )
57 |
58 | return Data(df=df)
59 |
--------------------------------------------------------------------------------
/ax/core/tests/test_parameter_distribution.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.core.parameter_distribution import ParameterDistribution
8 | from ax.exceptions.core import UserInputError
9 | from ax.utils.common.testutils import TestCase
10 | from scipy.stats._continuous_distns import norm_gen
11 | from scipy.stats._distn_infrastructure import rv_frozen
12 |
13 |
14 | class ParameterDistributionTest(TestCase):
15 | def test_parameter_distribution(self):
16 | dist = ParameterDistribution(
17 | parameters=["x1"],
18 | distribution_class="norm",
19 | distribution_parameters={"loc": 0.0, "scale": 1.0},
20 | multiplicative=True,
21 | )
22 | self.assertTrue(dist.multiplicative)
23 | dist_obj = dist.distribution
24 | self.assertEqual(dist.parameters, ["x1"])
25 | self.assertIsInstance(dist_obj, rv_frozen)
26 | self.assertIsInstance(dist_obj.dist, norm_gen)
27 | dist_kwds = dist_obj.kwds
28 | self.assertEqual(dist_kwds["loc"], 0.0)
29 | self.assertEqual(dist_kwds["scale"], 1.0)
30 |
31 | # Test repr.
32 | expected_repr = (
33 | "ParameterDistribution("
34 | "parameters=['x1'], "
35 | "distribution_class=norm, "
36 | "distribution_parameters={'loc': 0.0, 'scale': 1.0}, "
37 | "multiplicative=True)"
38 | )
39 | self.assertEqual(str(dist), expected_repr)
40 |
41 | # Test weird distribution name.
42 | dist = ParameterDistribution(
43 | parameters=["x1"],
44 | distribution_class="dummy_dist",
45 | distribution_parameters={},
46 | )
47 | self.assertFalse(dist.multiplicative)
48 | with self.assertRaises(UserInputError):
49 | dist.distribution
50 |
--------------------------------------------------------------------------------
/ax/global_stopping/strategies/base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from abc import ABC, abstractmethod
8 | from typing import Any, Dict, Tuple
9 |
10 | from ax.core.experiment import Experiment
11 | from ax.utils.common.base import Base
12 |
13 |
14 | class BaseGlobalStoppingStrategy(ABC, Base):
15 | """Interface for strategies used to stop the optimization.
16 |
17 | Note that this is different from the `BaseEarlyStoppingStrategy`,
18 | the functionality of which is to decide whether a trial with partial
19 | results available during evaluation should be stopped before
20 | fully completing. In global early stopping, the decision is about
21 | whether or not to stop the overall optimzation altogether (e.g. b/c
22 | the expected marginal gains of running additional evaluations do not
23 | justify the cost of running these trials).
24 | """
25 |
26 | def __init__(self, min_trials: int) -> None:
27 | """
28 | Initiating a base stopping strategy.
29 |
30 | Args:
31 | min_trials: Minimum number of trials before the stopping strategy kicks in.
32 | """
33 | self.min_trials = min_trials
34 |
35 | @abstractmethod
36 | def should_stop_optimization(
37 | self,
38 | experiment: Experiment,
39 | **kwargs: Dict[str, Any],
40 | ) -> Tuple[bool, str]:
41 | """Decide whether to stop optimization.
42 |
43 | Typical examples include stopping the optimization loop when the objective
44 | appears to not improve anymore.
45 |
46 | Args:
47 | experiment: Experiment that contains the trials and other contextual data.
48 |
49 | Returns:
50 | A Tuple with a boolean determining whether the optimization should stop,
51 | and a str declaring the reason for stopping.
52 | """
53 | pass # pragma: nocover
54 |
--------------------------------------------------------------------------------
/website/core/Footer.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | const React = require('react');
9 |
10 | class Footer extends React.Component {
11 |
12 | docUrl(doc, language) {
13 | const baseUrl = this.props.config.baseUrl;
14 | const docsUrl = this.props.config.docsUrl;
15 | const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`;
16 | const langPart = `${language ? `${language}/` : ''}`;
17 | return `${baseUrl}${docsPart}${langPart}${doc}`;
18 | }
19 |
20 | pageUrl(doc, language) {
21 | const baseUrl = this.props.config.baseUrl;
22 | return baseUrl + (language ? `${language}/` : '') + doc;
23 | }
24 |
25 | render() {
26 | const currentYear = new Date().getFullYear();
27 |
28 | return (
29 |
63 | );
64 | }
65 | }
66 |
67 | module.exports = Footer;
68 |
--------------------------------------------------------------------------------
/ax/service/tests/test_best_point.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.service.utils.best_point_mixin import BestPointMixin
7 | from ax.utils.common.testutils import TestCase
8 | from ax.utils.common.typeutils import not_none
9 | from ax.utils.testing.core_stubs import get_experiment_with_observations
10 |
11 |
12 | class TestBestPointMixin(TestCase):
13 | def test_get_trace(self):
14 | # Alias for easier access.
15 | get_trace = BestPointMixin.get_trace
16 |
17 | # Single objective, minimize.
18 | exp = get_experiment_with_observations(
19 | observations=[[11], [10], [9], [15], [5]], minimize=True
20 | )
21 | self.assertEqual(get_trace(exp), [11, 10, 9, 9, 5])
22 | # Same experiment with maximize via new optimization config.
23 | opt_conf = not_none(exp.optimization_config).clone()
24 | opt_conf.objective.minimize = False
25 | self.assertEqual(get_trace(exp, opt_conf), [11, 11, 11, 15, 15])
26 |
27 | # Scalarized.
28 | exp = get_experiment_with_observations(
29 | observations=[[1, 1], [2, 2], [3, 3]],
30 | scalarized=True,
31 | )
32 | self.assertEqual(get_trace(exp), [2, 4, 6])
33 |
34 | # Multi objective.
35 | exp = get_experiment_with_observations(
36 | observations=[[1, 1], [1, 2], [3, 3], [2, 4], [2, 1]],
37 | )
38 | self.assertEqual(get_trace(exp), [1, 2, 9, 11, 11])
39 |
40 | # W/ constraints.
41 | exp = get_experiment_with_observations(
42 | observations=[[1, 1, 1], [1, 2, -1], [3, 3, -1], [2, 4, 1], [2, 1, 1]],
43 | constrained=True,
44 | )
45 | self.assertEqual(get_trace(exp), [1, 1, 1, 8, 8])
46 |
47 | # W/ first objective being minimized.
48 | exp = get_experiment_with_observations(
49 | observations=[[1, 1], [-1, 2], [3, 3], [-2, 4], [2, 1]], minimize=True
50 | )
51 | self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8])
52 |
--------------------------------------------------------------------------------
/website/static/img/ax_lockup.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ax/core/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa F401
8 | from ax.core.arm import Arm
9 | from ax.core.batch_trial import BatchTrial
10 | from ax.core.data import Data
11 | from ax.core.experiment import Experiment
12 | from ax.core.generator_run import GeneratorRun
13 | from ax.core.metric import Metric
14 | from ax.core.objective import MultiObjective, Objective
15 | from ax.core.observation import Observation, ObservationData, ObservationFeatures
16 | from ax.core.optimization_config import (
17 | MultiObjectiveOptimizationConfig,
18 | OptimizationConfig,
19 | )
20 | from ax.core.outcome_constraint import (
21 | ComparisonOp,
22 | ObjectiveThreshold,
23 | OutcomeConstraint,
24 | )
25 | from ax.core.parameter import (
26 | ChoiceParameter,
27 | FixedParameter,
28 | Parameter,
29 | ParameterType,
30 | RangeParameter,
31 | )
32 | from ax.core.parameter_constraint import (
33 | OrderConstraint,
34 | ParameterConstraint,
35 | SumConstraint,
36 | )
37 | from ax.core.parameter_distribution import ParameterDistribution
38 | from ax.core.risk_measures import RiskMeasure
39 | from ax.core.runner import Runner
40 | from ax.core.search_space import SearchSpace
41 | from ax.core.trial import Trial
42 | from ax.core.types import TParameterization
43 |
44 |
45 | __all__ = [
46 | "Arm",
47 | "BatchTrial",
48 | "ChoiceParameter",
49 | "ComparisonOp",
50 | "Data",
51 | "Experiment",
52 | "FixedParameter",
53 | "GeneratorRun",
54 | "Metric",
55 | "MultiObjective",
56 | "MultiObjectiveOptimizationConfig",
57 | "Objective",
58 | "ObjectiveThreshold",
59 | "OptimizationConfig",
60 | "OrderConstraint",
61 | "OutcomeConstraint",
62 | "Parameter",
63 | "ParameterConstraint",
64 | "ParameterDistribution",
65 | "ParameterType",
66 | "RangeParameter",
67 | "RiskMeasure",
68 | "Runner",
69 | "SearchSpace",
70 | "SimpleExperiment",
71 | "SumConstraint",
72 | "Trial",
73 | ]
74 |
--------------------------------------------------------------------------------
/website/static/img/ax_logo_lockup.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/website/static/img/ax_lockup_white.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ax/modelbridge/transforms/cap_parameter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional, TYPE_CHECKING
8 |
9 | from ax.core.observation import ObservationData, ObservationFeatures
10 | from ax.core.parameter import RangeParameter
11 | from ax.core.search_space import SearchSpace
12 | from ax.modelbridge.transforms.base import Transform
13 | from ax.models.types import TConfig
14 | from ax.utils.common.typeutils import checked_cast
15 |
16 | if TYPE_CHECKING:
17 | # import as module to make sphinx-autodoc-typehints happy
18 | from ax import modelbridge as modelbridge_module # noqa F401 # pragma: no cover
19 |
20 |
21 | class CapParameter(Transform):
22 | """Cap parameter range(s) to given values. Expects a configuration of form
23 | { parameter_name -> new_upper_range_value }.
24 |
25 | This transform only transforms the search space.
26 | """
27 |
28 | def __init__(
29 | self,
30 | search_space: SearchSpace,
31 | observation_features: List[ObservationFeatures],
32 | observation_data: List[ObservationData],
33 | modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
34 | config: Optional[TConfig] = None,
35 | ) -> None:
36 | self.config = config or {}
37 | self.transform_parameters = { # Only transform parameters in config.
38 | p_name for p_name in search_space.parameters if p_name in self.config
39 | }
40 |
41 | def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
42 | for p_name, p in search_space.parameters.items():
43 | if p_name in self.transform_parameters:
44 | if not isinstance(p, RangeParameter):
45 | raise NotImplementedError(
46 | "Can only cap range parameters currently."
47 | )
48 | checked_cast(RangeParameter, p).update_range(
49 | upper=self.config.get(p_name)
50 | )
51 | return search_space
52 |
--------------------------------------------------------------------------------
/docs/why-ax.md:
--------------------------------------------------------------------------------
1 | ---
2 | id: why-ax
3 | title: Why Ax?
4 | sidebar_label: Why Ax?
5 | ---
6 |
7 | Developers and researchers alike face problems which confront them with a large space of possible ways to configure something –– whether those are "magic numbers" used for infrastructure or compiler flags, learning rates or other hyperparameters in machine learning, or images and calls-to-action used in marketing promotions. Selecting and tuning these configurations can often take time, resources, and can affect the quality of user experiences. Ax is a machine learning system to help automate this process, so that researchers and developers can determine how to get the most out of their software in an optimally efficient way.
8 |
9 | Ax is a platform for optimizing any kind of experiment, including machine learning experiments, A/B tests, and simulations. Ax can optimize discrete configurations (e.g., variants of an A/B test) using multi-armed bandit optimization, and continuous (e.g., integer or floating point)-valued configurations using Bayesian optimization. This makes it suitable for a wide range of applications.
10 |
11 | Ax has been successfully applied to a variety of product, infrastructure, ML, and research applications at Facebook.
12 |
13 | # Unique capabilities
14 | - **Support for noisy functions**. Results of A/B tests and simulations with reinforcement learning agents often exhibit high amounts of noise. Ax supports [state-of-the-art algorithms](https://research.facebook.com/blog/2018/09/efficient-tuning-of-online-systems-using-bayesian-optimization/) which work better than traditional Bayesian optimization in high-noise settings.
15 | - **Customization**. Ax's developer API makes it easy to integrate custom data modeling and decision algorithms. This allows developers to build their own custom optimization services with minimal overhead.
16 | - **Multi-modal experimentation**. Ax has first-class support for running and combining data from different types of experiments, such as "offline" simulation data and "online" data from real-world experiments.
17 | - **Multi-objective optimization**. Ax supports multi-objective and constrained optimization which are common to real-world problems, like improving load time without increasing data use.
18 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_contours.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import plotly.graph_objects as go
8 | from ax.modelbridge.registry import Models
9 | from ax.plot.base import AxPlotConfig
10 | from ax.plot.contour import (
11 | interact_contour,
12 | interact_contour_plotly,
13 | plot_contour,
14 | plot_contour_plotly,
15 | )
16 | from ax.utils.common.testutils import TestCase
17 | from ax.utils.testing.core_stubs import get_branin_experiment
18 | from ax.utils.testing.mock import fast_botorch_optimize
19 |
20 |
21 | class ContoursTest(TestCase):
22 | @fast_botorch_optimize
23 | def testContours(self):
24 | exp = get_branin_experiment(with_str_choice_param=True, with_batch=True)
25 | exp.trials[0].run()
26 | model = Models.BOTORCH(
27 | # Model bridge kwargs
28 | experiment=exp,
29 | data=exp.fetch_data(),
30 | )
31 | # Assert that each type of plot can be constructed successfully
32 | plot = plot_contour_plotly(
33 | model, model.parameters[0], model.parameters[1], list(model.metric_names)[0]
34 | )
35 | self.assertIsInstance(plot, go.Figure)
36 | plot = interact_contour_plotly(model, list(model.metric_names)[0])
37 | self.assertIsInstance(plot, go.Figure)
38 | plot = interact_contour(model, list(model.metric_names)[0])
39 | self.assertIsInstance(plot, AxPlotConfig)
40 | plot = plot = plot_contour(
41 | model, model.parameters[0], model.parameters[1], list(model.metric_names)[0]
42 | )
43 | self.assertIsInstance(plot, AxPlotConfig)
44 |
45 | # Make sure all parameters and metrics are displayed in tooltips
46 | tooltips = list(exp.parameters.keys()) + list(exp.metrics.keys())
47 | for d in plot.data["data"]:
48 | # Only check scatter plots hoverovers
49 | if d["type"] != "scatter":
50 | continue
51 | for text in d["text"]:
52 | for tt in tooltips:
53 | self.assertTrue(tt in text)
54 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_typeutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Dict
8 |
9 | import numpy as np
10 | from ax.utils.common.testutils import TestCase
11 | from ax.utils.common.typeutils import (
12 | checked_cast,
13 | checked_cast_complex,
14 | checked_cast_dict,
15 | checked_cast_list,
16 | checked_cast_optional,
17 | not_none,
18 | numpy_type_to_python_type,
19 | )
20 |
21 |
22 | class TestTypeUtils(TestCase):
23 | def test_not_none(self):
24 | self.assertEqual(not_none("not_none"), "not_none")
25 | with self.assertRaises(ValueError):
26 | not_none(None)
27 |
28 | def test_checked_cast(self):
29 | self.assertEqual(checked_cast(float, 2.0), 2.0)
30 | with self.assertRaises(ValueError):
31 | checked_cast(float, 2)
32 |
33 | def test_checked_cast_complex(self):
34 | t = Dict[int, str]
35 | self.assertEqual(checked_cast_complex(t, {1: "one"}), {1: "one"})
36 | with self.assertRaises(ValueError):
37 | checked_cast_complex(t, {"one": 1})
38 |
39 | def test_checked_cast_list(self):
40 | self.assertEqual(checked_cast_list(float, [1.0, 2.0]), [1.0, 2.0])
41 | with self.assertRaises(ValueError):
42 | checked_cast_list(float, [1.0, 2])
43 |
44 | def test_checked_cast_optional(self):
45 | self.assertEqual(checked_cast_optional(float, None), None)
46 | with self.assertRaises(ValueError):
47 | checked_cast_optional(float, 2)
48 |
49 | def test_checked_cast_dict(self):
50 | self.assertEqual(checked_cast_dict(str, int, {"some": 1}), {"some": 1})
51 | with self.assertRaises(ValueError):
52 | checked_cast_dict(str, int, {"some": 1.0})
53 | with self.assertRaises(ValueError):
54 | checked_cast_dict(str, int, {1: 1})
55 |
56 | def test_numpy_type_to_python_type(self):
57 | self.assertEqual(type(numpy_type_to_python_type(np.int64(2))), int)
58 | self.assertEqual(type(numpy_type_to_python_type(np.float64(2))), float)
59 |
--------------------------------------------------------------------------------
/ax/core/tests/test_risk_measures.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.core.risk_measures import RISK_MEASURE_NAME_TO_CLASS, RiskMeasure, VaR
8 | from ax.exceptions.core import UserInputError
9 | from ax.utils.common.testutils import TestCase
10 |
11 |
12 | class TestRiskMeasure(TestCase):
13 | def test_risk_measure(self):
14 | rm = RiskMeasure(
15 | risk_measure="VaR",
16 | options={"alpha": 0.8, "n_w": 5},
17 | )
18 | self.assertEqual(rm.risk_measure, "VaR")
19 | self.assertEqual(rm.options, {"alpha": 0.8, "n_w": 5})
20 | rm_module = rm.module
21 | self.assertIsInstance(rm_module, VaR)
22 | self.assertEqual(rm_module.alpha, 0.8)
23 | self.assertEqual(rm_module.n_w, 5)
24 | self.assertFalse(rm.is_multi_output)
25 |
26 | # Test repr.
27 | expected_repr = (
28 | "RiskMeasure(risk_measure=VaR, options={'alpha': 0.8, 'n_w': 5})"
29 | )
30 | self.assertEqual(str(rm), expected_repr)
31 |
32 | # Test clone.
33 | rm_clone = rm.clone()
34 | self.assertEqual(str(rm), str(rm_clone))
35 |
36 | # Test unknown risk measure.
37 | with self.assertRaisesRegex(UserInputError, "constructing"):
38 | RiskMeasure(
39 | risk_measure="VVar",
40 | options={},
41 | )
42 | # Test invalid options.
43 | with self.assertRaisesRegex(UserInputError, "constructing"):
44 | RiskMeasure(
45 | risk_measure="VaR",
46 | options={"alpha": 5, "n_w": 5},
47 | )
48 |
49 | def test_custom_risk_measure(self):
50 | # Test using user-defined risk measures.
51 |
52 | class CustomRM(VaR):
53 | pass
54 |
55 | RISK_MEASURE_NAME_TO_CLASS["custom"] = CustomRM
56 |
57 | rm = RiskMeasure(
58 | risk_measure="custom",
59 | options={"alpha": 0.8, "n_w": 5},
60 | )
61 | self.assertEqual(rm.risk_measure, "custom")
62 | self.assertIsInstance(rm.module, CustomRM)
63 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/sqa_enum.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import enum
8 | from typing import Any, Dict, List
9 |
10 | from ax.storage.sqa_store.db import NAME_OR_TYPE_FIELD_LENGTH
11 | from sqlalchemy import types
12 |
13 |
14 | class BaseNullableEnum(types.TypeDecorator):
15 | cache_ok = True
16 |
17 | def __init__(self, enum: Any, *arg: List[Any], **kw: Dict[Any, Any]) -> None:
18 | types.TypeDecorator.__init__(self, *arg, **kw)
19 | self._member_map = enum._member_map_
20 | self._value2member_map = enum._value2member_map_
21 |
22 | def process_bind_param(self, value: Any, dialect: Any) -> Any:
23 | if value is None:
24 | return value # pragma: no cover
25 | if not isinstance(value, enum.Enum):
26 | raise TypeError("Value is not an instance of Enum.") # pragma: no cover
27 | val = self._member_map.get(value.name)
28 | if val is None:
29 | raise ValueError( # pragma: no cover
30 | "Member '{value}' is not a supported enum: {members}".format(
31 | value=value, members=list(self._member_map.keys())
32 | )
33 | )
34 | return val._value_
35 |
36 | def process_result_value(self, value: Any, dialect: Any) -> Any:
37 | if value is None:
38 | return value # pragma: no cover
39 | member = self._value2member_map.get(value)
40 | if member is None:
41 | raise ValueError( # pragma: no cover
42 | f"Value '{value}' is not one of the supported "
43 | + "enum values: {supported_values}".format(
44 | supported_values=list(self._value2member_map.keys())
45 | )
46 | )
47 | return member
48 |
49 |
50 | class IntEnum(BaseNullableEnum):
51 | # pyre-fixme[8]: Attribute has type `SmallInteger`; used as
52 | # `Type[sqlalchemy.sql.sqltypes.SmallInteger]`.
53 | impl: types.SmallInteger = types.SmallInteger
54 |
55 |
56 | class StringEnum(BaseNullableEnum):
57 | impl = types.VARCHAR(NAME_OR_TYPE_FIELD_LENGTH)
58 |
--------------------------------------------------------------------------------
/ax/early_stopping/strategies/logical.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from typing import Any, Dict, Optional, Set
7 |
8 | from ax.core.experiment import Experiment
9 | from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
10 |
11 |
12 | class LogicalEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
13 | def __init__(
14 | self,
15 | left: BaseEarlyStoppingStrategy,
16 | right: BaseEarlyStoppingStrategy,
17 | seconds_between_polls: int = 60,
18 | true_objective_metric_name: Optional[str] = None,
19 | ) -> None:
20 | super().__init__(
21 | seconds_between_polls=seconds_between_polls,
22 | true_objective_metric_name=true_objective_metric_name,
23 | )
24 |
25 | self.left = left
26 | self.right = right
27 |
28 |
29 | class AndEarlyStoppingStrategy(LogicalEarlyStoppingStrategy):
30 | def should_stop_trials_early(
31 | self,
32 | trial_indices: Set[int],
33 | experiment: Experiment,
34 | **kwargs: Dict[str, Any],
35 | ) -> Dict[int, Optional[str]]:
36 |
37 | left = self.left.should_stop_trials_early(
38 | trial_indices=trial_indices, experiment=experiment, **kwargs
39 | )
40 | right = self.right.should_stop_trials_early(
41 | trial_indices=trial_indices, experiment=experiment, **kwargs
42 | )
43 | return {
44 | trial: f"{left[trial]}, {right[trial]}" for trial in left if trial in right
45 | }
46 |
47 |
48 | class OrEarlyStoppingStrategy(LogicalEarlyStoppingStrategy):
49 | def should_stop_trials_early(
50 | self,
51 | trial_indices: Set[int],
52 | experiment: Experiment,
53 | **kwargs: Dict[str, Any],
54 | ) -> Dict[int, Optional[str]]:
55 | return {
56 | **self.left.should_stop_trials_early(
57 | trial_indices=trial_indices, experiment=experiment, **kwargs
58 | ),
59 | **self.right.should_stop_trials_early(
60 | trial_indices=trial_indices, experiment=experiment, **kwargs
61 | ),
62 | }
63 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from tempfile import NamedTemporaryFile
9 | from unittest.mock import patch
10 |
11 | from ax.utils.common.logger import build_file_handler, get_logger
12 | from ax.utils.common.testutils import TestCase
13 |
14 |
15 | BASE_LOGGER_NAME = f"ax.{__name__}"
16 |
17 |
18 | class LoggerTest(TestCase):
19 | def setUp(self):
20 | self.warning_string = "Test warning"
21 |
22 | def testLogger(self):
23 | logger = get_logger(BASE_LOGGER_NAME + ".testLogger")
24 | # Verify it doesn't crash
25 | logger.warning(self.warning_string)
26 | # Patch it, verify we actually called it
27 | patcher = patch.object(logger, "warning")
28 | mock_warning = patcher.start()
29 | logger.warning(self.warning_string)
30 | mock_warning.assert_called_once_with(self.warning_string)
31 | # Need to stop patcher, else in some environments (like pytest)
32 | # the mock will leak into other tests, since it's getting set
33 | # onto the python logger directly.
34 | patcher.stop()
35 |
36 | def testLoggerWithFile(self):
37 | with NamedTemporaryFile() as tf:
38 | logger = get_logger(BASE_LOGGER_NAME + ".testLoggerWithFile")
39 | logger.addHandler(build_file_handler(tf.name))
40 | logger.info(self.warning_string)
41 | output = str(tf.read())
42 | self.assertIn(BASE_LOGGER_NAME, output)
43 | self.assertIn(self.warning_string, output)
44 | tf.close()
45 |
46 | def testLoggerOutputNameWithFile(self):
47 | with NamedTemporaryFile() as tf:
48 | logger = get_logger(BASE_LOGGER_NAME + ".testLoggerOutputNameWithFile")
49 | logger.addHandler(build_file_handler(tf.name))
50 | logger = logging.LoggerAdapter(logger, {"output_name": "my_output_name"})
51 | logger.warning(self.warning_string)
52 | output = str(tf.read())
53 | self.assertIn("my_output_name", output)
54 | self.assertIn(self.warning_string, output)
55 | tf.close()
56 |
--------------------------------------------------------------------------------
/ax/models/tests/test_torch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from ax.core.search_space import SearchSpaceDigest
9 | from ax.models.torch_base import TorchModel
10 | from ax.utils.common.testutils import TestCase
11 | from botorch.utils.datasets import FixedNoiseDataset
12 |
13 |
14 | class TorchModelTest(TestCase):
15 | def setUp(self):
16 | self.dataset = FixedNoiseDataset(
17 | X=torch.zeros(1), Y=torch.zeros(1), Yvar=torch.ones(1)
18 | )
19 |
20 | def testTorchModelFit(self):
21 | torch_model = TorchModel()
22 | torch_model.fit(
23 | datasets=[self.dataset],
24 | metric_names=["y"],
25 | search_space_digest=SearchSpaceDigest(
26 | feature_names=["x1"],
27 | bounds=[(0, 1)],
28 | ),
29 | )
30 |
31 | def testTorchModelPredict(self):
32 | torch_model = TorchModel()
33 | with self.assertRaises(NotImplementedError):
34 | torch_model.predict(torch.zeros(1))
35 |
36 | def testTorchModelGen(self):
37 | torch_model = TorchModel()
38 | with self.assertRaises(NotImplementedError):
39 | torch_model.gen(n=1, bounds=[(0, 1)], objective_weights=torch.ones(1))
40 |
41 | def testNumpyTorchBestPoint(self):
42 | torch_model = TorchModel()
43 | x = torch_model.best_point(bounds=[(0, 1)], objective_weights=torch.ones(1))
44 | self.assertIsNone(x)
45 |
46 | def testTorchModelCrossValidate(self):
47 | torch_model = TorchModel()
48 | with self.assertRaises(NotImplementedError):
49 | torch_model.cross_validate(
50 | datasets=[self.dataset],
51 | metric_names=["y"],
52 | X_test=torch.ones(1),
53 | search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]),
54 | )
55 |
56 | def testTorchModelUpdate(self):
57 | model = TorchModel()
58 | with self.assertRaises(NotImplementedError):
59 | model.update(
60 | datasets=[self.dataset],
61 | metric_names=["y"],
62 | search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]),
63 | )
64 |
--------------------------------------------------------------------------------
/ax/models/tests/test_full_factorial.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | import numpy as np
9 | from ax.models.discrete.full_factorial import FullFactorialGenerator
10 | from ax.utils.common.testutils import TestCase
11 |
12 |
13 | class FullFactorialGeneratorTest(TestCase):
14 | def testFullFactorial(self):
15 | generator = FullFactorialGenerator()
16 | parameter_values = [[1, 2], ["foo", "bar"]]
17 | generated_points, weights, _ = generator.gen(
18 | n=-1, parameter_values=parameter_values, objective_weights=np.ones(1)
19 | )
20 | expected_points = [[1, "foo"], [1, "bar"], [2, "foo"], [2, "bar"]]
21 | self.assertEqual(generated_points, expected_points)
22 | self.assertEqual(weights, [1 for _ in range(len(expected_points))])
23 |
24 | def testFullFactorialValidation(self):
25 | # Raise error because cardinality exceeds max cardinality
26 | generator = FullFactorialGenerator(max_cardinality=5, check_cardinality=True)
27 | parameter_values = [[1, 2], ["foo", "bar"], [True, False]]
28 | with self.assertRaises(ValueError):
29 | generated_points, weights, _ = generator.gen(
30 | n=-1, parameter_values=parameter_values, objective_weights=np.ones(1)
31 | )
32 |
33 | # Raise error because n != -1
34 | generator = FullFactorialGenerator()
35 | parameter_values = [[1, 2], ["foo", "bar"]]
36 | with self.assertRaises(ValueError):
37 | generated_points, weights, _ = generator.gen(
38 | n=5, parameter_values=parameter_values, objective_weights=np.ones(1)
39 | )
40 |
41 | def testFullFactorialFixedFeatures(self):
42 | generator = FullFactorialGenerator(max_cardinality=5, check_cardinality=True)
43 | parameter_values = [[1, 2], ["foo", "bar"]]
44 | generated_points, weights, _ = generator.gen(
45 | n=-1,
46 | parameter_values=parameter_values,
47 | objective_weights=np.ones(1),
48 | fixed_features={1: "foo"},
49 | )
50 | expected_points = [[1, "foo"], [2, "foo"]]
51 | self.assertEqual(generated_points, expected_points)
52 | self.assertEqual(weights, [1 for _ in range(len(expected_points))])
53 |
--------------------------------------------------------------------------------
/ax/benchmark/tests/test_scored_benchmark.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.benchmark.scored_benchmark import (
7 | scored_benchmark_full_run,
8 | scored_benchmark_replication,
9 | scored_benchmark_test,
10 | )
11 | from ax.utils.common.testutils import TestCase
12 | from ax.utils.testing.benchmark_stubs import (
13 | get_aggregated_benchmark_result,
14 | get_single_objective_benchmark_problem,
15 | get_sobol_gpei_benchmark_method,
16 | )
17 | from ax.utils.testing.mock import fast_botorch_optimize
18 |
19 |
20 | class TestProblems(TestCase):
21 | @fast_botorch_optimize
22 | def test_scored_benchmark_replication(self) -> None:
23 | scored_result = scored_benchmark_replication(
24 | problem=get_single_objective_benchmark_problem(),
25 | method=get_sobol_gpei_benchmark_method(),
26 | )
27 |
28 | self.assertEqual(len(scored_result.score_trace), 4)
29 | self.assertTrue(
30 | (scored_result.score_trace < 100).all()
31 | ) # Score should never be over 100
32 |
33 | @fast_botorch_optimize
34 | def test_scored_benchmark_test(self) -> None:
35 | aggregated_scored_result = scored_benchmark_test(
36 | problem=get_single_objective_benchmark_problem(),
37 | method=get_sobol_gpei_benchmark_method(),
38 | num_replications=2,
39 | )
40 |
41 | self.assertEqual(len(aggregated_scored_result.score_trace), 4)
42 | self.assertTrue(
43 | (aggregated_scored_result.score_trace["mean"] < 100).all()
44 | ) # Score should never be over 100
45 | self.assertTrue(
46 | (aggregated_scored_result.score_trace["median"] < 100).all()
47 | ) # Score should never be over 100
48 |
49 | @fast_botorch_optimize
50 | def test_scored_benchmark_full_run(self) -> None:
51 | aggregated_scored_results = scored_benchmark_full_run(
52 | problems_baseline_results=[
53 | (
54 | get_single_objective_benchmark_problem(),
55 | get_aggregated_benchmark_result(),
56 | )
57 | ],
58 | methods=[get_sobol_gpei_benchmark_method()],
59 | num_replications=2,
60 | )
61 |
62 | self.assertEqual(len(aggregated_scored_results), 1)
63 |
--------------------------------------------------------------------------------
/sphinx/source/metrics.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.metrics
5 | ===================================
6 |
7 | .. automodule:: ax.metrics
8 | .. currentmodule:: ax.metrics
9 |
10 |
11 | BoTorch Test Problem
12 | ~~~~~~
13 |
14 | .. automodule:: ax.metrics.botorch_test_problem
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 | Branin
20 | ~~~~~~
21 |
22 | .. automodule:: ax.metrics.branin
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Branin Map
28 | ~~~~~~~~~~
29 |
30 | .. automodule:: ax.metrics.branin_map
31 | :members:
32 | :undoc-members:
33 | :show-inheritance:
34 |
35 | Chemistry
36 | ~~~~~~~~~
37 |
38 | .. automodule:: ax.metrics.chemistry
39 | :members:
40 | :undoc-members:
41 | :show-inheritance:
42 |
43 | Curve
44 | ~~~~~~~~~
45 |
46 | .. automodule:: ax.metrics.curve
47 | :members:
48 | :undoc-members:
49 | :show-inheritance:
50 |
51 |
52 | Factorial
53 | ~~~~~~~~~
54 |
55 | .. automodule:: ax.metrics.factorial
56 | :members:
57 | :undoc-members:
58 | :show-inheritance:
59 |
60 | Hartmann6
61 | ~~~~~~~~~
62 |
63 | .. automodule:: ax.metrics.hartmann6
64 | :members:
65 | :undoc-members:
66 | :show-inheritance:
67 |
68 | Jenatton
69 | ~~~~~~~~~
70 |
71 | .. automodule:: ax.metrics.jenatton
72 | :members:
73 | :undoc-members:
74 | :show-inheritance:
75 |
76 |
77 | L2 Norm
78 | ~~~~~~~
79 |
80 | .. automodule:: ax.metrics.l2norm
81 | :members:
82 | :undoc-members:
83 | :show-inheritance:
84 |
85 | Noisy Functions
86 | ~~~~~~~~~~~~~~~
87 |
88 | .. automodule:: ax.metrics.noisy_function
89 | :members:
90 | :undoc-members:
91 | :show-inheritance:
92 |
93 | Noisy Function Map
94 | ~~~~~~~~~~~~~~~~~~~~
95 |
96 | .. automodule:: ax.metrics.noisy_function_map
97 | :members:
98 | :undoc-members:
99 | :show-inheritance:
100 |
101 | Sklearn
102 | ~~~~~~~
103 |
104 | .. automodule:: ax.metrics.sklearn
105 | :members:
106 | :undoc-members:
107 | :show-inheritance:
108 |
109 | Tensorboard
110 | ~~~~~~~~~~~
111 |
112 | .. automodule:: ax.metrics.tensorboard
113 | :members:
114 | :undoc-members:
115 | :show-inheritance:
116 |
117 |
118 | TorchX
119 | ~~~~~~~~~~~
120 |
121 | .. automodule:: ax.metrics.torchx
122 | :members:
123 | :undoc-members:
124 | :show-inheritance:
125 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/synthetic/hss/jenatton.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.benchmark.benchmark_problem import SingleObjectiveBenchmarkProblem
7 | from ax.core.objective import Objective
8 | from ax.core.optimization_config import OptimizationConfig
9 | from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
10 | from ax.core.search_space import HierarchicalSearchSpace
11 | from ax.metrics.jenatton import JenattonMetric
12 | from ax.runners.synthetic import SyntheticRunner
13 |
14 |
15 | def get_jenatton_benchmark_problem() -> SingleObjectiveBenchmarkProblem:
16 | search_space = HierarchicalSearchSpace(
17 | parameters=[
18 | ChoiceParameter(
19 | name="x1",
20 | parameter_type=ParameterType.INT,
21 | values=[0, 1],
22 | dependents={0: ["x2", "r8"], 1: ["x3", "r9"]},
23 | ),
24 | ChoiceParameter(
25 | name="x2",
26 | parameter_type=ParameterType.INT,
27 | values=[0, 1],
28 | dependents={0: ["x4"], 1: ["x5"]},
29 | ),
30 | ChoiceParameter(
31 | name="x3",
32 | parameter_type=ParameterType.INT,
33 | values=[0, 1],
34 | dependents={0: ["x6"], 1: ["x7"]},
35 | ),
36 | *[
37 | RangeParameter(
38 | name=f"x{i}",
39 | parameter_type=ParameterType.FLOAT,
40 | lower=0.0,
41 | upper=1.0,
42 | )
43 | for i in range(4, 8)
44 | ],
45 | RangeParameter(
46 | name="r8", parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0
47 | ),
48 | RangeParameter(
49 | name="r9", parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0
50 | ),
51 | ]
52 | )
53 |
54 | optimization_config = OptimizationConfig(
55 | objective=Objective(metric=JenattonMetric(), minimize=True)
56 | )
57 |
58 | return SingleObjectiveBenchmarkProblem(
59 | name="Jenatton",
60 | search_space=search_space,
61 | optimization_config=optimization_config,
62 | runner=SyntheticRunner(),
63 | optimal_value=0.1,
64 | )
65 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_feature_importances.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.modelbridge.base import ModelBridge
8 | from ax.modelbridge.registry import Models
9 | from ax.plot.base import AxPlotConfig
10 | from ax.plot.feature_importances import (
11 | plot_feature_importance_by_feature,
12 | plot_feature_importance_by_feature_plotly,
13 | plot_feature_importance_by_metric,
14 | plot_feature_importance_by_metric_plotly,
15 | plot_relative_feature_importance,
16 | plot_relative_feature_importance_plotly,
17 | )
18 | from ax.utils.common.testutils import TestCase
19 | from ax.utils.testing.core_stubs import get_branin_experiment
20 | from ax.utils.testing.mock import fast_botorch_optimize
21 | from plotly import graph_objects as go
22 |
23 | DUMMY_CAPTION = "test_caption"
24 |
25 |
26 | def get_modelbridge() -> ModelBridge:
27 | exp = get_branin_experiment(with_batch=True)
28 | exp.trials[0].run()
29 | return Models.BOTORCH(
30 | # Model bridge kwargs
31 | experiment=exp,
32 | data=exp.fetch_data(),
33 | )
34 |
35 |
36 | class FeatureImportancesTest(TestCase):
37 | @fast_botorch_optimize
38 | def testFeatureImportances(self):
39 | model = get_modelbridge()
40 | # Assert that each type of plot can be constructed successfully
41 | plot = plot_feature_importance_by_feature_plotly(model=model)
42 | self.assertIsInstance(plot, go.Figure)
43 | plot = plot_feature_importance_by_feature_plotly(
44 | model=model, caption=DUMMY_CAPTION
45 | )
46 | self.assertIsInstance(plot, go.Figure)
47 | self.assertEqual(len(plot.layout.annotations), 1)
48 | self.assertEqual(plot.layout.annotations[0].text, DUMMY_CAPTION)
49 | plot = plot_feature_importance_by_feature(model=model)
50 | self.assertIsInstance(plot, AxPlotConfig)
51 | plot = plot_feature_importance_by_metric_plotly(model=model)
52 | self.assertIsInstance(plot, go.Figure)
53 | plot = plot_feature_importance_by_metric(model=model)
54 | self.assertIsInstance(plot, AxPlotConfig)
55 | plot = plot_relative_feature_importance_plotly(model=model)
56 | self.assertIsInstance(plot, go.Figure)
57 | plot = plot_relative_feature_importance(model=model)
58 | self.assertIsInstance(plot, AxPlotConfig)
59 |
--------------------------------------------------------------------------------
/scripts/insert_api_refs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import ast
9 | import glob
10 | import re
11 |
12 |
13 | def list_functions(source_glob):
14 | """
15 | List all of the functions and classes defined
16 | """
17 | defined = []
18 | # Iterate through each source file
19 | for sp in glob.glob(source_glob):
20 | module_name = sp[:-3]
21 | module_name = module_name.replace("/", ".")
22 | # Parse the source file into an AST
23 | node = ast.parse(open(sp).read())
24 | # Extract the names of all functions and classes defined in this file
25 | defined.extend(
26 | (n.name, module_name + "." + n.name)
27 | for n in node.body
28 | if (isinstance(n, ast.FunctionDef) or isinstance(n, ast.ClassDef))
29 | )
30 | return defined
31 |
32 |
33 | def replace_backticks(source_path, docs_path):
34 | markdown_glob = docs_path + "/*.md"
35 | source_glob = source_path + "/**/*.py"
36 | methods = list_functions(source_glob)
37 | for f in glob.glob(markdown_glob):
38 | for n, m in methods:
39 | # Match backquoted mentions of the function/class name which are
40 | # not already links
41 | pattern = "(? AxPlotConfig:
19 | """
20 | Calculates and plots the marginal effects -- the effect of changing one
21 | factor away from the randomized distribution of the experiment and fixing it
22 | at a particular level.
23 |
24 | Args:
25 | model: Model to use for estimating effects
26 | metric: The metric for which to plot marginal effects.
27 |
28 | Returns:
29 | AxPlotConfig of the marginal effects
30 | """
31 | plot_data, _, _ = get_plot_data(model, {}, {metric})
32 |
33 | arm_dfs = []
34 | for arm in plot_data.in_sample.values():
35 | arm_df = pd.DataFrame(arm.parameters, index=[arm.name])
36 | arm_df["mean"] = arm.y_hat[metric]
37 | arm_df["sem"] = arm.se_hat[metric]
38 | arm_dfs.append(arm_df)
39 | effect_table = marginal_effects(pd.concat(arm_dfs, 0))
40 |
41 | varnames = effect_table["Name"].unique()
42 | data: List[Any] = []
43 | for varname in varnames:
44 | var_df = effect_table[effect_table["Name"] == varname]
45 | data += [
46 | go.Bar(
47 | x=var_df["Level"],
48 | y=var_df["Beta"],
49 | error_y={"type": "data", "array": var_df["SE"]},
50 | name=varname,
51 | )
52 | ]
53 | fig = subplots.make_subplots(
54 | cols=len(varnames),
55 | rows=1,
56 | subplot_titles=list(varnames),
57 | print_grid=False,
58 | shared_yaxes=True,
59 | )
60 | for idx, item in enumerate(data):
61 | fig.append_trace(item, 1, idx + 1)
62 | fig.layout.showlegend = False
63 | # fig.layout.margin = go.layout.Margin(l=2, r=2)
64 | fig.layout.title = "Marginal Effects by Factor"
65 | fig.layout.yaxis = {
66 | "title": "% better than experiment average",
67 | "hoverformat": ".{}f".format(DECIMALS),
68 | }
69 | return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
70 |
--------------------------------------------------------------------------------
/ax/modelbridge/tests/test_cap_parameter_transform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
8 | from ax.core.search_space import SearchSpace
9 | from ax.exceptions.core import UnsupportedError
10 | from ax.modelbridge.transforms.cap_parameter import CapParameter
11 | from ax.utils.common.testutils import TestCase
12 | from ax.utils.testing.core_stubs import get_robust_search_space
13 |
14 |
15 | class CapParameterTest(TestCase):
16 | def setUp(self):
17 | self.search_space = SearchSpace(
18 | parameters=[
19 | RangeParameter(
20 | "a", lower=1, upper=3, parameter_type=ParameterType.FLOAT
21 | ),
22 | ChoiceParameter(
23 | "b", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
24 | ),
25 | ]
26 | )
27 |
28 | def test_transform_search_space(self):
29 | t = CapParameter(
30 | search_space=self.search_space,
31 | observation_features=[],
32 | observation_data=[],
33 | config={"a": "2"},
34 | )
35 | t.transform_search_space(self.search_space)
36 | self.assertEqual(self.search_space.parameters.get("a").upper, 2)
37 | t2 = CapParameter(
38 | search_space=self.search_space,
39 | observation_features=[],
40 | observation_data=[],
41 | config={"b": "2"},
42 | )
43 | with self.assertRaises(NotImplementedError):
44 | t2.transform_search_space(self.search_space)
45 |
46 | def test_w_parameter_distributions(self):
47 | rss = get_robust_search_space()
48 | # Transform a non-distributional parameter.
49 | t = CapParameter(
50 | search_space=rss,
51 | observation_features=[],
52 | observation_data=[],
53 | config={"z": "2"},
54 | )
55 | t.transform_search_space(rss)
56 | self.assertEqual(rss.parameters.get("z").upper, 2)
57 | # Error with distributional parameter.
58 | t = CapParameter(
59 | search_space=rss,
60 | observation_features=[],
61 | observation_data=[],
62 | config={"x": "2"},
63 | )
64 | with self.assertRaisesRegex(UnsupportedError, "transform is not supported"):
65 | t.transform_search_space(rss)
66 |
--------------------------------------------------------------------------------
/ax/core/tests/test_arm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ax.core.arm import Arm
8 | from ax.utils.common.testutils import TestCase
9 |
10 |
11 | class ArmTest(TestCase):
12 | def setUp(self):
13 | pass
14 |
15 | def testInit(self):
16 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75})
17 | self.assertEqual(str(arm), "Arm(parameters={'y': 0.25, 'x': 0.75, 'z': 75})")
18 |
19 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}, name="status_quo")
20 | self.assertEqual(
21 | str(arm),
22 | "Arm(name='status_quo', parameters={'y': 0.25, 'x': 0.75, 'z': 75})",
23 | )
24 |
25 | def testNameValidation(self):
26 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75})
27 | self.assertFalse(arm.has_name)
28 | with self.assertRaises(ValueError):
29 | arm.name
30 | arm.name = "0_0"
31 | with self.assertRaises(ValueError):
32 | arm.name = "1_0"
33 |
34 | def testNameOrShortSignature(self):
35 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75}, name="0_0")
36 | self.assertEqual(arm.name_or_short_signature, "0_0")
37 |
38 | arm = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75})
39 | self.assertEqual(arm.name_or_short_signature, arm.signature[-4:])
40 |
41 | def testEq(self):
42 | arm1 = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75})
43 | arm2 = Arm(parameters={"z": 75, "x": 0.75, "y": 0.25})
44 | self.assertEqual(arm1, arm2)
45 |
46 | arm3 = Arm(parameters={"z": 5, "x": 0.75, "y": 0.25})
47 | self.assertNotEqual(arm1, arm3)
48 |
49 | arm4 = Arm(name="0_0", parameters={"y": 0.25, "x": 0.75, "z": 75})
50 | arm5 = Arm(name="0_0", parameters={"y": 0.25, "x": 0.75, "z": 75})
51 | self.assertEqual(arm4, arm5)
52 |
53 | arm6 = Arm(name="0_1", parameters={"y": 0.25, "x": 0.75, "z": 75})
54 | self.assertNotEqual(arm4, arm6)
55 |
56 | def testClone(self):
57 | arm1 = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75})
58 | arm2 = arm1.clone()
59 | self.assertFalse(arm1 is arm2)
60 | self.assertEqual(arm1, arm2)
61 | self.assertFalse(arm1.parameters is arm2.parameters)
62 |
63 | def testSortable(self):
64 | arm1 = Arm(parameters={"y": 0.25, "x": 0.75, "z": 75})
65 | arm2 = Arm(parameters={"z": 0, "x": 0, "y": 0})
66 | self.assertTrue(arm1 < arm2)
67 |
--------------------------------------------------------------------------------