├── ax
├── metrics
│ ├── curve.py
│ ├── tests
│ │ ├── test_curve.py
│ │ └── __init__.py
│ ├── chemistry_data.zip
│ ├── l2norm.py
│ ├── __init__.py
│ ├── hartmann6.py
│ └── branin.py
├── utils
│ ├── testing
│ │ ├── manifest.py
│ │ ├── __init__.py
│ │ ├── metrics
│ │ │ ├── __init__.py
│ │ │ └── backend_simulator_map.py
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── utils_testing_stubs.py
│ │ ├── test_init_files.py
│ │ └── torch_stubs.py
│ ├── __init__.py
│ ├── stats
│ │ ├── __init__.py
│ │ └── tests
│ │ │ └── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── tests
│ │ │ ├── __init__.py
│ │ │ ├── test_serialization.py
│ │ │ ├── test_func_enum.py
│ │ │ ├── test_docutils.py
│ │ │ ├── test_kwargutils.py
│ │ │ └── test_typeutils.py
│ │ ├── typeutils_nonnative.py
│ │ ├── mock.py
│ │ ├── typeutils_torch.py
│ │ ├── deprecation.py
│ │ ├── random.py
│ │ ├── func_enum.py
│ │ ├── timeutils.py
│ │ ├── docutils.py
│ │ └── base.py
│ ├── notebook
│ │ ├── __init__.py
│ │ └── plotting.py
│ ├── report
│ │ ├── __init__.py
│ │ ├── tests
│ │ │ ├── __init__.py
│ │ │ └── test_render.py
│ │ └── resources
│ │ │ ├── __init__.py
│ │ │ ├── simple_template.html
│ │ │ ├── sufficient_statistic.html
│ │ │ └── base_template.html
│ ├── tutorials
│ │ └── __init__.py
│ ├── flake8_plugins
│ │ └── __init__.py
│ ├── measurement
│ │ ├── __init__.py
│ │ └── tests
│ │ │ └── __init__.py
│ └── sensitivity
│ │ ├── __init__.py
│ │ └── tests
│ │ └── __init__.py
├── benchmark
│ ├── problems
│ │ ├── synthetic
│ │ │ ├── __init__.py
│ │ │ ├── hss
│ │ │ │ └── __init__.py
│ │ │ └── discretized
│ │ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── hpo
│ │ │ └── __init__.py
│ │ ├── runtime_funcs.py
│ │ └── hd_embedding.py
│ ├── __init__.py
│ ├── methods
│ │ ├── __init__.py
│ │ └── sobol.py
│ ├── benchmark_test_functions
│ │ └── __init__.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── methods
│ │ │ └── __init__.py
│ │ └── problems
│ │ │ └── __init__.py
│ ├── benchmark_step_runtime_function.py
│ ├── benchmark_trial_metadata.py
│ └── benchmark_test_function.py
├── plot
│ ├── __init__.py
│ ├── css
│ │ ├── __init__.py
│ │ └── base.css
│ ├── js
│ │ ├── __init__.py
│ │ ├── common
│ │ │ ├── __init__.py
│ │ │ ├── css.js
│ │ │ ├── plotly_requires.js
│ │ │ ├── plotly_offline.js
│ │ │ └── plotly_online.js
│ │ └── generic_plotly.js
│ └── tests
│ │ ├── __init__.py
│ │ ├── test_parallel_coordinates.py
│ │ ├── test_helper.py
│ │ ├── test_slices.py
│ │ └── test_diagnostic.py
├── core
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_risk_measures.py
│ │ ├── test_auxiliary.py
│ │ └── test_map_metric.py
│ ├── map_metric.py
│ ├── auxiliary.py
│ └── __init__.py
├── exceptions
│ ├── __init__.py
│ ├── model.py
│ ├── constants.py
│ ├── storage.py
│ └── data_provider.py
├── preview
│ ├── __init__.py
│ ├── api
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ ├── instantiation
│ │ │ │ └── __init__.py
│ │ │ └── storage.py
│ │ ├── protocols
│ │ │ ├── __init__.py
│ │ │ ├── metric.py
│ │ │ └── runner.py
│ │ ├── types.py
│ │ └── __init__.py
│ └── modelbridge
│ │ └── __init__.py
├── analysis
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_utils.py
│ ├── plotly
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── surface
│ │ │ ├── __init__.py
│ │ │ └── utils.py
│ │ ├── arm_effects
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ └── plotly_analysis.py
│ ├── markdown
│ │ └── __init__.py
│ ├── __init__.py
│ └── healthcheck
│ │ ├── __init__.py
│ │ ├── healthcheck_analysis.py
│ │ ├── tests
│ │ └── test_should_generate_candidates.py
│ │ └── should_generate_candidates.py
├── models
│ ├── discrete
│ │ ├── __init__.py
│ │ └── eb_thompson.py
│ ├── random
│ │ ├── __init__.py
│ │ └── uniform.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_base.py
│ │ ├── test_discrete.py
│ │ └── test_randomforest.py
│ ├── torch
│ │ ├── __init__.py
│ │ ├── tests
│ │ │ └── __init__.py
│ │ └── botorch_modular
│ │ │ ├── __init__.py
│ │ │ └── input_constructors
│ │ │ └── __init__.py
│ ├── __init__.py
│ ├── types.py
│ └── winsorization_config.py
├── runners
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_single_running_trial_mixin.py
│ ├── __init__.py
│ └── synthetic.py
├── service
│ ├── utils
│ │ ├── __init__.py
│ │ ├── best_point_utils.py
│ │ └── early_stopping.py
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_early_stopping.py
│ └── __init__.py
├── modelbridge
│ ├── tests
│ │ └── __init__.py
│ ├── transforms
│ │ ├── __init__.py
│ │ ├── tests
│ │ │ └── test_rounding_transform.py
│ │ ├── search_space_to_float.py
│ │ └── deprecated_transform_mixin.py
│ └── __init__.py
├── storage
│ ├── json_store
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── load.py
│ │ └── save.py
│ ├── sqa_store
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── structs.py
│ │ ├── timestamp.py
│ │ └── reduced_state.py
│ ├── __init__.py
│ └── tests
│ │ ├── test_registry_bundle.py
│ │ └── test_botorch_modular_registry.py
├── early_stopping
│ ├── tests
│ │ └── __init__.py
│ ├── __init__.py
│ └── strategies
│ │ └── __init__.py
├── global_stopping
│ ├── tests
│ │ └── __init__.py
│ ├── __init__.py
│ └── strategies
│ │ └── __init__.py
└── __init__.py
├── website
├── static
│ ├── .nojekyll
│ ├── CNAME
│ ├── img
│ │ ├── favicon.png
│ │ ├── oss_logo.png
│ │ ├── favicon
│ │ │ └── favicon.ico
│ │ ├── ax_wireframe.svg
│ │ ├── database-solid.svg
│ │ ├── th-large-solid.svg
│ │ ├── dice-solid.svg
│ │ ├── ax.svg
│ │ ├── ax_lockup.svg
│ │ ├── ax_logo_lockup.svg
│ │ └── ax_lockup_white.svg
│ └── js
│ │ ├── mathjax.js
│ │ └── plotUtils.js
├── versioned_docs
│ └── .gitkeep
├── versioned_sidebars
│ └── .gitkeep
├── sidebars.json
├── package.json
├── core
│ └── Footer.js
└── tutorials.json
├── .github
├── ISSUE_TEMPLATE
│ ├── config.yaml
│ ├── 3_general_support.yaml
│ ├── 2_feature_request.yaml
│ └── 1_bug_report.yaml
└── workflows
│ ├── tutorials.yml
│ ├── lint.yml
│ ├── cron_pinned.yml
│ ├── build-and-test.yml
│ ├── reusable_tutorials.yml
│ └── deploy.yml
├── docs
├── assets
│ ├── gp_opt.png
│ ├── bo_1d_opt.gif
│ ├── mab_probs.png
│ ├── mab_regret.png
│ ├── gp_posterior.png
│ ├── mab_animate.gif
│ ├── bandit_allocation.png
│ └── example_shrinkage.png
├── algo-overview.md
└── why-ax.md
├── CHANGELOG.md
├── requirements-fmt.txt
├── sphinx
└── source
│ ├── ax.rst
│ ├── health_check.rst
│ ├── global_stopping.rst
│ ├── index.rst
│ ├── runners.rst
│ ├── exceptions.rst
│ ├── early_stopping.rst
│ ├── metrics.rst
│ ├── preview.rst
│ └── service.rst
├── pyproject.toml
├── pytest.ini
├── scripts
├── import_ax.py
└── patch_site_config.py
├── .flake8
├── .pre-commit-config.yaml
└── LICENSE
/ax/metrics/curve.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ax/utils/testing/manifest.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/website/static/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ax/metrics/tests/test_curve.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/website/static/CNAME:
--------------------------------------------------------------------------------
1 | ax.dev
2 |
--------------------------------------------------------------------------------
/website/versioned_docs/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/website/versioned_sidebars/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yaml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
--------------------------------------------------------------------------------
/docs/assets/gp_opt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/gp_opt.png
--------------------------------------------------------------------------------
/docs/assets/bo_1d_opt.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/bo_1d_opt.gif
--------------------------------------------------------------------------------
/docs/assets/mab_probs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/mab_probs.png
--------------------------------------------------------------------------------
/docs/assets/mab_regret.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/mab_regret.png
--------------------------------------------------------------------------------
/ax/metrics/chemistry_data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/ax/metrics/chemistry_data.zip
--------------------------------------------------------------------------------
/docs/assets/gp_posterior.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/gp_posterior.png
--------------------------------------------------------------------------------
/docs/assets/mab_animate.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/mab_animate.gif
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | Ax uses GitHub tags for managing releases. See changelog [here](https://github.com/facebook/Ax/releases).
2 |
--------------------------------------------------------------------------------
/website/static/img/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/website/static/img/favicon.png
--------------------------------------------------------------------------------
/website/static/img/oss_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/website/static/img/oss_logo.png
--------------------------------------------------------------------------------
/docs/assets/bandit_allocation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/bandit_allocation.png
--------------------------------------------------------------------------------
/docs/assets/example_shrinkage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/docs/assets/example_shrinkage.png
--------------------------------------------------------------------------------
/website/static/img/favicon/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahandmohammadrehzaii/Ax/HEAD/website/static/img/favicon/favicon.ico
--------------------------------------------------------------------------------
/requirements-fmt.txt:
--------------------------------------------------------------------------------
1 | # generated by `pyfmt --requirements`
2 | black==24.4.2
3 | ruff-api==0.1.0
4 | stdlibs==2024.1.28
5 | ufmt==2.8.0
6 | usort==1.0.8.post1
7 |
--------------------------------------------------------------------------------
/docs/algo-overview.md:
--------------------------------------------------------------------------------
1 | ---
2 | id: algo-overview
3 | title: Overview
4 | ---
5 |
6 | Ax supports:
7 | * Bandit optimization
8 | * Empirical Bayes with Thompson sampling
9 | * Bayesian optimization
10 |
--------------------------------------------------------------------------------
/sphinx/source/ax.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax
5 | ===================================
6 |
7 | .. automodule:: ax
8 | :members:
9 | :noindex:
10 |
11 | .. currentmodule:: ax
12 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/synthetic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/ax/plot/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/core/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/exceptions/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/plot/css/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/plot/js/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/preview/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/stats/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/analysis/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/methods/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
--------------------------------------------------------------------------------
/ax/metrics/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/discrete/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/random/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/torch/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/plot/js/common/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/runners/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/service/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/common/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/notebook/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/report/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/testing/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/tutorials/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/analysis/plotly/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/hpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
--------------------------------------------------------------------------------
/ax/modelbridge/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/models/torch/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/preview/api/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/preview/modelbridge/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/flake8_plugins/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/measurement/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/report/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/sensitivity/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/testing/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/testing/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=34.4", "wheel", "setuptools_scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.usort]
6 | first_party_detection = false
7 |
8 | [tool.ufmt]
9 | formatter = "ruff-api"
10 |
--------------------------------------------------------------------------------
/ax/storage/json_store/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/measurement/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/report/resources/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/utils/sensitivity/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/benchmark_test_functions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/synthetic/hss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
--------------------------------------------------------------------------------
/ax/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/preview/api/utils/instantiation/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/synthetic/discretized/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
--------------------------------------------------------------------------------
/ax/plot/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/service/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/benchmark/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/utils/stats/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/benchmark/tests/methods/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/early_stopping/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/global_stopping/tests/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/modelbridge/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/benchmark/tests/problems/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/models/torch/botorch_modular/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/website/sidebars.json:
--------------------------------------------------------------------------------
1 | {
2 | "docs": {
3 | "Introduction": ["why-ax"],
4 | "Getting Started": ["installation", "api", "glossary"],
5 | "Algorithms": ["bayesopt", "banditopt"],
6 | "Components": ["core", "trial-evaluation", "data", "models", "storage"]
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/ax/models/torch/botorch_modular/input_constructors/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
--------------------------------------------------------------------------------
/ax/utils/report/resources/simple_template.html:
--------------------------------------------------------------------------------
1 |
4 | {% extends "base_template.html" %}
5 | {% block content %}
6 | {% for element in html_elements %}
7 | {{element}}
8 | {% endfor %}
9 | {% endblock %}
10 |
--------------------------------------------------------------------------------
/ax/plot/js/generic_plotly.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | Plotly.newPlot(
9 | {{id}},
10 | {{data}},
11 | {{layout}},
12 | {"showLink": false}
13 | );
14 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | # Configuration for pytest.
2 | [pytest]
3 | filterwarnings =
4 | # Filter out parameters and sklearn deprecation warnings.
5 | ignore::DeprecationWarning:.*paramz.*
6 | ignore::DeprecationWarning:.*sklearn*
7 | # Filter out numpy non-integer indices warning.
8 | ignore::DeprecationWarning:.*using a non-integer array as obj in delete*
9 |
--------------------------------------------------------------------------------
/ax/early_stopping/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from ax.early_stopping import strategies
11 |
12 | __all__ = ["strategies"]
13 |
--------------------------------------------------------------------------------
/ax/global_stopping/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from ax.global_stopping import strategies
11 |
12 | __all__ = ["strategies"]
13 |
--------------------------------------------------------------------------------
/website/static/img/ax_wireframe.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ax/service/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.service.managed_loop import OptimizationLoop, optimize
10 |
11 |
12 | __all__ = ["OptimizationLoop", "optimize"]
13 |
--------------------------------------------------------------------------------
/ax/storage/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.storage.json_store import load as json_load, save as json_save
10 |
11 |
12 | __all__ = ["json_save", "json_load"]
13 |
--------------------------------------------------------------------------------
/ax/plot/js/common/css.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | var css = document.createElement('style');
9 | css.type = 'text/css';
10 | css.innerHTML = "{{css}}";
11 | document.getElementsByTagName("head")[0].appendChild(css);
12 |
--------------------------------------------------------------------------------
/scripts/import_ax.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import ax
8 | from ax.service.ax_client import AxClient
9 |
10 |
11 | if __name__ == "__main__":
12 | assert ax is not None
13 | assert AxClient is not None
14 |
--------------------------------------------------------------------------------
/ax/analysis/plotly/surface/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.analysis.plotly.surface.contour import ContourPlot
9 | from ax.analysis.plotly.surface.slice import SlicePlot
10 |
11 | __all__ = ["ContourPlot", "SlicePlot"]
12 |
--------------------------------------------------------------------------------
/ax/plot/js/common/plotly_requires.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | require(['plotly'], function(Plotly) {
9 | window.PLOTLYENV = window.PLOTLYENV || {};
10 | window.PLOTLYENV.BASE_URL = 'https://plot.ly';
11 | {{script}}
12 | });
13 |
--------------------------------------------------------------------------------
/ax/analysis/markdown/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.analysis.markdown.markdown_analysis import (
9 | MarkdownAnalysis,
10 | MarkdownAnalysisCard,
11 | )
12 |
13 | __all__ = ["MarkdownAnalysis", "MarkdownAnalysisCard"]
14 |
--------------------------------------------------------------------------------
/sphinx/source/health_check.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.health_check
5 | ===============
6 |
7 | .. automodule:: ax.health_check
8 | .. currentmodule:: ax.health_check
9 |
10 | Ax Experiment Health Checks
11 | ---------------------------
12 |
13 | Search Space
14 | ~~~~~~~~~~~~
15 |
16 | .. automodule:: ax.health_check.search_space
17 | :members:
18 | :undoc-members:
19 | :show-inheritance:
20 |
--------------------------------------------------------------------------------
/website/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "scripts": {
3 | "examples": "docusaurus-examples",
4 | "start": "docusaurus-start",
5 | "build": "docusaurus-build",
6 | "publish-gh-pages": "docusaurus-publish",
7 | "write-translations": "docusaurus-write-translations",
8 | "version": "docusaurus-version",
9 | "rename-version": "docusaurus-rename-version"
10 | },
11 | "devDependencies": {
12 | "docusaurus": "^1.7.2"
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 |
3 | # E704: the linter doesn't parse types properly
4 | # T499, T484: silence mypy since using pyre for typechecking
5 | # W503: black and flake8 disagree on how to place operators
6 | # E231: black and flake8 disagree on whitespace after ','
7 | # E203: black and flake8 disagree on whitespace before ':'
8 | ignore = T484, T499, W503, E704, E231, E203
9 |
10 | # Black really wants lines to be 88 chars...
11 | max-line-length = 88
12 |
--------------------------------------------------------------------------------
/ax/plot/js/common/plotly_offline.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | if (!window.Plotly) {
9 | define('plotly', function(require, exports, module) {
10 | {{library}}
11 | });
12 | require(['plotly'], function(Plotly) {
13 | window.Plotly = Plotly;
14 | });
15 | }
16 |
--------------------------------------------------------------------------------
/ax/preview/api/protocols/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.preview.api.protocols.metric import IMetric
10 | from ax.preview.api.protocols.runner import IRunner
11 |
12 | __all__ = [
13 | "IMetric",
14 | "IRunner",
15 | ]
16 |
--------------------------------------------------------------------------------
/ax/storage/json_store/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.storage.json_store.load import load_experiment as json_load
10 | from ax.storage.json_store.save import save_experiment as json_save
11 |
12 |
13 | __all__ = ["json_load", "json_save"]
14 |
--------------------------------------------------------------------------------
/ax/preview/api/types.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from typing import Mapping
9 |
10 | TParameterValue = int | float | str | bool
11 | TParameterization = Mapping[str, TParameterValue]
12 |
13 | # Metric name => mean | (mean, sem)
14 | TOutcome = Mapping[str, float | tuple[float, float]]
15 |
--------------------------------------------------------------------------------
/ax/runners/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # flake8: noqa F401
10 | from ax.runners.simulated_backend import SimulatedBackendRunner
11 | from ax.runners.synthetic import SyntheticRunner
12 |
13 |
14 | __all__ = ["SimulatedBackendRunner", "SyntheticRunner"]
15 |
--------------------------------------------------------------------------------
/ax/plot/js/common/plotly_online.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | requirejs.config({
9 | paths: {
10 | plotly: ['https://cdn.plot.ly/plotly-latest.min'],
11 | },
12 | });
13 | if (!window.Plotly) {
14 | require(['plotly'], function(plotly) {
15 | window.Plotly = plotly;
16 | });
17 | }
18 |
--------------------------------------------------------------------------------
/ax/analysis/plotly/arm_effects/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-unsafe
7 |
8 | from ax.analysis.plotly.arm_effects.insample_effects import InSampleEffectsPlot
9 | from ax.analysis.plotly.arm_effects.predicted_effects import PredictedEffectsPlot
10 |
11 | __all__ = ["PredictedEffectsPlot", "InSampleEffectsPlot"]
12 |
--------------------------------------------------------------------------------
/.github/workflows/tutorials.yml:
--------------------------------------------------------------------------------
1 | name: Tutorials
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [ main ]
7 | paths:
8 | - "tutorials/**"
9 | pull_request:
10 | branches: [ main ]
11 | paths:
12 | - "tutorials/**"
13 |
14 |
15 | jobs:
16 |
17 | build-tutorials-with-latest-botorch:
18 | name: Tutorials with latest BoTorch
19 | uses: ./.github/workflows/reusable_tutorials.yml
20 | with:
21 | smoke_test: true
22 | pinned_botorch: false
23 |
--------------------------------------------------------------------------------
/ax/metrics/l2norm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import numpy as np
10 | import numpy.typing as npt
11 | from ax.metrics.noisy_function import NoisyFunctionMetric
12 |
13 |
14 | class L2NormMetric(NoisyFunctionMetric):
15 | def f(self, x: npt.NDArray) -> float:
16 | return np.sqrt((x**2).sum())
17 |
--------------------------------------------------------------------------------
/ax/global_stopping/strategies/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
10 | from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy
11 |
12 |
13 | __all__ = [
14 | "BaseGlobalStoppingStrategy",
15 | "ImprovementGlobalStoppingStrategy",
16 | ]
17 |
--------------------------------------------------------------------------------
/ax/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.analysis.analysis import (
9 | Analysis,
10 | AnalysisCard,
11 | AnalysisCardLevel,
12 | display_cards,
13 | )
14 | from ax.analysis.summary import Summary
15 | from ax.analysis.markdown import * # noqa
16 | from ax.analysis.plotly import * # noqa
17 |
18 | __all__ = ["Analysis", "AnalysisCard", "AnalysisCardLevel", "display_cards", "Summary"]
19 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 | workflow_dispatch:
9 |
10 |
11 | jobs:
12 |
13 | lint:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v4
17 |
18 | - name: Set up Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: "3.10"
22 |
23 | - name: Install dependencies
24 | run: pip install pre-commit
25 |
26 | - name: Run pre-commit
27 | run: pre-commit run --all-files --show-diff-on-failure
28 |
--------------------------------------------------------------------------------
/ax/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # flake8: noqa F401
10 | from ax.metrics.branin import BraninMetric
11 | from ax.metrics.chemistry import ChemistryMetric
12 | from ax.metrics.factorial import FactorialMetric
13 | from ax.metrics.sklearn import SklearnMetric
14 |
15 | __all__ = [
16 | "BraninMetric",
17 | "ChemistryMetric",
18 | "FactorialMetric",
19 | "SklearnMetric",
20 | ]
21 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # necessary to import this file so SQLAlchemy knows about the event listeners
10 | # see https://fburl.com/8mn7yjt2
11 | from ax.storage.sqa_store import validation
12 | from ax.storage.sqa_store.load import load_experiment as sqa_load
13 | from ax.storage.sqa_store.save import save_experiment as sqa_save
14 |
15 |
16 | __all__ = ["sqa_load", "sqa_save"]
17 |
18 | del validation
19 |
--------------------------------------------------------------------------------
/sphinx/source/global_stopping.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.global_stopping
5 | ==================
6 |
7 | .. automodule:: ax.global_stopping
8 | .. currentmodule:: ax.global_stopping
9 |
10 | Strategies
11 | ----------
12 |
13 | Base Strategies
14 | ~~~~~~~~~~~~~~~
15 |
16 | .. automodule:: ax.global_stopping.strategies.base
17 | :members:
18 | :undoc-members:
19 | :show-inheritance:
20 |
21 | `ImprovementGlobalStoppingStrategy`
22 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23 |
24 | .. automodule:: ax.global_stopping.strategies.improvement
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/runtime_funcs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from collections.abc import Mapping
9 |
10 | from ax.core.arm import Arm
11 | from ax.core.types import TParamValue
12 |
13 |
14 | def int_from_params(
15 | params: Mapping[str, TParamValue], n_possibilities: int = 10
16 | ) -> int:
17 | """
18 | Get an int between 0 and n_possibilities - 1, using a hash of the parameters.
19 | """
20 | arm_hash = Arm.md5hash(parameters=params)
21 | return int(arm_hash[-1], base=16) % n_possibilities
22 |
--------------------------------------------------------------------------------
/ax/benchmark/benchmark_step_runtime_function.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from collections.abc import Mapping
9 | from typing import Protocol, runtime_checkable
10 |
11 | from ax.core.types import TParamValue
12 |
13 |
14 | @runtime_checkable
15 | class TBenchmarkStepRuntimeFunction(Protocol):
16 | def __call__(self, params: Mapping[str, TParamValue]) -> float:
17 | """
18 | Return the runtime for each step.
19 |
20 | Each step within an arm will take the same amount of time.
21 | """
22 | ...
23 |
--------------------------------------------------------------------------------
/ax/utils/testing/utils_testing_stubs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from ax.utils.testing.backend_simulator import BackendSimulator, BackendSimulatorOptions
11 |
12 |
13 | def get_backend_simulator_with_trials() -> BackendSimulator:
14 | options = BackendSimulatorOptions(
15 | internal_clock=0.0, use_update_as_start_time=True, max_concurrency=2
16 | )
17 | sim = BackendSimulator(options=options)
18 | sim.run_trial(0, 2)
19 | sim.run_trial(1, 1)
20 | sim.run_trial(2, 10)
21 | return sim
22 |
--------------------------------------------------------------------------------
/sphinx/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Ax documentation index file, created by
2 | sphinx-quickstart on Sat Mar 2 00:03:32 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | API Reference
7 | =============
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 |
12 | analysis
13 | ax
14 | benchmark
15 | core
16 | early_stopping
17 | exceptions
18 | global_stopping
19 | health_check
20 | metrics
21 | modelbridge
22 | models
23 | plot
24 | preview
25 | runners
26 | service
27 | storage
28 | telemetry
29 | utils
30 |
31 |
32 | Indices and tables
33 | ==================
34 |
35 | * :ref:`genindex`
36 | * :ref:`modindex`
37 | * :ref:`search`
38 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: check-requirements-versions
5 | name: Check pre-commit formatting versions
6 | entry: python scripts/check_pre_commit_reqs.py
7 | language: python
8 | always_run: true
9 | pass_filenames: false
10 | additional_dependencies:
11 | - PyYAML
12 |
13 | - repo: https://github.com/omnilib/ufmt
14 | rev: v2.8.0
15 | hooks:
16 | - id: ufmt
17 | additional_dependencies:
18 | - black==24.4.2
19 | - usort==1.0.8.post1
20 | - ruff-api==0.1.0
21 | - stdlibs==2024.1.28
22 | args: [format]
23 |
24 | - repo: https://github.com/pycqa/flake8
25 | rev: 7.0.0
26 | hooks:
27 | - id: flake8
28 |
--------------------------------------------------------------------------------
/ax/utils/report/resources/sufficient_statistic.html:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | {% for cell in cells %}
14 | {% if cell.html %}
15 |
18 |
19 | {{cell.html}}
20 |
21 | {% endif %}
22 | {% endfor %}
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/ax/metrics/hartmann6.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import numpy.typing as npt
10 | from ax.metrics.noisy_function import NoisyFunctionMetric
11 | from ax.utils.common.typeutils import checked_cast
12 | from ax.utils.measurement.synthetic_functions import aug_hartmann6, hartmann6
13 |
14 |
15 | class Hartmann6Metric(NoisyFunctionMetric):
16 | def f(self, x: npt.NDArray) -> float:
17 | return checked_cast(float, hartmann6(x))
18 |
19 |
20 | class AugmentedHartmann6Metric(NoisyFunctionMetric):
21 | def f(self, x: npt.NDArray) -> float:
22 | return checked_cast(float, aug_hartmann6(x))
23 |
--------------------------------------------------------------------------------
/website/static/js/mathjax.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | window.MathJax = {
9 | tex2jax: {
10 | inlineMath: [['$', '$'], ['\\(', '\\)']],
11 | displayMath: [['$$', '$$'], ['\\[', '\\]']],
12 | processEscapes: true,
13 | processEnvironments: true,
14 | },
15 | // Center justify equations in code and markdown cells. Note that this
16 | // doesn't work with Plotly though, hence the !important declaratio
17 | // below.
18 | displayAlign: 'center',
19 | 'HTML-CSS': {
20 | styles: {
21 | '.MathJax_Display': {margin: 0, 'text-align': 'center !important'},
22 | },
23 | linebreaks: {automatic: true},
24 | },
25 | };
26 |
--------------------------------------------------------------------------------
/website/static/img/database-solid.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
12 |
--------------------------------------------------------------------------------
/.github/workflows/cron_pinned.yml:
--------------------------------------------------------------------------------
1 | name: Replicate Nightly Cron with Pinned BoTorch
2 |
3 | on:
4 |
5 | workflow_dispatch:
6 |
7 | jobs:
8 |
9 | tests-and-coverage-minimal:
10 | name: Tests with pinned BoTorch & minimal dependencies
11 | uses: ./.github/workflows/reusable_test.yml
12 | with:
13 | pinned_botorch: true
14 | minimal_dependencies: true
15 | secrets: inherit
16 |
17 | tests-and-coverage-full:
18 | name: Tests with pinned BoTorch & full dependencies
19 | uses: ./.github/workflows/reusable_test.yml
20 | with:
21 | pinned_botorch: true
22 | minimal_dependencies: false
23 | secrets: inherit
24 |
25 | build-tutorials:
26 | name: Build tutorials with pinned BoTorch
27 | uses: ./.github/workflows/reusable_tutorials.yml
28 | with:
29 | smoke_test: false
30 | pinned_botorch: true
31 |
--------------------------------------------------------------------------------
/ax/models/types.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from typing import Any, Union
10 |
11 | from ax.core.optimization_config import OptimizationConfig
12 | from ax.models.winsorization_config import WinsorizationConfig
13 | from botorch.acquisition import AcquisitionFunction
14 |
15 | # pyre-ignore [33]: `TConfig` cannot alias to a type containing `Any`.
16 | TConfig = dict[
17 | str,
18 | Union[
19 | int,
20 | float,
21 | str,
22 | AcquisitionFunction,
23 | list[str],
24 | dict[int, Any],
25 | dict[str, Any],
26 | OptimizationConfig,
27 | WinsorizationConfig,
28 | None,
29 | ],
30 | ]
31 |
--------------------------------------------------------------------------------
/ax/exceptions/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.exceptions.core import AxError
10 |
11 |
12 | class ModelError(AxError):
13 | """Raised when an error occurs during modeling."""
14 |
15 | pass
16 |
17 |
18 | class CVNotSupportedError(AxError):
19 | """Raised when cross validation is applied to a model which doesn't
20 | support it.
21 | """
22 |
23 | pass
24 |
25 |
26 | class ModelBridgeMethodNotImplementedError(AxError, NotImplementedError):
27 | """Raised when a ``ModelBridge`` method is not implemented by subclasses.
28 |
29 | NOTE: ``ModelBridge`` may catch and silently discard this error.
30 | """
31 |
32 | pass
33 |
--------------------------------------------------------------------------------
/ax/modelbridge/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # flake8: noqa F401
10 | from ax.modelbridge import transforms
11 | from ax.modelbridge.base import ModelBridge
12 | from ax.modelbridge.factory import (
13 | get_factorial,
14 | get_sobol,
15 | get_thompson,
16 | get_uniform,
17 | Models,
18 | )
19 | from ax.modelbridge.map_torch import MapTorchModelBridge
20 | from ax.modelbridge.torch import TorchModelBridge
21 |
22 | __all__ = [
23 | "MapTorchModelBridge",
24 | "ModelBridge",
25 | "Models",
26 | "TorchModelBridge",
27 | "get_factorial",
28 | "get_sobol",
29 | "get_thompson",
30 | "get_uniform",
31 | "transforms",
32 | ]
33 |
--------------------------------------------------------------------------------
/ax/models/tests/test_base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.models.base import Model
10 | from ax.utils.common.testutils import TestCase
11 |
12 |
13 | class BaseModelTest(TestCase):
14 | def test_base_model(self) -> None:
15 | model = Model()
16 | raw_state = {"foo": "bar", "two": 3.0}
17 | self.assertEqual(model.serialize_state(raw_state), raw_state)
18 | self.assertEqual(model.deserialize_state(raw_state), raw_state)
19 | self.assertEqual(model._get_state(), {})
20 | with self.assertRaisesRegex(
21 | NotImplementedError, "Feature importance not available"
22 | ):
23 | model.feature_importances()
24 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/structs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from collections.abc import Callable
10 | from typing import NamedTuple
11 |
12 | from ax.storage.sqa_store.decoder import Decoder
13 | from ax.storage.sqa_store.encoder import Encoder
14 | from ax.storage.sqa_store.sqa_config import SQAConfig
15 |
16 |
17 | class DBSettings(NamedTuple):
18 | """
19 | Defines behavior for loading/saving experiment to/from db.
20 | Either creator or url must be specified as a way to connect to the SQL db.
21 | """
22 |
23 | creator: Callable | None = None
24 | decoder: Decoder = Decoder(config=SQAConfig())
25 | encoder: Encoder = Encoder(config=SQAConfig())
26 | url: str | None = None
27 |
--------------------------------------------------------------------------------
/ax/utils/common/typeutils_nonnative.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from typing import Any
10 |
11 | import numpy as np
12 |
13 |
14 | # pyre-fixme[3]: Return annotation cannot be `Any`.
15 | # pyre-fixme[2]: Parameter annotation cannot be `Any`.
16 | def numpy_type_to_python_type(value: Any) -> Any:
17 | """If `value` is a Numpy int or float, coerce to a Python int or float.
18 | This is necessary because some of our transforms return Numpy values.
19 | """
20 | if isinstance(value, np.integer):
21 | value = int(value) # pragma: nocover (covered by generator tests)
22 | if isinstance(value, np.floating):
23 | value = float(value) # pragma: nocover (covered by generator tests)
24 | return value
25 |
--------------------------------------------------------------------------------
/ax/analysis/plotly/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.analysis.plotly.cross_validation import CrossValidationPlot
9 | from ax.analysis.plotly.interaction import InteractionPlot
10 | from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
11 | from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
12 | from ax.analysis.plotly.scatter import ScatterPlot
13 | from ax.analysis.plotly.surface.contour import ContourPlot
14 | from ax.analysis.plotly.surface.slice import SlicePlot
15 |
16 | __all__ = [
17 | "ContourPlot",
18 | "CrossValidationPlot",
19 | "InteractionPlot",
20 | "PlotlyAnalysis",
21 | "PlotlyAnalysisCard",
22 | "ParallelCoordinatesPlot",
23 | "ScatterPlot",
24 | "SlicePlot",
25 | ]
26 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_parallel_coordinates.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 |
9 | from ax.plot.base import AxPlotConfig, AxPlotTypes
10 | from ax.plot.parallel_coordinates import plot_parallel_coordinates
11 | from ax.utils.common.testutils import TestCase
12 | from ax.utils.testing.core_stubs import get_branin_experiment
13 |
14 |
15 | class ParallelCoordinatesTest(TestCase):
16 | def test_ParallelCoordinates(self) -> None:
17 | exp = get_branin_experiment(with_batch=True)
18 | exp.trials[0].run()
19 |
20 | # Assert that each type of plot can be constructed successfully
21 | plot = plot_parallel_coordinates(experiment=exp)
22 |
23 | self.assertIsInstance(plot, AxPlotConfig)
24 | self.assertEqual(plot.plot_type, AxPlotTypes.GENERIC)
25 |
--------------------------------------------------------------------------------
/website/static/img/th-large-solid.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
14 |
--------------------------------------------------------------------------------
/ax/preview/api/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.preview.api.client import Client
10 | from ax.preview.api.configs import (
11 | ChoiceParameterConfig,
12 | ExperimentConfig,
13 | GenerationStrategyConfig,
14 | OrchestrationConfig,
15 | ParameterScaling,
16 | ParameterType,
17 | RangeParameterConfig,
18 | StorageConfig,
19 | )
20 | from ax.preview.api.types import TOutcome, TParameterization
21 |
22 | __all__ = [
23 | "Client",
24 | "ChoiceParameterConfig",
25 | "ExperimentConfig",
26 | "GenerationStrategyConfig",
27 | "OrchestrationConfig",
28 | "ParameterScaling",
29 | "ParameterType",
30 | "RangeParameterConfig",
31 | "StorageConfig",
32 | "TOutcome",
33 | "TParameterization",
34 | ]
35 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_serialization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from typing import NamedTuple
10 |
11 | from ax.utils.common.serialization import named_tuple_to_dict
12 | from ax.utils.common.testutils import TestCase
13 |
14 |
15 | class TestSerializationUtils(TestCase):
16 | def test_named_tuple_to_dict(self) -> None:
17 | class Foo(NamedTuple):
18 | x: int
19 | y: str
20 |
21 | foo = Foo(x=5, y="g")
22 | self.assertEqual(named_tuple_to_dict(foo), {"x": 5, "y": "g"})
23 |
24 | bar = {"x": 5, "foo": foo, "y": [(1, True), foo]}
25 | self.assertEqual(
26 | named_tuple_to_dict(bar),
27 | {"x": 5, "foo": {"x": 5, "y": "g"}, "y": [(1, True), {"x": 5, "y": "g"}]},
28 | )
29 |
--------------------------------------------------------------------------------
/ax/metrics/branin.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import numpy.typing as npt
10 | from ax.metrics.noisy_function import NoisyFunctionMetric
11 | from ax.utils.common.typeutils import checked_cast
12 | from ax.utils.measurement.synthetic_functions import aug_branin, branin
13 |
14 |
15 | class BraninMetric(NoisyFunctionMetric):
16 | def f(self, x: npt.NDArray) -> float:
17 | x1, x2 = x
18 | return checked_cast(float, branin(x1=x1, x2=x2))
19 |
20 |
21 | class NegativeBraninMetric(BraninMetric):
22 | def f(self, x: npt.NDArray) -> float:
23 | fpos = super().f(x)
24 | return -fpos
25 |
26 |
27 | class AugmentedBraninMetric(NoisyFunctionMetric):
28 | def f(self, x: npt.NDArray) -> float:
29 | return checked_cast(float, aug_branin(x))
30 |
--------------------------------------------------------------------------------
/ax/benchmark/methods/sobol.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 |
9 | from ax.benchmark.benchmark_method import BenchmarkMethod
10 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
11 | from ax.modelbridge.registry import Models
12 |
13 |
14 | def get_sobol_generation_strategy() -> GenerationStrategy:
15 | return GenerationStrategy(
16 | name="Sobol",
17 | steps=[
18 | GenerationStep(model=Models.SOBOL, num_trials=-1),
19 | ],
20 | )
21 |
22 |
23 | def get_sobol_benchmark_method(
24 | distribute_replications: bool,
25 | batch_size: int = 1,
26 | ) -> BenchmarkMethod:
27 | return BenchmarkMethod(
28 | generation_strategy=get_sobol_generation_strategy(),
29 | batch_size=batch_size,
30 | distribute_replications=distribute_replications,
31 | )
32 |
--------------------------------------------------------------------------------
/ax/exceptions/constants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | TS_MIN_WEIGHT_ERROR = """\
10 | No arms generated by Thompson Sampling had weight > min_weight. \
11 | The minimum weight required is {min_weight:2.4}, and the \
12 | maximum weight of any arm generated is {max_weight:2.4}.
13 | """
14 |
15 | TS_NO_FEASIBLE_ARMS_ERROR = """\
16 | Less than 1% of samples have a feasible arm. \
17 | Check your outcome constraints.
18 | """
19 |
20 | CHOLESKY_ERROR_ANNOTATION = (
21 | "Cholesky errors typically occur when the same or very similar "
22 | "arms are suggested repeatedly. This can mean the model has "
23 | "already converged and you should avoid running further trials. "
24 | "It will also help to convert integer or categorical parameters "
25 | "to float ranges where reasonable.\nOriginal error: "
26 | )
27 |
--------------------------------------------------------------------------------
/ax/core/tests/test_risk_measures.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.core.risk_measures import RiskMeasure
10 | from ax.utils.common.testutils import TestCase
11 |
12 |
13 | class TestRiskMeasure(TestCase):
14 | def test_risk_measure(self) -> None:
15 | rm = RiskMeasure(
16 | risk_measure="VaR",
17 | options={"alpha": 0.8, "n_w": 5},
18 | )
19 | self.assertEqual(rm.risk_measure, "VaR")
20 | self.assertEqual(rm.options, {"alpha": 0.8, "n_w": 5})
21 |
22 | # Test repr.
23 | expected_repr = (
24 | "RiskMeasure(risk_measure=VaR, options={'alpha': 0.8, 'n_w': 5})"
25 | )
26 | self.assertEqual(str(rm), expected_repr)
27 |
28 | # Test clone.
29 | rm_clone = rm.clone()
30 | self.assertEqual(str(rm), str(rm_clone))
31 |
--------------------------------------------------------------------------------
/sphinx/source/runners.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.runners
5 | ==========
6 |
7 | .. automodule:: ax.runners
8 | .. currentmodule:: ax.runners
9 |
10 | BoTorch Test Problem
11 | ~~~~~~~~~~~~~~~~~~~~
12 |
13 | .. automodule:: ax.runners.botorch_test_problem
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
18 | SingleRunningTrialMixin
19 | ~~~~~~~~~~~~~~~~~~~~~~~
20 |
21 | .. automodule:: ax.runners.single_running_trial_mixin
22 | :members:
23 | :undoc-members:
24 | :show-inheritance:
25 |
26 | Synthetic Runner
27 | ~~~~~~~~~~~~~~~~
28 |
29 | .. automodule:: ax.runners.synthetic
30 | :members:
31 | :undoc-members:
32 | :show-inheritance:
33 |
34 | Simulated Backend Runner
35 | ~~~~~~~~~~~~~~~~~~~~~~~~
36 |
37 | .. automodule:: ax.runners.simulated_backend
38 | :members:
39 | :undoc-members:
40 | :show-inheritance:
41 |
42 | TorchX Runner
43 | ~~~~~~~~~~~~~
44 |
45 | .. automodule:: ax.runners.torchx
46 | :members:
47 | :undoc-members:
48 | :show-inheritance:
49 |
--------------------------------------------------------------------------------
/ax/analysis/healthcheck/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.analysis.healthcheck.can_generate_candidates import (
9 | CanGenerateCandidatesAnalysis,
10 | )
11 |
12 | from ax.analysis.healthcheck.constraints_feasibility import (
13 | ConstraintsFeasibilityAnalysis,
14 | )
15 | from ax.analysis.healthcheck.healthcheck_analysis import (
16 | HealthcheckAnalysis,
17 | HealthcheckAnalysisCard,
18 | HealthcheckStatus,
19 | )
20 |
21 | from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
22 | from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
23 |
24 | __all__ = [
25 | "ConstraintsFeasibilityAnalysis",
26 | "CanGenerateCandidatesAnalysis",
27 | "HealthcheckAnalysis",
28 | "HealthcheckAnalysisCard",
29 | "HealthcheckStatus",
30 | "ShouldGenerateCandidates",
31 | "SearchSpaceAnalysis",
32 | ]
33 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/timestamp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import datetime
10 |
11 | from sqlalchemy.engine.interfaces import Dialect
12 | from sqlalchemy.types import Integer, TypeDecorator
13 |
14 |
15 | class IntTimestamp(TypeDecorator):
16 | impl = Integer
17 | cache_ok = True
18 |
19 | # pyre-fixme[15]: `process_bind_param` overrides method defined in
20 | # `TypeDecorator` inconsistently.
21 | def process_bind_param(
22 | self, value: datetime.datetime | None, dialect: Dialect
23 | ) -> int | None:
24 | if value is None:
25 | return None
26 | else:
27 | return int(value.timestamp())
28 |
29 | def process_result_value(
30 | self, value: int | None, dialect: Dialect
31 | ) -> datetime.datetime | None:
32 | return None if value is None else datetime.datetime.fromtimestamp(value)
33 |
--------------------------------------------------------------------------------
/ax/utils/testing/test_init_files.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import os
10 | from glob import glob
11 |
12 | from ax.utils.common.testutils import TestCase
13 |
14 | DIRS_TO_SKIP = ["ax/fb", "ax/github", "tests"]
15 |
16 |
17 | class InitTest(TestCase):
18 | def test_InitFiles(self) -> None:
19 | """__init__.py files are necessary for the inclusion of the directories
20 | in pip builds."""
21 | for root, _, files in os.walk("./ax", topdown=False):
22 | if any(s in root for s in DIRS_TO_SKIP):
23 | continue
24 | if len(glob(f"{root}/**/*.py", recursive=True)) > 0:
25 | with self.subTest(root):
26 | self.assertTrue(
27 | "__init__.py" in files,
28 | "directory " + root + " does not contain a .__init__.py file",
29 | )
30 |
--------------------------------------------------------------------------------
/sphinx/source/exceptions.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.exceptions
5 | =============
6 |
7 | .. automodule:: ax.exceptions
8 | .. currentmodule:: ax.exceptions
9 |
10 |
11 | Constants
12 | ~~~~~~~~~
13 |
14 | .. automodule:: ax.exceptions.constants
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 |
20 | Core
21 | ~~~~
22 |
23 | .. automodule:: ax.exceptions.core
24 | :members:
25 | :undoc-members:
26 | :show-inheritance:
27 |
28 | Data
29 | ~~~~
30 |
31 | .. automodule:: ax.exceptions.data_provider
32 | :members:
33 | :undoc-members:
34 | :show-inheritance:
35 |
36 | Generation Strategy
37 | ~~~~~~~~~~~~~~~~~~~
38 |
39 | .. automodule:: ax.exceptions.generation_strategy
40 | :members:
41 | :undoc-members:
42 | :show-inheritance:
43 |
44 | Model
45 | ~~~~~
46 |
47 | .. automodule:: ax.exceptions.model
48 | :members:
49 | :undoc-members:
50 | :show-inheritance:
51 |
52 | Storage
53 | ~~~~~~~
54 |
55 | .. automodule:: ax.exceptions.storage
56 | :members:
57 | :undoc-members:
58 | :show-inheritance:
59 |
--------------------------------------------------------------------------------
/ax/preview/api/utils/storage.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from ax.preview.api.configs import StorageConfig
7 | from ax.storage.sqa_store.decoder import Decoder
8 | from ax.storage.sqa_store.encoder import Encoder
9 | from ax.storage.sqa_store.sqa_config import SQAConfig
10 | from ax.storage.sqa_store.structs import DBSettings
11 |
12 |
13 | def db_settings_from_storage_config(
14 | storage_config: StorageConfig,
15 | ) -> DBSettings:
16 | """Construct DBSettings (expected by WithDBSettingsBase) from StorageConfig."""
17 | if (bundle := storage_config.registry_bundle) is not None:
18 | encoder = bundle.encoder
19 | decoder = bundle.decoder
20 | else:
21 | encoder = Encoder(config=SQAConfig())
22 | decoder = Decoder(config=SQAConfig())
23 |
24 | return DBSettings(
25 | creator=storage_config.creator,
26 | url=storage_config.url,
27 | encoder=encoder,
28 | decoder=decoder,
29 | )
30 |
--------------------------------------------------------------------------------
/ax/early_stopping/strategies/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.early_stopping.strategies.base import (
10 | BaseEarlyStoppingStrategy,
11 | EarlyStoppingTrainingData,
12 | ModelBasedEarlyStoppingStrategy,
13 | )
14 | from ax.early_stopping.strategies.logical import (
15 | AndEarlyStoppingStrategy,
16 | LogicalEarlyStoppingStrategy,
17 | OrEarlyStoppingStrategy,
18 | )
19 | from ax.early_stopping.strategies.percentile import PercentileEarlyStoppingStrategy
20 | from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy
21 |
22 |
23 | __all__ = [
24 | "BaseEarlyStoppingStrategy",
25 | "EarlyStoppingTrainingData",
26 | "ModelBasedEarlyStoppingStrategy",
27 | "PercentileEarlyStoppingStrategy",
28 | "ThresholdEarlyStoppingStrategy",
29 | "AndEarlyStoppingStrategy",
30 | "OrEarlyStoppingStrategy",
31 | "LogicalEarlyStoppingStrategy",
32 | ]
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Meta Platforms, Inc. and affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/ax/analysis/healthcheck/healthcheck_analysis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 | import json
8 | from enum import IntEnum
9 |
10 | from ax.analysis.analysis import Analysis, AnalysisCard
11 | from ax.core.experiment import Experiment
12 | from ax.core.generation_strategy_interface import GenerationStrategyInterface
13 |
14 |
15 | class HealthcheckStatus(IntEnum):
16 | PASS = 0
17 | FAIL = 1
18 | WARNING = 2
19 |
20 |
21 | class HealthcheckAnalysisCard(AnalysisCard):
22 | blob_annotation = "healthcheck"
23 |
24 | def get_status(self) -> HealthcheckStatus:
25 | return HealthcheckStatus(json.loads(self.blob)["status"])
26 |
27 |
28 | class HealthcheckAnalysis(Analysis):
29 | """
30 | An analysis that performs a health check.
31 | """
32 |
33 | def compute(
34 | self,
35 | experiment: Experiment | None = None,
36 | generation_strategy: GenerationStrategyInterface | None = None,
37 | ) -> HealthcheckAnalysisCard: ...
38 |
--------------------------------------------------------------------------------
/ax/core/tests/test_auxiliary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.core.auxiliary import AuxiliaryExperiment
9 | from ax.utils.common.testutils import TestCase
10 | from ax.utils.testing.core_stubs import get_experiment, get_experiment_with_data
11 |
12 |
13 | class AuxiliaryExperimentTest(TestCase):
14 | def test_AuxiliaryExperiment(self) -> None:
15 | for get_exp_func in [get_experiment, get_experiment_with_data]:
16 | exp = get_exp_func()
17 | data = exp.lookup_data()
18 |
19 | # Test init
20 | aux_exp = AuxiliaryExperiment(experiment=exp)
21 | self.assertEqual(aux_exp.experiment, exp)
22 | self.assertEqual(aux_exp.data, data)
23 |
24 | another_aux_exp = AuxiliaryExperiment(
25 | experiment=exp, data=exp.lookup_data()
26 | )
27 | self.assertEqual(another_aux_exp.experiment, exp)
28 | self.assertEqual(another_aux_exp.data, data)
29 |
--------------------------------------------------------------------------------
/sphinx/source/early_stopping.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.early_stopping
5 | =================
6 |
7 | .. automodule:: ax.early_stopping
8 | .. currentmodule:: ax.early_stopping
9 |
10 | Strategies
11 | ----------
12 |
13 | Base Strategies
14 | ~~~~~~~~~~~~~~~
15 |
16 | .. automodule:: ax.early_stopping.strategies.base
17 | :members:
18 | :undoc-members:
19 | :show-inheritance:
20 |
21 | Logical Strategies
22 | ~~~~~~~~~~~~~~~~~~
23 |
24 | .. automodule:: ax.early_stopping.strategies.logical
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
29 | `PercentileEarlyStoppingStrategy`
30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
31 |
32 | .. automodule:: ax.early_stopping.strategies.percentile
33 | :members:
34 | :undoc-members:
35 | :show-inheritance:
36 |
37 | `ThresholdEarlyStoppingStrategy`
38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39 |
40 | .. automodule:: ax.early_stopping.strategies.threshold
41 | :members:
42 | :undoc-members:
43 | :show-inheritance:
44 |
45 | Utils
46 | -----
47 |
48 | .. automodule:: ax.early_stopping.utils
49 | :members:
50 | :undoc-members:
51 | :show-inheritance:
52 |
--------------------------------------------------------------------------------
/ax/plot/css/base.css:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | .hovertext {
9 | opacity: 0.85;
10 | }
11 |
12 | .plot-menu div {
13 | margin-top: 5px;
14 | margin-bottom: 5px;
15 | }
16 |
17 | .plot-menu select {
18 | height: 30px;
19 | padding: 6px 10px;
20 | background-color: #fff;
21 | border: 1px solid #D1D1D1;
22 | border-radius: 4px;
23 | box-shadow: none;
24 | box-sizing: border-box;
25 | margin-left: 15px;
26 | margin-bottom: 0px;
27 | }
28 |
29 | .plot-menu select:focus {
30 | border: 1px solid #33C3F0;
31 | outline: 0;
32 | }
33 |
34 | .plot-menu button {
35 | display: inline-block;
36 | height: 30px;
37 | padding: 0 10px;
38 | color: #555;
39 | text-align: center;
40 | font-size: 12px;
41 | font-weight: 600;
42 | text-transform: uppercase;
43 | text-decoration: none;
44 | white-space: nowrap;
45 | background-color: transparent;
46 | border-radius: 4px;
47 | border: 1px solid #bbb;
48 | cursor: pointer;
49 | box-sizing: border-box;
50 | background-color: #fff;
51 | margin-left: 15px;
52 | }
53 |
--------------------------------------------------------------------------------
/ax/modelbridge/transforms/tests/test_rounding_transform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import numpy as np
10 | from ax.modelbridge.transforms.rounding import (
11 | randomized_onehot_round,
12 | strict_onehot_round,
13 | )
14 | from ax.utils.common.testutils import TestCase
15 |
16 |
17 | class RoundingTest(TestCase):
18 | def test_OneHotRound(self) -> None:
19 | self.assertTrue(
20 | np.allclose(
21 | strict_onehot_round(np.array([0.1, 0.5, 0.3])), np.array([0, 1, 0])
22 | )
23 | )
24 | # One item should be set to one at random.
25 | self.assertEqual(
26 | np.count_nonzero(
27 | np.isclose(
28 | randomized_onehot_round(np.array([0.0, 0.0, 0.0])),
29 | np.array([1, 1, 1]),
30 | )
31 | ),
32 | 1,
33 | )
34 | # Negative value is not selected.
35 | self.assertEqual(randomized_onehot_round(np.array([0.0, -1.0, 0.0]))[1], 0.0)
36 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_func_enum.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.utils.common.func_enum import FuncEnum
10 | from ax.utils.common.testutils import TestCase
11 |
12 |
13 | def morph_into_salamander(how_soon: int) -> bool:
14 | if how_soon <= 0: # Now is the time!
15 | return True
16 | else: # Not morphing yet
17 | return False
18 |
19 |
20 | class AnimalAbilities(FuncEnum):
21 | # ꒰(˶• ᴗ •˶)꒱
22 | AXOLOTL_MORPH = "morph_into_salamander"
23 |
24 |
25 | class EqualityTest(TestCase):
26 | def test_basic(self) -> None:
27 | self.assertEqual( # Check underlying function correctness.
28 | AnimalAbilities.AXOLOTL_MORPH._get_function_for_value(),
29 | morph_into_salamander,
30 | )
31 |
32 | def test_call(self) -> None:
33 | # Should be too early to morph...
34 | self.assertFalse(AnimalAbilities.AXOLOTL_MORPH(how_soon=1))
35 | # Should've morphed yesterday!
36 | self.assertTrue(AnimalAbilities.AXOLOTL_MORPH(how_soon=-1))
37 |
--------------------------------------------------------------------------------
/ax/core/map_metric.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from __future__ import annotations
10 |
11 | from ax.core.map_data import MapData, MapKeyInfo
12 | from ax.core.metric import Metric, MetricFetchE
13 | from ax.utils.common.result import Result
14 |
15 | MapMetricFetchResult = Result[MapData, MetricFetchE]
16 |
17 |
18 | class MapMetric(Metric):
19 | """Base class for representing metrics that return `MapData`.
20 |
21 | The `fetch_trial_data` method is the essential method to override when
22 | subclassing, which specifies how to retrieve a Metric, for a given trial.
23 |
24 | A MapMetric must return a MapData object, which requires (at minimum) the following:
25 | https://ax.dev/api/_modules/ax/core/data.html#Data.required_columns
26 |
27 | Attributes:
28 | lower_is_better: Flag for metrics which should be minimized.
29 | properties: Properties specific to a particular metric.
30 | """
31 |
32 | data_constructor: type[MapData] = MapData
33 | map_key_info: MapKeyInfo[float] = MapKeyInfo(key="step", default_value=0.0)
34 |
--------------------------------------------------------------------------------
/ax/analysis/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.analysis.utils import choose_analyses
9 | from ax.utils.common.testutils import TestCase
10 | from ax.utils.testing.core_stubs import (
11 | get_branin_experiment,
12 | get_branin_experiment_with_multi_objective,
13 | )
14 |
15 |
16 | class TestUtils(TestCase):
17 | def test_choose_analyses(self) -> None:
18 | analyses = choose_analyses(experiment=get_branin_experiment())
19 | self.assertEqual(
20 | {analysis.name for analysis in analyses},
21 | {
22 | "ParallelCoordinatesPlot",
23 | "InteractionPlot",
24 | "Summary",
25 | "CrossValidationPlot",
26 | },
27 | )
28 |
29 | # Multi-objective case
30 | analyses = choose_analyses(
31 | experiment=get_branin_experiment_with_multi_objective()
32 | )
33 | self.assertEqual(
34 | {analysis.name for analysis in analyses},
35 | {"InteractionPlot", "ScatterPlot", "Summary", "CrossValidationPlot"},
36 | )
37 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/3_general_support.yaml:
--------------------------------------------------------------------------------
1 | name: General Support
2 | description: Get advice on an experiment you're currently running, ask questions about methods, and receive general help with Ax.
3 | labels: ["question"]
4 | title: "[GENERAL SUPPORT]: "
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Thank you for reaching out -- we will do our best to respond to your inquiry promptly.
10 | - type: textarea
11 | id: question
12 | attributes:
13 | label: Question
14 | description: Provide a detailed description of the problem you're facing or the question you would like help answering.
15 | validations:
16 | required: true
17 | - type: textarea
18 | id: snippet
19 | attributes:
20 | label: Please provide any relevant code snippet if applicable.
21 | description: This will be automatically formatted into code, so no need for backticks.
22 | render: shell
23 | - type: checkboxes
24 | id: terms
25 | attributes:
26 | label: Code of Conduct
27 | description: By submitting this issue you agree to follow Ax's [Code of Conduct](https://github.com/facebook/Ax/blob/main/CODE_OF_CONDUCT.md).
28 | options:
29 | - label: I agree to follow this Ax's Code of Conduct
30 | required: true
31 |
--------------------------------------------------------------------------------
/ax/utils/common/mock.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from collections.abc import Callable
10 | from contextlib import contextmanager
11 | from typing import Any, TypeVar
12 |
13 | from unittest.mock import MagicMock, patch
14 |
15 |
16 | T = TypeVar("T")
17 | C = TypeVar("C")
18 |
19 |
20 | @contextmanager
21 | def mock_patch_method_original(
22 | mock_path: str,
23 | original_method: Callable[..., T],
24 | ) -> MagicMock:
25 | """Context manager for patching a method returning type T on class C,
26 | to track calls to it while still executing the original method. There
27 | is not a native way to do this with `mock.patch`.
28 | """
29 |
30 | def side_effect(self: C, *args: Any, **kwargs: Any) -> T:
31 | # pyre-ignore[16]: Anonymous callable has no attribute `self`
32 | # (We can ignore because we expect C to be a class).
33 | side_effect.self = self
34 | return original_method(self, *args, **kwargs)
35 |
36 | patcher = patch(mock_path, autospec=True, side_effect=side_effect)
37 | yield patcher.start()
38 | patcher.stop()
39 |
--------------------------------------------------------------------------------
/ax/utils/report/tests/test_render.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.utils.common.testutils import TestCase
10 | from ax.utils.report.render import (
11 | h2_html,
12 | h3_html,
13 | link_html,
14 | list_item_html,
15 | p_html,
16 | render_report_elements,
17 | table_cell_html,
18 | table_heading_cell_html,
19 | table_html,
20 | table_row_html,
21 | unordered_list_html,
22 | )
23 |
24 |
25 | class RenderTest(TestCase):
26 | def test_RenderReportElements(self) -> None:
27 | elements = [
28 | p_html("foobar"),
29 | h2_html("foobar"),
30 | h3_html("foobar"),
31 | list_item_html("foobar"),
32 | unordered_list_html(["foo", "bar"]),
33 | link_html("foo", "bar"),
34 | table_cell_html("foobar"),
35 | table_cell_html("foobar", width="100px"),
36 | table_heading_cell_html("foobar"),
37 | table_row_html(["foo", "bar"]),
38 | table_html(["foo", "bar"]),
39 | ]
40 | render_report_elements("test", elements)
41 |
--------------------------------------------------------------------------------
/ax/utils/common/typeutils_torch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import json
10 |
11 | import torch
12 | from ax.utils.common.typeutils import checked_cast
13 |
14 |
15 | def torch_type_to_str(value: torch.dtype | torch.device | torch.Size) -> str:
16 | """Converts torch types, commonly used in Ax, to string representations."""
17 | if isinstance(value, torch.dtype):
18 | return str(value)
19 | if isinstance(value, torch.device):
20 | return checked_cast(str, value.type)
21 | if isinstance(value, torch.Size):
22 | return json.dumps(list(value))
23 | raise ValueError(f"Object {value} was of unexpected torch type.")
24 |
25 |
26 | def torch_type_from_str(
27 | identifier: str, type_name: str
28 | ) -> torch.dtype | torch.device | torch.Size:
29 | if type_name == "device":
30 | return torch.device(identifier)
31 | if type_name == "dtype":
32 | return getattr(torch, identifier[6:])
33 | if type_name == "Size":
34 | return torch.Size(json.loads(identifier))
35 | raise ValueError(f"Unexpected type: {type_name} for identifier: {identifier}.")
36 |
--------------------------------------------------------------------------------
/ax/storage/json_store/load.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import json
10 | from collections.abc import Callable
11 | from typing import Any
12 |
13 | from ax.core.experiment import Experiment
14 | from ax.storage.json_store.decoder import object_from_json
15 | from ax.storage.json_store.registry import (
16 | CORE_CLASS_DECODER_REGISTRY,
17 | CORE_DECODER_REGISTRY,
18 | )
19 | from ax.utils.common.serialization import TDecoderRegistry
20 |
21 |
22 | def load_experiment(
23 | filepath: str,
24 | decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
25 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
26 | class_decoder_registry: dict[
27 | str, Callable[[dict[str, Any]], Any]
28 | ] = CORE_CLASS_DECODER_REGISTRY,
29 | ) -> Experiment:
30 | """Load experiment from file.
31 |
32 | 1) Read file.
33 | 2) Convert dictionary to Ax experiment instance.
34 | """
35 | with open(filepath) as file:
36 | json_experiment = json.loads(file.read())
37 | return object_from_json(
38 | json_experiment, decoder_registry, class_decoder_registry
39 | )
40 |
--------------------------------------------------------------------------------
/.github/workflows/build-and-test.yml:
--------------------------------------------------------------------------------
1 | name: Build and Test Workflow
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [ main ]
7 | pull_request:
8 | branches: [ main ]
9 |
10 | jobs:
11 | tests-and-coverage:
12 | name: Tests with latest BoTorch
13 | uses: ./.github/workflows/reusable_test.yml
14 | with:
15 | pinned_botorch: false
16 | secrets: inherit
17 |
18 | docs:
19 | runs-on: ubuntu-latest
20 |
21 | steps:
22 | - uses: actions/checkout@v4
23 | - name: Set up Python
24 | uses: actions/setup-python@v5
25 | with:
26 | python-version: "3.10"
27 | - name: Install dependencies
28 | env:
29 | ALLOW_BOTORCH_LATEST: true
30 | ALLOW_LATEST_GPYTORCH_LINOP: true
31 | run: |
32 | # use latest Botorch
33 | pip install git+https://github.com/cornellius-gp/gpytorch.git
34 | pip install git+https://github.com/pytorch/botorch.git
35 | pip install -e ".[unittest]"
36 | - name: Validate Sphinx
37 | run: |
38 | python scripts/validate_sphinx.py -p "${pwd}"
39 | - name: Run Sphinx
40 | # run even if previous step (validate Sphinx) failed
41 | if: ${{ always() }}
42 | run: |
43 | # warnings no longer treated as errors.
44 | sphinx-build -T --keep-going sphinx/source sphinx/build
45 |
--------------------------------------------------------------------------------
/ax/benchmark/benchmark_trial_metadata.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from collections.abc import Mapping
9 | from dataclasses import dataclass
10 |
11 | import pandas as pd
12 |
13 | from ax.utils.testing.backend_simulator import BackendSimulator
14 |
15 |
16 | @dataclass(kw_only=True, frozen=True)
17 | class BenchmarkTrialMetadata:
18 | """
19 | Data pertaining to one trial evaluation.
20 |
21 | Args:
22 | df: A dict mapping each metric name to a Pandas DataFrame with columns
23 | ["metric_name", "arm_name", "mean", "sem", and "step"]. The "sem" is
24 | always present in this df even if noise levels are unobserved;
25 | ``BenchmarkMetric`` and ``BenchmarkMapMetric`` hide that data if it
26 | should not be observed, and ``BenchmarkMapMetric``s drop data from
27 | time periods that that are not observed based on the (simulated)
28 | trial progression.
29 | backend_simulator: Optionally, the backend simulator that is tracking
30 | the trial's status.
31 | """
32 |
33 | dfs: Mapping[str, pd.DataFrame]
34 | backend_simulator: BackendSimulator | None = None
35 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_docutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.utils.common.docutils import copy_doc
10 | from ax.utils.common.testutils import TestCase
11 |
12 |
13 | def has_doc() -> None:
14 | """I have a docstring"""
15 |
16 |
17 | def has_no_doc() -> None:
18 | pass
19 |
20 |
21 | class TestDocUtils(TestCase):
22 | def test_transfer_doc(self) -> None:
23 | @copy_doc(has_doc)
24 | # pyre-fixme[3]: Return type must be annotated.
25 | def inherits_doc():
26 | pass
27 |
28 | self.assertEqual(inherits_doc.__doc__, "I have a docstring")
29 |
30 | def test_fail_when_already_has_doc(self) -> None:
31 | with self.assertRaises(ValueError):
32 |
33 | @copy_doc(has_doc)
34 | # pyre-fixme[3]: Return type must be annotated.
35 | def inherits_doc():
36 | """I already have a doc string"""
37 | pass
38 |
39 | def test_fail_when_no_doc_to_copy(self) -> None:
40 | with self.assertRaises(ValueError):
41 |
42 | @copy_doc(has_no_doc)
43 | # pyre-fixme[3]: Return type must be annotated.
44 | def f():
45 | pass
46 |
--------------------------------------------------------------------------------
/ax/benchmark/benchmark_test_function.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from abc import ABC, abstractmethod
9 | from collections.abc import Mapping, Sequence
10 | from dataclasses import dataclass
11 |
12 | from ax.core.types import TParamValue
13 | from torch import Tensor
14 |
15 |
16 | @dataclass(kw_only=True)
17 | class BenchmarkTestFunction(ABC):
18 | """
19 | The basic Ax class for generating deterministic data to benchmark against.
20 |
21 | (Noise - if desired - is added by the runner.)
22 |
23 | Args:
24 | outcome_names: Names of the outcomes.
25 | n_steps: Number of data points produced per metric and per evaluation. 1
26 | if data is not time-series. If data is time-series, this will
27 | eventually become the number of values on a `MapMetric` for
28 | evaluations that run to completion.
29 | """
30 |
31 | outcome_names: Sequence[str]
32 | n_steps: int = 1
33 |
34 | @abstractmethod
35 | def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
36 | """
37 | Evaluate noiselessly.
38 |
39 | Returns:
40 | A 2d tensor of shape (len(self.outcome_names), self.n_steps).
41 | """
42 | ...
43 |
--------------------------------------------------------------------------------
/ax/utils/report/resources/base_template.html:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 | {{experiment_name}}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | {% if headfoot %}
16 |
17 |
18 | {% if trial_index is none %}
19 | Report for {{experiment_name}}
20 | {% else %}
21 | Report for {{ batch_noun }} #{{ trial_index }} of
22 | {{experiment_name}}
23 | {% endif %}
24 |
25 | Powered by Ax
26 |
27 | {% endif %}
28 |
29 | {% block content %}
30 | {% endblock %}
31 |
32 | {% if headfoot %}
33 |
34 | {% endif %}
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/website/static/img/dice-solid.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
17 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_kwargutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from logging import Logger
11 | from unittest.mock import patch
12 |
13 | from ax.utils.common.kwargs import warn_on_kwargs
14 | from ax.utils.common.logger import get_logger
15 | from ax.utils.common.testutils import TestCase
16 |
17 | logger: Logger = get_logger("ax.utils.common.kwargs")
18 |
19 |
20 | class TestWarnOnKwargs(TestCase):
21 | def test_it_warns_if_kwargs_are_passed(self) -> None:
22 | with patch.object(logger, "warning") as mock_warning:
23 |
24 | def callable_arg() -> None:
25 | return
26 |
27 | warn_on_kwargs(callable_with_kwargs=callable_arg, foo="")
28 | mock_warning.assert_called_once_with(
29 | "Found unexpected kwargs: %s while calling %s "
30 | "from JSON. These kwargs will be ignored.",
31 | {"foo": ""},
32 | callable_arg,
33 | )
34 |
35 | def test_it_does_not_warn_if_no_kwargs_are_passed(self) -> None:
36 | with patch.object(logger, "warning") as mock_warning:
37 | warn_on_kwargs(callable_with_kwargs=lambda: None)
38 | mock_warning.assert_not_called()
39 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_helper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from ax.plot.helper import arm_name_to_sort_key, extend_range
11 | from ax.utils.common.testutils import TestCase
12 |
13 |
14 | class HelperTest(TestCase):
15 | def test_extend_range(self) -> None:
16 | with self.assertRaises(ValueError):
17 | extend_range(lower=1, upper=-1)
18 | self.assertEqual(extend_range(lower=-1, upper=1), (-1.2, 1.2))
19 | self.assertEqual(extend_range(lower=-1, upper=0, percent=30), (-1.3, 0.3))
20 | self.assertEqual(extend_range(lower=0, upper=1, percent=50), (-0.5, 1.5))
21 |
22 | def test_arm_name_to_sort_key(self) -> None:
23 | arm_names = ["0_0", "1_10", "1_2", "10_0", "control"]
24 | sorted_names = sorted(arm_names, key=arm_name_to_sort_key, reverse=True)
25 | expected = ["control", "0_0", "1_2", "1_10", "10_0"]
26 | self.assertEqual(sorted_names, expected)
27 |
28 | arm_names = ["0_0", "0", "1_10", "3_2_x", "3_x", "1_2", "control"]
29 | sorted_names = sorted(arm_names, key=arm_name_to_sort_key, reverse=True)
30 | expected = ["control", "3_x", "3_2_x", "0", "0_0", "1_2", "1_10"]
31 | self.assertEqual(sorted_names, expected)
32 |
--------------------------------------------------------------------------------
/ax/utils/common/deprecation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import warnings
10 |
11 |
12 | def _validate_force_random_search(
13 | no_bayesian_optimization: bool | None = None,
14 | force_random_search: bool = False,
15 | exception_cls: type[Exception] = ValueError,
16 | ) -> None:
17 | """Helper function to validate interaction between `force_random_search`
18 | and `no_bayesian_optimization` (supported until deprecation in [T199632397])
19 | """
20 | if no_bayesian_optimization is not None:
21 | # users are effectively permitted to continue using
22 | # `no_bayesian_optimization` so long as it doesn't
23 | # conflict with `force_random_search`
24 | if no_bayesian_optimization != force_random_search:
25 | raise exception_cls(
26 | "Conflicting values for `force_random_search` "
27 | "and `no_bayesian_optimization`! "
28 | "Please only specify `force_random_search`."
29 | )
30 | warnings.warn(
31 | "`no_bayesian_optimization` is deprecated. Please use "
32 | "`force_random_search` in the future.",
33 | DeprecationWarning,
34 | stacklevel=2,
35 | )
36 |
--------------------------------------------------------------------------------
/ax/preview/api/protocols/metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 |
9 | from typing import Any, Mapping
10 |
11 | from ax.preview.api.protocols.utils import _APIMetric
12 | from pyre_extensions import override
13 |
14 |
15 | class IMetric(_APIMetric):
16 | """
17 | Metrics automate the process of fetching data from external systems. They are used
18 | in conjunction with Runners in the run_n_trials method to facilitate closed-loop
19 | experimentation.
20 | """
21 |
22 | def __init__(self, name: str) -> None:
23 | super().__init__(name=name)
24 |
25 | @override
26 | def fetch(
27 | self,
28 | trial_index: int,
29 | trial_metadata: Mapping[str, Any],
30 | ) -> tuple[int, float | tuple[float, float]]:
31 | """
32 | Given trial metadata (the mapping returned from IRunner.run), fetches
33 | readings for the metric.
34 |
35 | Readings are returned as a pair (progression, outcome), where progression is
36 | an integer representing the progression of the trial (e.g. number of epochs
37 | for a training job, timestamp for a time series, etc.), and outcome is either
38 | direct reading or a (mean, sem) pair for the metric.
39 | """
40 | ...
41 |
--------------------------------------------------------------------------------
/ax/runners/synthetic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from collections.abc import Iterable
10 | from typing import Any
11 |
12 | from ax.core.base_trial import BaseTrial, TrialStatus
13 | from ax.core.runner import Runner
14 |
15 |
16 | class SyntheticRunner(Runner):
17 | """Class for synthetic or dummy runner.
18 |
19 | Currently acts as a shell runner, only creating a name.
20 | """
21 |
22 | def __init__(self, dummy_metadata: str | None = None) -> None:
23 | self.dummy_metadata = dummy_metadata
24 |
25 | def run(self, trial: BaseTrial) -> dict[str, Any]:
26 | deployed_name = (
27 | trial.experiment.name + "_" + str(trial.index)
28 | if trial.experiment.has_name
29 | else str(trial.index)
30 | )
31 | metadata = {"name": deployed_name}
32 |
33 | # Add dummy metadata if needed for testing
34 | if self.dummy_metadata:
35 | metadata["dummy_metadata"] = self.dummy_metadata
36 | return metadata
37 |
38 | def poll_trial_status(
39 | self, trials: Iterable[BaseTrial]
40 | ) -> dict[TrialStatus, set[int]]:
41 | return {TrialStatus.COMPLETED: {t.index for t in trials}}
42 |
43 | @property
44 | def run_metadata_report_keys(self) -> list[str]:
45 | return ["name"]
46 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/2_feature_request.yaml:
--------------------------------------------------------------------------------
1 | name: Feature Request
2 | description: Request a feature or improvement to Ax.
3 | labels: ["enhancement"]
4 | title: "[FEATURE REQUEST]: "
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Thank you for taking the time to request a feature! We strive to make Ax a useful and rich library for our users.
10 | - type: textarea
11 | id: motivation
12 | attributes:
13 | label: Motivation
14 | description: Provide a detailed description of the problem you would like to solve via this new feature or improvement.
15 | validations:
16 | required: true
17 | - type: textarea
18 | id: pitch
19 | attributes:
20 | label: Describe the solution you'd like to see implemented in Ax.
21 | validations:
22 | required: true
23 | - type: textarea
24 | id: alternatives
25 | attributes:
26 | label: Describe any alternatives you've considered to the above solution.
27 | - type: textarea
28 | id: related
29 | attributes:
30 | label: Is this related to an existing issue in Ax or another repository? If so please include links to those Issues here.
31 | - type: checkboxes
32 | id: terms
33 | attributes:
34 | label: Code of Conduct
35 | description: By submitting this issue, you agree to follow Ax's [Code of Conduct](https://github.com/facebook/Ax/blob/main/CODE_OF_CONDUCT.md).
36 | options:
37 | - label: I agree to follow Ax's Code of Conduct
38 | required: true
39 |
--------------------------------------------------------------------------------
/ax/exceptions/storage.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from ax.exceptions.core import AxError
11 |
12 |
13 | STORAGE_DOCS_SUFFIX = (
14 | "Please see our storage tutorial (https://ax.dev/docs/storage.html) "
15 | "for more details ('Customizing' section will be "
16 | "relevant for saving Ax object subclasses)."
17 | )
18 |
19 |
20 | class JSONDecodeError(AxError):
21 | """Raised when an error occurs during JSON decoding."""
22 |
23 | pass
24 |
25 |
26 | class JSONEncodeError(AxError):
27 | """Raised when an error occurs during JSON encoding."""
28 |
29 | pass
30 |
31 |
32 | class SQADecodeError(AxError):
33 | """Raised when an error occurs during SQA decoding."""
34 |
35 | pass
36 |
37 |
38 | class SQAEncodeError(AxError):
39 | """Raised when an error occurs during SQA encoding."""
40 |
41 | pass
42 |
43 |
44 | class ImmutabilityError(AxError):
45 | """Raised when an attempt is made to update an immutable object."""
46 |
47 | pass
48 |
49 |
50 | class IncorrectDBConfigurationError(AxError):
51 | """Raised when an attempt is made to save and load an object, but
52 | the current engine and session factory is setup up incorrectly to
53 | process the call (e.g. current session factory will connect to a
54 | wrong database for the call).
55 | """
56 |
57 | pass
58 |
--------------------------------------------------------------------------------
/ax/models/winsorization_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from dataclasses import dataclass
10 |
11 |
12 | @dataclass
13 | class WinsorizationConfig:
14 | """Dataclass for storing Winsorization configuration parameters
15 |
16 | Attributes:
17 | lower_quantile_margin: Winsorization will increase any metric value below this
18 | quantile to this quantile's value.
19 | upper_quantile_margin: Winsorization will decrease any metric value above this
20 | quantile to this quantile's value. NOTE: this quantile will be inverted before
21 | any operations, e.g., a value of 0.2 will decrease values above the 80th
22 | percentile to the value of the 80th percentile.
23 | lower_boundary: If this value is lesser than the metric value corresponding to
24 | ``lower_quantile_margin``, set metric values below ``lower_boundary`` to
25 | ``lower_boundary`` and leave larger values unaffected.
26 | upper_boundary: If this value is greater than the metric value corresponding to
27 | ``upper_quantile_margin``, set metric values above ``upper_boundary`` to
28 | ``upper_boundary`` and leave smaller values unaffected.
29 | """
30 |
31 | lower_quantile_margin: float = 0.0
32 | upper_quantile_margin: float = 0.0
33 | lower_boundary: float | None = None
34 | upper_boundary: float | None = None
35 |
--------------------------------------------------------------------------------
/ax/core/auxiliary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from __future__ import annotations
9 |
10 | from enum import Enum, unique
11 | from typing import TYPE_CHECKING
12 |
13 | from ax.core.data import Data
14 | from ax.utils.common.base import SortableBase
15 |
16 |
17 | if TYPE_CHECKING:
18 | # import as module to make sphinx-autodoc-typehints happy
19 | from ax import core # noqa F401
20 |
21 |
22 | class AuxiliaryExperiment(SortableBase):
23 | """Class for defining an auxiliary experiment."""
24 |
25 | def __init__(
26 | self,
27 | experiment: core.experiment.Experiment,
28 | data: Data | None = None,
29 | ) -> None:
30 | """
31 | Lightweight container of an experiment, and its data,
32 | that will be used as auxiliary information for another experiment.
33 | """
34 | self.experiment = experiment
35 | self.data: Data = data or experiment.lookup_data()
36 |
37 | def _unique_id(self) -> str:
38 | # While there can be multiple `AuxiliarySource`-s made from the same
39 | # experiment (and thus sharing the experiment name), the uniqueness
40 | # here is only needed w.r.t. parent object ("main experiment", for which
41 | # this will be an auxiliary source for).
42 | return self.experiment.name
43 |
44 |
45 | @unique
46 | class AuxiliaryExperimentPurpose(Enum):
47 | pass
48 |
--------------------------------------------------------------------------------
/.github/workflows/reusable_tutorials.yml:
--------------------------------------------------------------------------------
1 | name: Reusable Tutorials Workflow
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | smoke_test:
7 | required: true
8 | type: boolean
9 | pinned_botorch:
10 | required: true
11 | type: boolean
12 | workflow_call:
13 | inputs:
14 | smoke_test:
15 | required: true
16 | type: boolean
17 | pinned_botorch:
18 | required: true
19 | type: boolean
20 |
21 | jobs:
22 |
23 | build-tutorials:
24 | name: Tutorials
25 | runs-on: ubuntu-latest
26 | steps:
27 | - uses: actions/checkout@v4
28 | - name: Set up Python
29 | uses: actions/setup-python@v5
30 | with:
31 | python-version: "3.10"
32 |
33 | - if: ${{ inputs.pinned_botorch }}
34 | name: Install dependencies with pinned BoTorch
35 | run: |
36 | pip install -e ".[tutorial]"
37 |
38 | - if: ${{ !inputs.pinned_botorch }}
39 | name: Install dependencies with latest BoTorch
40 | env:
41 | ALLOW_BOTORCH_LATEST: true
42 | ALLOW_LATEST_GPYTORCH_LINOP: true
43 | run: |
44 | pip install git+https://github.com/cornellius-gp/gpytorch.git
45 | pip install git+https://github.com/pytorch/botorch.git
46 | pip install -e ".[tutorial]"
47 |
48 | - if: ${{ inputs.smoke_test }}
49 | name: Build tutorials with smoke test
50 | run: |
51 | python scripts/make_tutorials.py -w $(pwd) -e -s
52 | - if: ${{ !inputs.smoke_test }}
53 | name: Build tutorials without smoke test
54 | run: |
55 | python scripts/make_tutorials.py -w $(pwd) -e
56 |
--------------------------------------------------------------------------------
/scripts/patch_site_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import re
9 |
10 |
11 | def patch_config(
12 | config_file: str, base_url: str = None, disable_algolia: bool = True
13 | ) -> None:
14 | config = open(config_file).read()
15 |
16 | if base_url is not None:
17 | config = re.sub("baseUrl = '/';", f"baseUrl = '{base_url}';", config)
18 | if disable_algolia is True:
19 | config = re.sub(
20 | "const includeAlgolia = true;", "const includeAlgolia = false;", config
21 | )
22 |
23 | with open(config_file, "w") as outfile:
24 | outfile.write(config)
25 |
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser(
29 | description="Path Docusaurus siteConfig.js file when building site."
30 | )
31 | parser.add_argument(
32 | "-f",
33 | "--config_file",
34 | metavar="path",
35 | required=True,
36 | help="Path to configuration file.",
37 | )
38 | parser.add_argument(
39 | "-b",
40 | "--base_url",
41 | type=str,
42 | required=False,
43 | help="Value for baseUrl.",
44 | default=None,
45 | )
46 | parser.add_argument(
47 | "--disable_algolia",
48 | required=False,
49 | action="store_true",
50 | help="Disable algolia.",
51 | )
52 | args = parser.parse_args()
53 | patch_config(args.config_file, args.base_url, args.disable_algolia)
54 |
--------------------------------------------------------------------------------
/ax/utils/common/random.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import random
10 | from collections.abc import Generator
11 | from contextlib import contextmanager
12 |
13 | import numpy as np
14 | import torch
15 |
16 |
17 | def set_rng_seed(seed: int) -> None:
18 | """Sets seeds for random number generators from numpy, pytorch,
19 | and the native random module.
20 |
21 | Args:
22 | seed: The random number generator seed.
23 | """
24 | random.seed(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 |
28 |
29 | @contextmanager
30 | def with_rng_seed(seed: int | None) -> Generator[None, None, None]:
31 | """Context manager that sets the random number generator seeds
32 | to a given value and restores the previous state on exit.
33 |
34 | If the seed is None, the context manager does nothing. This makes
35 | it possible to use the context manager without having to change
36 | the code based on whether the seed is specified.
37 |
38 | Args:
39 | seed: The random number generator seed.
40 | """
41 | if seed is None:
42 | yield
43 | else:
44 | old_state_native = random.getstate()
45 | old_state_numpy = np.random.get_state()
46 | try:
47 | with torch.random.fork_rng():
48 | set_rng_seed(seed)
49 | yield
50 | finally:
51 | random.setstate(old_state_native)
52 | np.random.set_state(old_state_numpy)
53 |
--------------------------------------------------------------------------------
/ax/storage/tests/test_registry_bundle.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.benchmark.benchmark_metric import BenchmarkMetric
9 | from ax.metrics.branin import BraninMetric
10 | from ax.runners.synthetic import SyntheticRunner
11 | from ax.storage.registry_bundle import RegistryBundle
12 | from ax.utils.common.testutils import TestCase
13 |
14 |
15 | class RegistryBundleTest(TestCase):
16 | def test_from_registry_bundles(self) -> None:
17 | left = RegistryBundle(
18 | metric_clss={BraninMetric: None},
19 | runner_clss={SyntheticRunner: None},
20 | json_encoder_registry={},
21 | json_class_encoder_registry={},
22 | json_decoder_registry={},
23 | json_class_decoder_registry={},
24 | )
25 |
26 | right = RegistryBundle(
27 | metric_clss={BenchmarkMetric: None},
28 | runner_clss={SyntheticRunner: None},
29 | json_encoder_registry={},
30 | json_class_encoder_registry={},
31 | json_decoder_registry={},
32 | json_class_decoder_registry={},
33 | )
34 |
35 | self.assertIn(BraninMetric, left.encoder_registry)
36 | self.assertNotIn(BenchmarkMetric, left.encoder_registry)
37 |
38 | combined = RegistryBundle.from_registry_bundles(left, right)
39 |
40 | self.assertIn(BraninMetric, combined.encoder_registry)
41 | self.assertIn(SyntheticRunner, combined.encoder_registry)
42 | self.assertIn(BenchmarkMetric, combined.encoder_registry)
43 |
--------------------------------------------------------------------------------
/website/static/js/plotUtils.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | // helper functions used across multiple plots
9 | function rgb(rgb_array) {
10 | return 'rgb(' + rgb_array.join() + ')';
11 | }
12 |
13 | function copy_and_reverse(arr) {
14 | const copy = arr.slice();
15 | copy.reverse();
16 | return copy;
17 | }
18 |
19 | function axis_range(grid, is_log) {
20 | return is_log ?
21 | [Math.log10(Math.min(...grid)), Math.log10(Math.max(...grid))]:
22 | [Math.min(...grid), Math.max(...grid)];
23 | }
24 |
25 | function relativize_data(f, sd, rel, arm_data, metric) {
26 | // if relative, extract status quo & compute ratio
27 | const f_final = rel === true ? [] : f;
28 | const sd_final = rel === true ? []: sd;
29 |
30 | if (rel === true) {
31 | const f_sq = (
32 | arm_data['in_sample'][arm_data['status_quo_name']]['y'][metric]
33 | );
34 | const sd_sq = (
35 | arm_data['in_sample'][arm_data['status_quo_name']]['se'][metric]
36 | );
37 |
38 | for (let i = 0; i < f.length; i++) {
39 | res = relativize(f[i], sd[i], f_sq, sd_sq);
40 | f_final.push(100 * res[0]);
41 | sd_final.push(100 * res[1]);
42 | }
43 | }
44 |
45 | return [f_final, sd_final];
46 | }
47 |
48 | function relativize(m_t, sem_t, m_c, sem_c) {
49 | r_hat = (
50 | (m_t - m_c) / Math.abs(m_c) -
51 | Math.pow(sem_c, 2) * m_t / Math.pow(Math.abs(m_c), 3)
52 | );
53 | variance = (
54 | (Math.pow(sem_t, 2) + Math.pow((m_t / m_c * sem_c), 2)) /
55 | Math.pow(m_c, 2)
56 | )
57 | return [r_hat, Math.sqrt(variance)];
58 | }
59 |
--------------------------------------------------------------------------------
/ax/utils/common/func_enum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 | from enum import Enum, unique
8 | from importlib import import_module
9 | from typing import Any, Callable
10 |
11 | from ax.exceptions.core import UnsupportedError
12 |
13 |
14 | @unique
15 | class FuncEnum(Enum):
16 | """A base class for all enums with the following structure: string values that
17 | map to names of functions, which reside in the same module as the enum."""
18 |
19 | # pyre-ignore[3]: Input constructors will be used to make different inputs,
20 | # so we need to allow `Any` return type here.
21 | def __call__(self, **kwargs: Any) -> Any:
22 | """Defines a method, by which the members of this enum can be called,
23 | e.g. ``MyFunctions.F(**kwargs)``, which will call the corresponding
24 | function registered by the name ``F`` in the enum."""
25 | return self._get_function_for_value()(**kwargs)
26 |
27 | # pyre-ignore[31]: Expression `typing.Callable[([...], typing.Any)]`
28 | # is not a valid type.
29 | def _get_function_for_value(self) -> Callable[[...], Any]:
30 | """Retrieve the function in this module, name of which corresponds to the
31 | value of the enum member."""
32 | try:
33 | return getattr(import_module(self.__module__), self.value)
34 | except AttributeError:
35 | raise UnsupportedError(
36 | f"{self.value} is not defined as a method in "
37 | f"`{self.__module__}`. Please add the method "
38 | "to the file."
39 | )
40 |
--------------------------------------------------------------------------------
/ax/analysis/healthcheck/tests/test_should_generate_candidates.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-unsafe
7 |
8 | from random import randint
9 |
10 | from ax.analysis.analysis import AnalysisCardLevel
11 | from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus
12 | from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
13 | from ax.utils.common.testutils import TestCase
14 |
15 |
16 | class TestShouldGenerateCandidates(TestCase):
17 | def test_should(self) -> None:
18 | trial_index = randint(0, 10)
19 | card = ShouldGenerateCandidates(
20 | should_generate=True,
21 | reason="Something reassuring",
22 | trial_index=trial_index,
23 | ).compute()
24 | self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
25 | self.assertEqual(card.level, AnalysisCardLevel.CRITICAL)
26 | self.assertEqual(card.subtitle, "Something reassuring")
27 | self.assertEqual(card.attributes["trial_index"], trial_index)
28 |
29 | def test_should_not(self) -> None:
30 | trial_index = randint(0, 10)
31 | card = ShouldGenerateCandidates(
32 | should_generate=False,
33 | reason="Something concerning",
34 | trial_index=trial_index,
35 | ).compute()
36 | self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
37 | self.assertEqual(card.level, AnalysisCardLevel.CRITICAL)
38 | self.assertEqual(card.subtitle, "Something concerning")
39 | self.assertEqual(card.attributes["trial_index"], trial_index)
40 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_slices.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import plotly.graph_objects as go
10 | from ax.modelbridge.registry import Models
11 | from ax.plot.base import AxPlotConfig
12 | from ax.plot.slice import (
13 | interact_slice,
14 | interact_slice_plotly,
15 | plot_slice,
16 | plot_slice_plotly,
17 | )
18 | from ax.utils.common.testutils import TestCase
19 | from ax.utils.testing.core_stubs import get_branin_experiment
20 | from ax.utils.testing.mock import mock_botorch_optimize
21 |
22 |
23 | class SlicesTest(TestCase):
24 | @mock_botorch_optimize
25 | def test_Slices(self) -> None:
26 | exp = get_branin_experiment(with_batch=True)
27 | exp.trials[0].run()
28 | model = Models.BOTORCH_MODULAR(
29 | # Model bridge kwargs
30 | experiment=exp,
31 | data=exp.fetch_data(),
32 | )
33 | # Assert that each type of plot can be constructed successfully
34 | plot = plot_slice_plotly(
35 | model,
36 | # pyre-fixme[16]: `ModelBridge` has no attribute `parameters`.
37 | model.parameters[0],
38 | list(model.metric_names)[0],
39 | )
40 | self.assertIsInstance(plot, go.Figure)
41 | plot = interact_slice_plotly(model)
42 | self.assertIsInstance(plot, go.Figure)
43 | plot = plot_slice(model, model.parameters[0], list(model.metric_names)[0])
44 | self.assertIsInstance(plot, AxPlotConfig)
45 | plot = interact_slice(model)
46 | self.assertIsInstance(plot, AxPlotConfig)
47 |
--------------------------------------------------------------------------------
/ax/modelbridge/transforms/search_space_to_float.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from ax.core.arm import Arm
11 | from ax.core.observation import ObservationFeatures
12 | from ax.core.parameter import ParameterType, RangeParameter
13 | from ax.core.search_space import SearchSpace
14 | from ax.modelbridge.transforms.base import Transform
15 |
16 |
17 | class SearchSpaceToFloat(Transform):
18 | """Replaces the search space with a single range parameter, whose values
19 | are derived from the signature of the arms.
20 |
21 | NOTE: This will have collisions and so should not be used whenever unique
22 | observation features need to be preserved. Its purpose is to enable
23 | forward transforms for any search space regardless of parameterization.
24 |
25 | Transform is done in-place.
26 | """
27 |
28 | def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
29 | parameter = RangeParameter(
30 | name="HASH_PARAM",
31 | parameter_type=ParameterType.FLOAT,
32 | lower=0.0,
33 | upper=1e12,
34 | )
35 | return SearchSpace(parameters=[parameter])
36 |
37 | def transform_observation_features(
38 | self, observation_features: list[ObservationFeatures]
39 | ) -> list[ObservationFeatures]:
40 | for obsf in observation_features:
41 | sig = Arm(parameters=obsf.parameters).signature
42 | val = float(int(sig, 16) % 1_000_000_000_000)
43 | obsf.parameters = {"HASH_PARAM": val}
44 | return observation_features
45 |
--------------------------------------------------------------------------------
/ax/utils/notebook/plotting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from logging import Logger
10 |
11 | from ax.plot.base import AxPlotConfig, AxPlotTypes
12 | from ax.plot.render import _js_requires, _wrap_js, plot_config_to_html
13 | from ax.utils.common.logger import get_logger
14 | from IPython.display import display
15 | from plotly.offline import init_notebook_mode, iplot
16 |
17 | logger: Logger = get_logger(__name__)
18 |
19 |
20 | def init_notebook_plotting(offline: bool = False) -> None:
21 | """Initialize plotting in notebooks, either in online or offline mode."""
22 | display_bundle = {"text/html": _wrap_js(_js_requires(offline=offline))}
23 | display(display_bundle, raw=True)
24 | logger.info("Injecting Plotly library into cell. Do not overwrite or delete cell.")
25 | logger.info(
26 | """
27 | Please see
28 | (https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering)
29 | if visualizations are not rendering.
30 | """.strip()
31 | )
32 | init_notebook_mode()
33 |
34 |
35 | def render(plot_config: AxPlotConfig, inject_helpers: bool = False) -> None:
36 | """Render plot config."""
37 | if plot_config.plot_type == AxPlotTypes.GENERIC:
38 | iplot(plot_config.data)
39 | elif plot_config.plot_type == AxPlotTypes.HTML:
40 | assert "text/html" in plot_config.data
41 | display(plot_config.data, raw=True)
42 | else:
43 | display_bundle = {
44 | "text/html": plot_config_to_html(plot_config, inject_helpers=inject_helpers)
45 | }
46 | display(display_bundle, raw=True)
47 |
--------------------------------------------------------------------------------
/ax/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.core import (
10 | Arm,
11 | BatchTrial,
12 | ChoiceParameter,
13 | ComparisonOp,
14 | Data,
15 | Experiment,
16 | FixedParameter,
17 | GeneratorRun,
18 | Metric,
19 | MultiObjective,
20 | MultiObjectiveOptimizationConfig,
21 | Objective,
22 | ObjectiveThreshold,
23 | OptimizationConfig,
24 | OrderConstraint,
25 | OutcomeConstraint,
26 | Parameter,
27 | ParameterConstraint,
28 | ParameterType,
29 | RangeParameter,
30 | Runner,
31 | SearchSpace,
32 | SumConstraint,
33 | Trial,
34 | )
35 | from ax.modelbridge import Models
36 | from ax.service import OptimizationLoop, optimize
37 | from ax.storage import json_load, json_save
38 |
39 | try:
40 | pass
41 | except Exception: # pragma: no cover
42 | __version__ = "Unknown"
43 |
44 |
45 | __all__ = [
46 | "Arm",
47 | "BatchTrial",
48 | "ChoiceParameter",
49 | "ComparisonOp",
50 | "Data",
51 | "Experiment",
52 | "FixedParameter",
53 | "GeneratorRun",
54 | "Metric",
55 | "Models",
56 | "MultiObjective",
57 | "MultiObjectiveOptimizationConfig",
58 | "Objective",
59 | "ObjectiveThreshold",
60 | "OptimizationConfig",
61 | "OptimizationLoop",
62 | "OrderConstraint",
63 | "OutcomeConstraint",
64 | "Parameter",
65 | "ParameterConstraint",
66 | "ParameterType",
67 | "RangeParameter",
68 | "Runner",
69 | "SearchSpace",
70 | "SumConstraint",
71 | "Trial",
72 | "optimize",
73 | "json_save",
74 | "json_load",
75 | ]
76 |
--------------------------------------------------------------------------------
/ax/modelbridge/transforms/deprecated_transform_mixin.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from logging import Logger
10 | from typing import Any
11 |
12 | from ax.utils.common.logger import get_logger
13 |
14 | logger: Logger = get_logger(__name__)
15 |
16 |
17 | class DeprecatedTransformMixin:
18 | """
19 | Mixin class for deprecated transforms.
20 |
21 | This class is used to log warnings when a deprecated transform is used,
22 | and will construct the new transform that should be used instead.
23 |
24 | The deprecated transform should inherit as follows:
25 |
26 | class DeprecatedTransform(DeprecatedTransformMixin, NewTransform):
27 | ...
28 |
29 | :meta private:
30 | """
31 |
32 | def __init__(self, *args: Any, **kwargs: Any) -> None:
33 | """
34 | Log a warning that the transform is deprecated, and construct the
35 | new transform.
36 | """
37 | warning_msg = self.warn_deprecated_message(
38 | self.__class__.__name__, type(self).__bases__[1].__name__
39 | )
40 | logger.warning(warning_msg)
41 |
42 | super().__init__(*args, **kwargs)
43 |
44 | @staticmethod
45 | def warn_deprecated_message(
46 | deprecated_transform_name: str, new_transform_name: str
47 | ) -> str:
48 | """
49 | Constructs the warning message.
50 | """
51 | return (
52 | f"`{deprecated_transform_name}` transform has been deprecated "
53 | "and will be removed in a future release. "
54 | f"Using `{new_transform_name}` instead."
55 | )
56 |
--------------------------------------------------------------------------------
/ax/utils/testing/torch_stubs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from __future__ import annotations
10 |
11 | from typing import Any
12 |
13 | import torch
14 |
15 |
16 | def get_torch_test_data(
17 | dtype: torch.dtype = torch.float,
18 | cuda: bool = False,
19 | constant_noise: bool = True,
20 | task_features: list[int] | None = None,
21 | offset: float = 0.0,
22 | ) -> tuple[
23 | list[torch.Tensor],
24 | list[torch.Tensor],
25 | list[torch.Tensor],
26 | list[tuple[float, float]],
27 | list[int],
28 | list[str],
29 | list[str],
30 | ]:
31 | tkwargs: dict[str, Any] = {
32 | "device": torch.device("cuda" if cuda else "cpu"),
33 | "dtype": dtype,
34 | }
35 | Xs = [
36 | torch.tensor(
37 | [
38 | [1.0 + offset, 2.0 + offset, 3.0 + offset],
39 | [2.0 + offset, 3.0 + offset, 4.0 + offset],
40 | ],
41 | **tkwargs,
42 | )
43 | ]
44 | Ys = [torch.tensor([[3.0 + offset], [4.0 + offset]], **tkwargs)]
45 | if constant_noise:
46 | Yvar = torch.ones(2, 1, **tkwargs)
47 | else:
48 | Yvar = torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs)
49 | Yvars = [Yvar]
50 |
51 | bounds = [
52 | (0.0 + offset, 1.0 + offset),
53 | (1.0 + offset, 4.0 + offset),
54 | (2.0 + offset, 5.0 + offset),
55 | ]
56 | feature_names = ["x1", "x2", "x3"]
57 | task_features = [] if task_features is None else task_features
58 | metric_names = ["y"]
59 | return Xs, Ys, Yvars, bounds, task_features, feature_names, metric_names
60 |
--------------------------------------------------------------------------------
/ax/storage/sqa_store/reduced_state.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from ax.storage.sqa_store.sqa_classes import SQAGeneratorRun
11 | from sqlalchemy.orm import defaultload, lazyload, strategy_options
12 | from sqlalchemy.orm.attributes import InstrumentedAttribute
13 |
14 |
15 | GR_LARGE_MODEL_ATTRS: list[InstrumentedAttribute] = [ # pyre-ignore[9]
16 | SQAGeneratorRun.model_kwargs,
17 | SQAGeneratorRun.bridge_kwargs,
18 | SQAGeneratorRun.model_state_after_gen,
19 | SQAGeneratorRun.gen_metadata,
20 | ]
21 |
22 |
23 | GR_PARAMS_METRICS_COLS = [
24 | "parameters",
25 | "parameter_constraints",
26 | "metrics",
27 | ]
28 |
29 |
30 | def get_query_options_to_defer_immutable_duplicates() -> list[strategy_options.Load]:
31 | """Returns the query options that defer loading of attributes that are duplicated
32 | on each trial (like search space attributes and metrics). These attributes do not
33 | need to be loaded for experiments with immutable search space and optimization
34 | configuration.
35 | """
36 | options = [lazyload(f"generator_runs.{col}") for col in GR_PARAMS_METRICS_COLS]
37 | return options
38 |
39 |
40 | def get_query_options_to_defer_large_model_cols() -> list[strategy_options.Load]:
41 | """Returns the query options that defer loading of model-state-related columns
42 | of generator runs, which can be large and are not needed on every generator run
43 | when loading experiment and generation strategy in reduced state.
44 | """
45 | return [
46 | defaultload("generator_runs").defer(col.key) for col in GR_LARGE_MODEL_ATTRS
47 | ]
48 |
--------------------------------------------------------------------------------
/ax/utils/common/timeutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from collections.abc import Generator
10 | from datetime import datetime, timedelta
11 | from time import time
12 |
13 | import pandas as pd
14 |
15 |
16 | DS_FRMT = "%Y-%m-%d" # Format to use for parsing DS strings.
17 |
18 |
19 | def to_ds(ts: datetime) -> str:
20 | """Convert a `datetime` to a DS string."""
21 | return datetime.strftime(ts, DS_FRMT)
22 |
23 |
24 | def to_ts(ds: str) -> datetime:
25 | """Convert a DS string to a `datetime`."""
26 | return datetime.strptime(ds, DS_FRMT)
27 |
28 |
29 | def _ts_to_pandas(ts: int) -> pd.Timestamp:
30 | """Convert int timestamp into pandas timestamp."""
31 | return pd.Timestamp(datetime.fromtimestamp(ts))
32 |
33 |
34 | def _pandas_ts_to_int(ts: pd.Timestamp) -> int:
35 | """Convert int timestamp into pandas timestamp."""
36 | # pyre-fixme[7]: Expected `int` but got `float`.
37 | return ts.to_pydatetime().timestamp()
38 |
39 |
40 | def current_timestamp_in_millis() -> int:
41 | """Grab current timestamp in milliseconds as an int."""
42 | return int(round(time() * 1000))
43 |
44 |
45 | def timestamps_in_range(
46 | start: datetime, end: datetime, delta: timedelta
47 | ) -> Generator[datetime, None, None]:
48 | """Generator of timestamps in range [start, end], at intervals
49 | delta.
50 | """
51 | curr = start
52 | while curr <= end:
53 | yield curr
54 | curr += delta
55 |
56 |
57 | def unixtime_to_pandas_ts(ts: float) -> pd.Timestamp:
58 | """Convert float unixtime into pandas timestamp (UTC)."""
59 | return pd.to_datetime(ts, unit="s")
60 |
--------------------------------------------------------------------------------
/ax/utils/testing/metrics/backend_simulator_map.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from typing import Any
10 |
11 | from ax.core.base_trial import BaseTrial
12 | from ax.core.map_metric import MapMetricFetchResult
13 | from ax.metrics.noisy_function_map import NoisyFunctionMapMetric
14 |
15 |
16 | class BackendSimulatorTimestampMapMetric(NoisyFunctionMapMetric):
17 | """A metric that interfaces with an underlying ``BackendSimulator`` and
18 | returns timestamp map data."""
19 |
20 | def fetch_trial_data(
21 | self, trial: BaseTrial, noisy: bool = True, **kwargs: Any
22 | ) -> MapMetricFetchResult:
23 | """Fetch data for one trial."""
24 | backend_simulator = trial.experiment.runner.simulator # pyre-ignore[16]
25 | sim_trial = backend_simulator.get_sim_trial_by_index(trial.index)
26 | end_time = (
27 | backend_simulator.time
28 | if sim_trial.sim_completed_time is None
29 | else sim_trial.sim_completed_time
30 | )
31 | timestamps = self.convert_to_timestamps(
32 | start_time=sim_trial.sim_start_time, end_time=end_time
33 | )
34 | timestamp_kwargs = {"map_keys": ["timestamp"], "timestamp": timestamps}
35 | return NoisyFunctionMapMetric.fetch_trial_data(
36 | self, trial=trial, noisy=noisy, **kwargs, **timestamp_kwargs
37 | )
38 |
39 | def convert_to_timestamps(self, start_time: float, end_time: float) -> list[float]:
40 | """Given a starting and current time, get the list of intermediate
41 | timestamps at which we have observations."""
42 | raise NotImplementedError
43 |
--------------------------------------------------------------------------------
/ax/service/tests/test_early_stopping.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 |
9 | from ax.service.utils import early_stopping as early_stopping_utils
10 | from ax.utils.common.testutils import TestCase
11 | from ax.utils.testing.core_stubs import (
12 | DummyEarlyStoppingStrategy,
13 | get_branin_experiment,
14 | )
15 |
16 |
17 | class TestEarlyStoppingUtils(TestCase):
18 | """Testing the early stopping utilities functionality that is not tested in
19 | main `AxClient` testing suite (`TestServiceAPI`)."""
20 |
21 | def setUp(self) -> None:
22 | super().setUp()
23 | self.branin_experiment = get_branin_experiment()
24 |
25 | def test_should_stop_trials_early(self) -> None:
26 | expected: dict[int, str | None] = {
27 | 1: "Stopped due to testing.",
28 | 3: "Stopped due to testing.",
29 | }
30 | actual = early_stopping_utils.should_stop_trials_early(
31 | early_stopping_strategy=DummyEarlyStoppingStrategy(expected),
32 | # pyre-fixme[6]: For 2nd param expected `Set[int]` but got `List[int]`.
33 | trial_indices=[1, 2, 3],
34 | experiment=self.branin_experiment,
35 | )
36 | self.assertEqual(actual, expected)
37 |
38 | def test_should_stop_trials_early_no_strategy(self) -> None:
39 | actual = early_stopping_utils.should_stop_trials_early(
40 | early_stopping_strategy=None,
41 | # pyre-fixme[6]: For 2nd param expected `Set[int]` but got `List[int]`.
42 | trial_indices=[1, 2, 3],
43 | experiment=self.branin_experiment,
44 | )
45 | expected = {}
46 | self.assertEqual(actual, expected)
47 |
--------------------------------------------------------------------------------
/ax/utils/common/docutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | """Support functions for sphinx et. al"""
10 |
11 | from collections.abc import Callable
12 | from typing import Any, TypeVar
13 |
14 |
15 | _T = TypeVar("_T")
16 |
17 |
18 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
19 | # pyre-ignore[34]: T77127616
20 | def copy_doc(src: Callable[..., Any]) -> Callable[[_T], _T]:
21 | """A decorator that copies the docstring of another object
22 |
23 | Since ``sphinx`` actually loads the python modules to grab the docstrings
24 | this works with both ``sphinx`` and the ``help`` function.
25 |
26 | .. code:: python
27 |
28 | class Cat(Mamal):
29 |
30 | @property
31 | @copy_doc(Mamal.is_feline)
32 | def is_feline(self) -> true:
33 | ...
34 | """
35 | # It would be tempting to try to get the doc through the class the method
36 | # is bound to (via __self__) but decorators are called before __self__ is
37 | # assigned.
38 | # One other solution would be to use a decorator on classes that would fill
39 | # all the missing docstrings but we want to be able to detect syntactically
40 | # when docstrings are copied to keep things nice and simple
41 |
42 | if src.__doc__ is None:
43 | raise ValueError(f"{src.__qualname__} has no docstring to copy")
44 |
45 | def copy_doc(dst: _T) -> _T:
46 | if dst.__doc__ is not None:
47 | # pyre-fixme[16]: `_T` has no attribute `__qualname__`.
48 | raise ValueError(f"{dst.__qualname__} already has a docstring")
49 | dst.__doc__ = src.__doc__
50 | return dst
51 |
52 | return copy_doc
53 |
--------------------------------------------------------------------------------
/ax/exceptions/data_provider.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from collections.abc import Iterable
10 | from typing import Any
11 |
12 |
13 | class DataProviderError(Exception):
14 | """Base Exception for Ax DataProviders.
15 |
16 | The type of the data provider must be included.
17 | The raw error is stored in the data_provider_error section,
18 | and an Ax-friendly message is stored as the actual error message.
19 | """
20 |
21 | def __init__(
22 | self,
23 | message: str,
24 | data_provider: str,
25 | # pyre-fixme[2]: Parameter annotation cannot be `Any`.
26 | data_provider_error: Any,
27 | ) -> None:
28 | self.message = message
29 | self.data_provider = data_provider
30 | self.data_provider_error = data_provider_error
31 |
32 | def __str__(self) -> str:
33 | return (
34 | "{message}. \n Error thrown by: {dp} data provider \n"
35 | + "Native {dp} data provider error: {dp_error}"
36 | ).format(
37 | dp=self.data_provider,
38 | message=self.message,
39 | dp_error=self.data_provider_error,
40 | )
41 |
42 |
43 | class MissingDataError(Exception):
44 | def __init__(self, missing_trial_indexes: Iterable[int]) -> None:
45 | missing_trial_str = ", ".join([str(index) for index in missing_trial_indexes])
46 | self.message: str = (
47 | f"Unable to find data for the following trials: {missing_trial_str} "
48 | "consider updating the data fetching kwargs or manually fetching "
49 | "data via `refetch_data()`"
50 | )
51 |
52 | def __str__(self) -> str:
53 | return self.message
54 |
--------------------------------------------------------------------------------
/website/static/img/ax.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ax/models/tests/test_discrete.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import numpy as np
10 | from ax.models.discrete_base import DiscreteModel
11 | from ax.utils.common.testutils import TestCase
12 |
13 |
14 | class DiscreteModelTest(TestCase):
15 | def test_discrete_model_get_state(self) -> None:
16 | discrete_model = DiscreteModel()
17 | self.assertEqual(discrete_model._get_state(), {})
18 |
19 | def test_discrete_model_feature_importances(self) -> None:
20 | discrete_model = DiscreteModel()
21 | with self.assertRaises(NotImplementedError):
22 | discrete_model.feature_importances()
23 |
24 | def test_DiscreteModelFit(self) -> None:
25 | discrete_model = DiscreteModel()
26 | discrete_model.fit(
27 | Xs=[[[0]]],
28 | Ys=[[0]],
29 | Yvars=[[1]],
30 | parameter_values=[[0, 1]],
31 | outcome_names=[],
32 | )
33 |
34 | def test_discreteModelPredict(self) -> None:
35 | discrete_model = DiscreteModel()
36 | with self.assertRaises(NotImplementedError):
37 | discrete_model.predict([[0]])
38 |
39 | def test_discreteModelGen(self) -> None:
40 | discrete_model = DiscreteModel()
41 | with self.assertRaises(NotImplementedError):
42 | discrete_model.gen(
43 | n=1, parameter_values=[[0, 1]], objective_weights=np.array([1])
44 | )
45 |
46 | def test_discreteModelCrossValidate(self) -> None:
47 | discrete_model = DiscreteModel()
48 | with self.assertRaises(NotImplementedError):
49 | discrete_model.cross_validate(
50 | Xs_train=[[[0]]], Ys_train=[[1]], Yvars_train=[[1]], X_test=[[1]]
51 | )
52 |
--------------------------------------------------------------------------------
/ax/models/tests/test_randomforest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from ax.core.search_space import SearchSpaceDigest
11 | from ax.models.torch.randomforest import RandomForest
12 | from ax.utils.common.testutils import TestCase
13 | from botorch.utils.datasets import SupervisedDataset
14 |
15 |
16 | class RandomForestTest(TestCase):
17 | def test_RFModel(self) -> None:
18 | datasets = [
19 | SupervisedDataset(
20 | X=torch.rand(10, 2),
21 | Y=torch.rand(10, 1),
22 | Yvar=torch.rand(10, 1),
23 | feature_names=["x1", "x2"],
24 | outcome_names=[f"y{i}"],
25 | )
26 | for i in range(2)
27 | ]
28 |
29 | m = RandomForest(num_trees=5)
30 | m.fit(
31 | datasets=datasets,
32 | search_space_digest=SearchSpaceDigest(
33 | feature_names=["x1", "x2"],
34 | # pyre-fixme[6]: For 2nd param expected `List[Tuple[Union[float,
35 | # int], Union[float, int]]]` but got `List[Tuple[int, int]]`.
36 | bounds=[(0, 1)] * 2,
37 | ),
38 | )
39 | self.assertEqual(len(m.models), 2)
40 | # pyre-fixme[16]: `RandomForestRegressor` has no attribute `estimators_`.
41 | self.assertEqual(len(m.models[0].estimators_), 5)
42 |
43 | f, cov = m.predict(torch.rand(5, 2))
44 | self.assertEqual(f.shape, torch.Size((5, 2)))
45 | self.assertEqual(cov.shape, torch.Size((5, 2, 2)))
46 |
47 | f, cov = m.cross_validate(datasets=datasets, X_test=torch.rand(3, 2))
48 | self.assertEqual(f.shape, torch.Size((3, 2)))
49 | self.assertEqual(cov.shape, torch.Size((3, 2, 2)))
50 |
--------------------------------------------------------------------------------
/ax/models/discrete/eb_thompson.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import logging
10 |
11 | import numpy as np
12 | from ax.models.discrete.thompson import ThompsonSampler
13 | from ax.utils.common.logger import get_logger
14 | from ax.utils.stats.statstools import positive_part_james_stein
15 |
16 |
17 | logger: logging.Logger = get_logger(__name__)
18 |
19 |
20 | class EmpiricalBayesThompsonSampler(ThompsonSampler):
21 | """Generator for Thompson sampling using Empirical Bayes estimates.
22 |
23 | The generator applies positive-part James-Stein Estimator to the data
24 | passed in via `fit` and then performs Thompson Sampling.
25 | """
26 |
27 | def _fit_Ys_and_Yvars(
28 | self, Ys: list[list[float]], Yvars: list[list[float]], outcome_names: list[str]
29 | ) -> tuple[list[list[float]], list[list[float]]]:
30 | newYs = []
31 | newYvars = []
32 | for i, (Y, Yvar) in enumerate(zip(Ys, Yvars)):
33 | newY, newYvar = self._apply_shrinkage(Y, Yvar, i)
34 | newYs.append(newY)
35 | newYvars.append(newYvar)
36 | return newYs, newYvars
37 |
38 | def _apply_shrinkage(
39 | self, Y: list[float], Yvar: list[float], outcome: int
40 | ) -> tuple[list[float], list[float]]:
41 | npY = np.array(Y)
42 | npYvar = np.array(Yvar)
43 | npYsem = np.sqrt(Yvar)
44 | try:
45 | npY, npYsem = positive_part_james_stein(means=npY, sems=npYsem)
46 | except ValueError as e:
47 | logger.warning(
48 | str(e) + f" Raw (unshrunk) estimates used for outcome: {outcome}"
49 | )
50 | Y = npY.tolist()
51 | npYvar = npYsem**2
52 | Yvar = npYvar.tolist()
53 | return Y, Yvar
54 |
--------------------------------------------------------------------------------
/ax/runners/tests/test_single_running_trial_mixin.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | from ax.core.base_trial import TrialStatus
11 | from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin
12 | from ax.runners.synthetic import SyntheticRunner
13 | from ax.utils.common.testutils import TestCase
14 | from ax.utils.testing.core_stubs import get_branin_experiment
15 |
16 |
17 | class SyntheticRunnerWithSingleRunningTrial(SingleRunningTrialMixin, SyntheticRunner):
18 | pass
19 |
20 |
21 | class SingleRunningTrialMixinTest(TestCase):
22 | def test_single_running_trial_mixin(self) -> None:
23 | runner = SyntheticRunnerWithSingleRunningTrial()
24 | exp = get_branin_experiment(with_trial=True, with_batch=True)
25 | exp.runner = runner
26 | trials = exp.trials.values()
27 | for trial in trials:
28 | trial.assign_runner()
29 | trial.run()
30 | trial_statuses = runner.poll_trial_status(trials=trials)
31 | self.assertEqual(trial_statuses[TrialStatus.COMPLETED], {0})
32 | self.assertEqual(trial_statuses[TrialStatus.RUNNING], {1})
33 |
34 | def test_no_trials(self) -> None:
35 | runner = SyntheticRunnerWithSingleRunningTrial()
36 | trial_statuses = runner.poll_trial_status(trials=[])
37 | self.assertEqual(trial_statuses, {})
38 |
39 | def test_abandoned_trial(self) -> None:
40 | runner = SyntheticRunnerWithSingleRunningTrial()
41 | exp = get_branin_experiment(with_trial=True)
42 | exp.runner = runner
43 | trial = exp.trials[0]
44 | trial.assign_runner()
45 | trial.mark_abandoned()
46 | trial_statuses = runner.poll_trial_status(trials=[trial])
47 | self.assertEqual(trial_statuses, {})
48 |
--------------------------------------------------------------------------------
/ax/storage/json_store/save.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import json
10 | from collections.abc import Callable
11 | from typing import Any
12 |
13 | from ax.core.experiment import Experiment
14 | from ax.storage.json_store.encoder import object_to_json
15 | from ax.storage.json_store.registry import (
16 | CORE_CLASS_ENCODER_REGISTRY,
17 | CORE_ENCODER_REGISTRY,
18 | )
19 |
20 |
21 | def save_experiment(
22 | experiment: Experiment,
23 | filepath: str,
24 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
25 | # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
26 | # `typing.Type` to avoid runtime subscripting errors.
27 | encoder_registry: dict[
28 | type, Callable[[Any], dict[str, Any]]
29 | ] = CORE_ENCODER_REGISTRY,
30 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
31 | # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
32 | # `typing.Type` to avoid runtime subscripting errors.
33 | class_encoder_registry: dict[
34 | type, Callable[[Any], dict[str, Any]]
35 | ] = CORE_CLASS_ENCODER_REGISTRY,
36 | ) -> None:
37 | """Save experiment to file.
38 |
39 | 1) Convert Ax experiment to JSON-serializable dictionary.
40 | 2) Write to file.
41 | """
42 | if not isinstance(experiment, Experiment):
43 | raise ValueError("Can only save instances of Experiment")
44 |
45 | if not filepath.endswith(".json"):
46 | raise ValueError("Filepath must end in .json")
47 |
48 | json_experiment = object_to_json(
49 | experiment,
50 | encoder_registry=encoder_registry,
51 | class_encoder_registry=class_encoder_registry,
52 | )
53 | with open(filepath, "w+") as file:
54 | file.write(json.dumps(json_experiment))
55 |
--------------------------------------------------------------------------------
/ax/analysis/healthcheck/should_generate_candidates.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-unsafe
7 |
8 | import json
9 |
10 | import pandas as pd
11 | from ax.analysis.analysis import AnalysisCardLevel
12 |
13 | from ax.analysis.healthcheck.healthcheck_analysis import (
14 | HealthcheckAnalysis,
15 | HealthcheckAnalysisCard,
16 | HealthcheckStatus,
17 | )
18 | from ax.core.experiment import Experiment
19 | from ax.core.generation_strategy_interface import GenerationStrategyInterface
20 |
21 |
22 | class ShouldGenerateCandidates(HealthcheckAnalysis):
23 | def __init__(
24 | self,
25 | should_generate: bool,
26 | reason: str,
27 | trial_index: int,
28 | ) -> None:
29 | self.should_generate = should_generate
30 | self.reason = reason
31 | self.trial_index = trial_index
32 |
33 | def compute(
34 | self,
35 | experiment: Experiment | None = None,
36 | generation_strategy: GenerationStrategyInterface | None = None,
37 | ) -> HealthcheckAnalysisCard:
38 | status = (
39 | HealthcheckStatus.PASS
40 | if self.should_generate
41 | else HealthcheckStatus.WARNING
42 | )
43 | return HealthcheckAnalysisCard(
44 | name=self.name,
45 | title=f"Ready to Generate Candidates for Trial {self.trial_index}",
46 | blob=json.dumps(
47 | {
48 | "status": status,
49 | }
50 | ),
51 | subtitle=self.reason,
52 | df=pd.DataFrame(
53 | {
54 | "status": [status],
55 | "reason": [self.reason],
56 | }
57 | ),
58 | level=AnalysisCardLevel.CRITICAL,
59 | attributes=self.attributes,
60 | )
61 |
--------------------------------------------------------------------------------
/ax/models/random/uniform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import numpy as np
11 | import numpy.typing as npt
12 | from ax.models.random.base import RandomModel
13 |
14 |
15 | class UniformGenerator(RandomModel):
16 | """This class specifies a uniform random generation algorithm.
17 |
18 | As a uniform generator does not make use of a model, it does not implement
19 | the fit or predict methods.
20 |
21 | See base `RandomModel` for a description of model attributes.
22 | """
23 |
24 | def __init__(
25 | self,
26 | deduplicate: bool = True,
27 | seed: int | None = None,
28 | init_position: int = 0,
29 | generated_points: npt.NDArray | None = None,
30 | fallback_to_sample_polytope: bool = False,
31 | ) -> None:
32 | super().__init__(
33 | deduplicate=deduplicate,
34 | seed=seed,
35 | init_position=init_position,
36 | generated_points=generated_points,
37 | fallback_to_sample_polytope=fallback_to_sample_polytope,
38 | )
39 | self._rs = np.random.RandomState(seed=self.seed)
40 | if self.init_position > 0:
41 | # Fast-forward the random state by generating & discarding samples.
42 | self._rs.uniform(size=(self.init_position))
43 |
44 | def _gen_samples(self, n: int, tunable_d: int) -> npt.NDArray:
45 | """Generate samples from the scipy uniform distribution.
46 |
47 | Args:
48 | n: Number of samples to generate.
49 | tunable_d: Dimension of samples to generate.
50 |
51 | Returns:
52 | samples: An (n x d) array of random points.
53 |
54 | """
55 | self.init_position += n * tunable_d
56 | return self._rs.uniform(size=(n, tunable_d))
57 |
--------------------------------------------------------------------------------
/ax/service/utils/best_point_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.core.experiment import Experiment
10 | from pyre_extensions import none_throws
11 |
12 | BASELINE_ARM_NAME = "baseline_arm"
13 |
14 |
15 | def select_baseline_name_default_first_trial(
16 | experiment: Experiment, baseline_arm_name: str | None
17 | ) -> tuple[str, bool]:
18 | """
19 | Choose a baseline arm from arms on the experiment. Logic:
20 | 1. If ``baseline_arm_name`` provided, validate that arm exists
21 | and return that arm name.
22 | 2. If ``experiment.status_quo`` is set, return its arm name.
23 | 3. If there is at least one trial on the experiment, use the
24 | first trial's first arm as the baseline.
25 | 4. Error if 1-3 all don't apply.
26 |
27 | Returns:
28 | Tuple:
29 | baseline arm name (str)
30 | true when baseline selected from first arm of experiment (bool)
31 | raise ValueError if no valid baseline found
32 | """
33 |
34 | arms_dict = experiment.arms_by_name
35 |
36 | if baseline_arm_name:
37 | if baseline_arm_name not in arms_dict:
38 | raise ValueError(f"Arm by name {baseline_arm_name=} not found.")
39 | return baseline_arm_name, False
40 |
41 | if experiment.status_quo and none_throws(experiment.status_quo).name in arms_dict:
42 | baseline_arm_name = none_throws(experiment.status_quo).name
43 | return baseline_arm_name, False
44 |
45 | if (
46 | experiment.trials
47 | and experiment.trials[0].arms
48 | and experiment.trials[0].arms[0].name in arms_dict
49 | ):
50 | baseline_arm_name = experiment.trials[0].arms[0].name
51 | return baseline_arm_name, True
52 |
53 | else:
54 | raise ValueError("Could not find valid baseline arm.")
55 |
--------------------------------------------------------------------------------
/ax/preview/api/protocols/runner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 |
9 | from typing import Any, Mapping
10 |
11 | from ax.core.base_trial import TrialStatus
12 | from ax.preview.api.protocols.utils import _APIRunner
13 | from ax.preview.api.types import TParameterization
14 | from pyre_extensions import override
15 |
16 |
17 | class IRunner(_APIRunner):
18 | @override
19 | def run_trial(
20 | self, trial_index: int, parameterization: TParameterization
21 | ) -> dict[str, Any]:
22 | """
23 | Given an index and parameterization, run a trial and return a dictionary of any
24 | appropriate metadata. This metadata will be used to identify the trial when
25 | polling its status, stopping, fetching data, etc. This may hold information
26 | such as the trial's unique identifier on the system its running on, a
27 | directory where the trial is logging results to, etc.
28 |
29 | The metadata MUST be JSON-serializable (i.e. dict, list, str, int, float, bool,
30 | or None) so that Trials may be properly serialized in Ax.
31 | """
32 | ...
33 |
34 | @override
35 | def poll_trial(
36 | self, trial_index: int, trial_metadata: Mapping[str, Any]
37 | ) -> TrialStatus:
38 | """
39 | Given trial index and metadata, poll the status of the trial.
40 | """
41 | ...
42 |
43 | @override
44 | def stop_trial(
45 | self, trial_index: int, trial_metadata: Mapping[str, Any]
46 | ) -> dict[str, Any]:
47 | """
48 | Given trial index and metadata, stop the trial. Returns a dictionary of any
49 | appropriate metadata.
50 |
51 | The metadata MUST be JSON-serializable (i.e. dict, list, str, int, float, bool,
52 | or None) so that Trials may be properly serialized in Ax.
53 | """
54 | ...
55 |
--------------------------------------------------------------------------------
/ax/core/tests/test_map_metric.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from ax.core.map_metric import MapMetric
10 | from ax.utils.common.testutils import TestCase
11 | from ax.utils.testing.core_stubs import get_map_data
12 |
13 |
14 | METRIC_STRING = "MapMetric('m1')"
15 |
16 |
17 | class MapMetricTest(TestCase):
18 | def test_Init(self) -> None:
19 | metric = MapMetric(name="m1", lower_is_better=False)
20 | self.assertEqual(str(metric), METRIC_STRING)
21 |
22 | def test_Eq(self) -> None:
23 | metric1 = MapMetric(name="m1", lower_is_better=False)
24 | metric2 = MapMetric(name="m1", lower_is_better=False)
25 | self.assertEqual(metric1, metric2)
26 |
27 | metric3 = MapMetric(name="m1", lower_is_better=True)
28 | self.assertNotEqual(metric1, metric3)
29 |
30 | def test_Clone(self) -> None:
31 | metric1 = MapMetric(name="m1", lower_is_better=False)
32 | self.assertEqual(metric1, metric1.clone())
33 |
34 | def test_Sortable(self) -> None:
35 | metric1 = MapMetric(name="m1", lower_is_better=False)
36 | metric2 = MapMetric(name="m2", lower_is_better=False)
37 | self.assertTrue(metric1 < metric2)
38 |
39 | def test_WrapUnwrap(self) -> None:
40 | data = get_map_data()
41 |
42 | trial_multi = MapMetric._unwrap_trial_data_multi(
43 | results=MapMetric._wrap_trial_data_multi(data=data)
44 | )
45 | self.assertEqual(trial_multi, data)
46 |
47 | experiment = MapMetric._unwrap_experiment_data(
48 | results=MapMetric._wrap_experiment_data(data=data)
49 | )
50 | self.assertEqual(experiment, data)
51 |
52 | experiment_multi = MapMetric._unwrap_experiment_data_multi(
53 | results=MapMetric._wrap_experiment_data_multi(data=data)
54 | )
55 | self.assertEqual(experiment_multi, data)
56 |
--------------------------------------------------------------------------------
/sphinx/source/metrics.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.metrics
5 | ==========
6 |
7 | .. automodule:: ax.metrics
8 | .. currentmodule:: ax.metrics
9 |
10 |
11 | Branin
12 | ~~~~~~
13 |
14 | .. automodule:: ax.metrics.branin
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 | Branin Map
20 | ~~~~~~~~~~
21 |
22 | .. automodule:: ax.metrics.branin_map
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Chemistry
28 | ~~~~~~~~~
29 |
30 | .. automodule:: ax.metrics.chemistry
31 | :members:
32 | :undoc-members:
33 | :show-inheritance:
34 |
35 | Curve
36 | ~~~~~
37 |
38 | .. automodule:: ax.metrics.curve
39 | :members:
40 | :undoc-members:
41 | :show-inheritance:
42 |
43 | Factorial
44 | ~~~~~~~~~
45 |
46 | .. automodule:: ax.metrics.factorial
47 | :members:
48 | :undoc-members:
49 | :show-inheritance:
50 |
51 | Hartmann6
52 | ~~~~~~~~~
53 |
54 | .. automodule:: ax.metrics.hartmann6
55 | :members:
56 | :undoc-members:
57 | :show-inheritance:
58 |
59 | L2 Norm
60 | ~~~~~~~
61 |
62 | .. automodule:: ax.metrics.l2norm
63 | :members:
64 | :undoc-members:
65 | :show-inheritance:
66 |
67 | Noisy Functions
68 | ~~~~~~~~~~~~~~~
69 |
70 | .. automodule:: ax.metrics.noisy_function
71 | :members:
72 | :undoc-members:
73 | :show-inheritance:
74 |
75 | Noisy Function Map
76 | ~~~~~~~~~~~~~~~~~~
77 |
78 | .. automodule:: ax.metrics.noisy_function_map
79 | :members:
80 | :undoc-members:
81 | :show-inheritance:
82 |
83 | Sklearn
84 | ~~~~~~~
85 |
86 | .. automodule:: ax.metrics.sklearn
87 | :members:
88 | :undoc-members:
89 | :show-inheritance:
90 |
91 | Tensorboard
92 | ~~~~~~~~~~~
93 |
94 | .. automodule:: ax.metrics.tensorboard
95 | :members:
96 | :undoc-members:
97 | :show-inheritance:
98 |
99 |
100 | TorchX
101 | ~~~~~~
102 |
103 | .. automodule:: ax.metrics.torchx
104 | :members:
105 | :undoc-members:
106 | :show-inheritance:
107 |
--------------------------------------------------------------------------------
/ax/analysis/plotly/plotly_analysis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 |
9 | import pandas as pd
10 | from ax.analysis.analysis import Analysis, AnalysisCard
11 | from ax.core.experiment import Experiment
12 | from ax.core.generation_strategy_interface import GenerationStrategyInterface
13 | from IPython.display import display, Markdown
14 | from plotly import graph_objects as go, io as pio
15 |
16 |
17 | class PlotlyAnalysisCard(AnalysisCard):
18 | blob_annotation = "plotly"
19 |
20 | def get_figure(self) -> go.Figure:
21 | return pio.from_json(self.blob)
22 |
23 | def _ipython_display_(self) -> None:
24 | """
25 | IPython display hook. This is called when the AnalysisCard is printed in an
26 | IPython environment (ex. Jupyter). Here we want to display the Plotly figure.
27 | """
28 | display(Markdown(f"## {self.title}\n\n### {self.subtitle}"))
29 | display(self.get_figure())
30 |
31 |
32 | class PlotlyAnalysis(Analysis):
33 | """
34 | An Analysis that computes a Plotly figure.
35 | """
36 |
37 | def compute(
38 | self,
39 | experiment: Experiment | None = None,
40 | generation_strategy: GenerationStrategyInterface | None = None,
41 | ) -> PlotlyAnalysisCard: ...
42 |
43 | def _create_plotly_analysis_card(
44 | self,
45 | title: str,
46 | subtitle: str,
47 | level: int,
48 | df: pd.DataFrame,
49 | fig: go.Figure,
50 | ) -> PlotlyAnalysisCard:
51 | """
52 | Make a PlotlyAnalysisCard from this Analysis using provided fields and
53 | details about the Analysis class.
54 | """
55 | return PlotlyAnalysisCard(
56 | name=self.name,
57 | attributes=self.attributes,
58 | title=title,
59 | subtitle=subtitle,
60 | level=level,
61 | df=df,
62 | blob=pio.to_json(fig),
63 | )
64 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/1_bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: File a bug report.
3 | labels: ["bug"]
4 | title: "[Bug]: "
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Thank you for taking the time to fill out a bug report. We strive to make Ax a useful and stable library for all our users.
10 | - type: textarea
11 | id: what-happened
12 | attributes:
13 | label: What happened?
14 | description: Provide a detailed description of the bug as well as the expected behavior.
15 | validations:
16 | required: true
17 | - type: textarea
18 | id: repro
19 | attributes:
20 | label: Please provide a minimal, reproducible example of the unexpected behavior.
21 | description: Follow [these guidelines](https://stackoverflow.com/help/minimal-reproducible-example) for writing your example.
22 | validations:
23 | required: true
24 | - type: textarea
25 | id: traceback
26 | attributes:
27 | label: Please paste any relevant traceback/logs produced by the example provided.
28 | description: This will be automatically formatted into code, so no need for backticks.
29 | render: shell
30 | - type: input
31 | id: ax-version
32 | attributes:
33 | label: Ax Version
34 | description: What version of Ax are you using?
35 | validations:
36 | required: true
37 | - type: input
38 | id: python-version
39 | attributes:
40 | label: Python Version
41 | description: What version of Python are you using?
42 | validations:
43 | required: true
44 | - type: input
45 | id: os
46 | attributes:
47 | label: Operating System
48 | description: What operating system are you using?
49 | validations:
50 | required: true
51 | - type: checkboxes
52 | id: terms
53 | attributes:
54 | label: Code of Conduct
55 | description: By submitting this issue, you agree to follow Ax's [Code of Conduct](https://github.com/facebook/Ax/blob/main/CODE_OF_CONDUCT.md).
56 | options:
57 | - label: I agree to follow Ax's Code of Conduct
58 | required: true
59 |
--------------------------------------------------------------------------------
/ax/service/utils/early_stopping.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 |
9 | from ax.core.experiment import Experiment
10 | from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
11 | from pyre_extensions import none_throws
12 |
13 |
14 | def should_stop_trials_early(
15 | early_stopping_strategy: BaseEarlyStoppingStrategy | None,
16 | trial_indices: set[int],
17 | experiment: Experiment,
18 | ) -> dict[int, str | None]:
19 | """Evaluate whether to early-stop running trials.
20 |
21 | Args:
22 | early_stopping_strategy: A ``BaseEarlyStoppingStrategy`` that determines
23 | whether a trial should be stopped given the state of an experiment.
24 | trial_indices: Indices of trials to consider for early stopping.
25 | experiment: The experiment containing the trials.
26 |
27 | Returns:
28 | A dictionary mapping trial indices that should be early stopped to
29 | (optional) messages with the associated reason.
30 | """
31 | if early_stopping_strategy is None:
32 | return {}
33 |
34 | early_stopping_strategy = none_throws(early_stopping_strategy)
35 | return early_stopping_strategy.should_stop_trials_early(
36 | trial_indices=trial_indices, experiment=experiment
37 | )
38 |
39 |
40 | def get_early_stopping_metrics(
41 | experiment: Experiment, early_stopping_strategy: BaseEarlyStoppingStrategy | None
42 | ) -> list[str]:
43 | """A helper function that returns a list of metric names on which a given
44 | `early_stopping_strategy` is operating."""
45 | if early_stopping_strategy is None:
46 | return []
47 | if early_stopping_strategy.metric_names is not None:
48 | return list(early_stopping_strategy.metric_names)
49 | # TODO: generalize this to multi-objective ess
50 | default_objective, _ = early_stopping_strategy._default_objective_and_direction(
51 | experiment=experiment
52 | )
53 | return [default_objective]
54 |
--------------------------------------------------------------------------------
/sphinx/source/preview.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.preview
5 | ==========
6 |
7 | .. automodule:: ax.preview
8 | .. currentmodule:: ax.preview
9 |
10 | A preview of future Ax API
11 | --------------------------
12 |
13 |
14 | IMetric
15 | ~~~~~~~
16 |
17 | .. automodule:: ax.preview.api.protocols.metric
18 | :members:
19 | :undoc-members:
20 | :show-inheritance:
21 |
22 | IRunner
23 | ~~~~~~~
24 |
25 | .. automodule:: ax.preview.api.protocols.runner
26 | :members:
27 | :undoc-members:
28 | :show-inheritance:
29 |
30 |
31 | Utils
32 | ~~~~~~~
33 |
34 | .. automodule:: ax.preview.api.protocols.utils
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 |
40 | Client
41 | ~~~~~~
42 |
43 | .. automodule:: ax.preview.api.client
44 | :members:
45 | :undoc-members:
46 | :show-inheritance:
47 |
48 |
49 | Configs
50 | ~~~~~~~
51 |
52 | .. automodule:: ax.preview.api.configs
53 | :members:
54 | :undoc-members:
55 | :show-inheritance:
56 |
57 | Types
58 | ~~~~~
59 |
60 | .. automodule:: ax.preview.api.types
61 | :members:
62 | :undoc-members:
63 | :show-inheritance:
64 |
65 | From Config
66 | ~~~~~~~~~~~
67 |
68 | .. automodule:: ax.preview.api.utils.instantiation.from_config
69 | :members:
70 | :undoc-members:
71 | :show-inheritance:
72 |
73 | From String
74 | ~~~~~~~~~~~
75 |
76 | .. automodule:: ax.preview.api.utils.instantiation.from_string
77 | :members:
78 | :undoc-members:
79 | :show-inheritance:
80 |
81 |
82 | ModelBridge
83 | ~~~~~~~~~~~
84 |
85 | .. automodule:: ax.preview.modelbridge
86 | :members:
87 | :undoc-members:
88 | :show-inheritance:
89 |
90 | Dispatch Utils
91 | ~~~~~~~~~~~~~~
92 |
93 | .. automodule:: ax.preview.modelbridge.dispatch_utils
94 | :members:
95 | :undoc-members:
96 | :show-inheritance:
97 |
98 | Storage Utils
99 | ~~~~~~~~~~~~~
100 |
101 | .. automodule:: ax.preview.api.utils.storage
102 | :members:
103 | :undoc-members:
104 | :show-inheritance:
105 |
--------------------------------------------------------------------------------
/website/core/Footer.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | const React = require('react');
9 |
10 | class Footer extends React.Component {
11 |
12 | docUrl(doc, language) {
13 | const baseUrl = this.props.config.baseUrl;
14 | const docsUrl = this.props.config.docsUrl;
15 | const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`;
16 | const langPart = `${language ? `${language}/` : ''}`;
17 | return `${baseUrl}${docsPart}${langPart}${doc}`;
18 | }
19 |
20 | pageUrl(doc, language) {
21 | const baseUrl = this.props.config.baseUrl;
22 | return baseUrl + (language ? `${language}/` : '') + doc;
23 | }
24 |
25 | render() {
26 | const currentYear = new Date().getFullYear();
27 |
28 | return (
29 |
63 | );
64 | }
65 | }
66 |
67 | module.exports = Footer;
68 |
--------------------------------------------------------------------------------
/ax/core/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # flake8: noqa F401
10 | from ax.core.arm import Arm
11 | from ax.core.batch_trial import BatchTrial
12 | from ax.core.data import Data
13 | from ax.core.experiment import Experiment
14 | from ax.core.generator_run import GeneratorRun
15 | from ax.core.metric import Metric
16 | from ax.core.objective import MultiObjective, Objective
17 | from ax.core.observation import ObservationFeatures
18 | from ax.core.optimization_config import (
19 | MultiObjectiveOptimizationConfig,
20 | OptimizationConfig,
21 | )
22 | from ax.core.outcome_constraint import (
23 | ComparisonOp,
24 | ObjectiveThreshold,
25 | OutcomeConstraint,
26 | )
27 | from ax.core.parameter import (
28 | ChoiceParameter,
29 | FixedParameter,
30 | Parameter,
31 | ParameterType,
32 | RangeParameter,
33 | )
34 | from ax.core.parameter_constraint import (
35 | OrderConstraint,
36 | ParameterConstraint,
37 | SumConstraint,
38 | )
39 | from ax.core.parameter_distribution import ParameterDistribution
40 | from ax.core.risk_measures import RiskMeasure
41 | from ax.core.runner import Runner
42 | from ax.core.search_space import SearchSpace
43 | from ax.core.trial import Trial
44 |
45 |
46 | __all__ = [
47 | "Arm",
48 | "BatchTrial",
49 | "ChoiceParameter",
50 | "ComparisonOp",
51 | "Data",
52 | "Experiment",
53 | "FixedParameter",
54 | "GeneratorRun",
55 | "Metric",
56 | "MultiObjective",
57 | "MultiObjectiveOptimizationConfig",
58 | "Objective",
59 | "ObjectiveThreshold",
60 | "ObservationFeatures",
61 | "OptimizationConfig",
62 | "OrderConstraint",
63 | "OutcomeConstraint",
64 | "Parameter",
65 | "ParameterConstraint",
66 | "ParameterDistribution",
67 | "ParameterType",
68 | "RangeParameter",
69 | "RiskMeasure",
70 | "Runner",
71 | "SearchSpace",
72 | "SimpleExperiment",
73 | "SumConstraint",
74 | "Trial",
75 | ]
76 |
--------------------------------------------------------------------------------
/ax/utils/common/tests/test_typeutils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import numpy as np
11 | from ax.utils.common.testutils import TestCase
12 | from ax.utils.common.typeutils import (
13 | checked_cast,
14 | checked_cast_dict,
15 | checked_cast_list,
16 | checked_cast_optional,
17 | )
18 | from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type
19 |
20 |
21 | class TestTypeUtils(TestCase):
22 | def test_checked_cast(self) -> None:
23 | self.assertEqual(checked_cast(float, 2.0), 2.0)
24 | with self.assertRaises(ValueError):
25 | checked_cast(float, 2)
26 |
27 | def test_checked_cast_with_error_override(self) -> None:
28 | self.assertEqual(checked_cast(float, 2.0), 2.0)
29 | with self.assertRaises(NotImplementedError):
30 | checked_cast(
31 | float, 2, exception=NotImplementedError("foo() doesn't support ints")
32 | )
33 |
34 | def test_checked_cast_list(self) -> None:
35 | self.assertEqual(checked_cast_list(float, [1.0, 2.0]), [1.0, 2.0])
36 | with self.assertRaises(ValueError):
37 | checked_cast_list(float, [1.0, 2])
38 |
39 | def test_checked_cast_optional(self) -> None:
40 | self.assertEqual(checked_cast_optional(float, None), None)
41 | with self.assertRaises(ValueError):
42 | checked_cast_optional(float, 2)
43 |
44 | def test_checked_cast_dict(self) -> None:
45 | self.assertEqual(checked_cast_dict(str, int, {"some": 1}), {"some": 1})
46 | with self.assertRaises(ValueError):
47 | checked_cast_dict(str, int, {"some": 1.0})
48 | with self.assertRaises(ValueError):
49 | checked_cast_dict(str, int, {1: 1})
50 |
51 | def test_numpy_type_to_python_type(self) -> None:
52 | self.assertEqual(type(numpy_type_to_python_type(np.int64(2))), int)
53 | self.assertEqual(type(numpy_type_to_python_type(np.float64(2))), float)
54 |
--------------------------------------------------------------------------------
/website/static/img/ax_lockup.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/website/static/img/ax_logo_lockup.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/website/static/img/ax_lockup_white.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/why-ax.md:
--------------------------------------------------------------------------------
1 | ---
2 | id: why-ax
3 | title: Why Ax?
4 | sidebar_label: Why Ax?
5 | ---
6 |
7 | Developers and researchers alike face problems which confront them with a large space of possible ways to configure something –– whether those are "magic numbers" used for infrastructure or compiler flags, learning rates or other hyperparameters in machine learning, or images and calls-to-action used in marketing promotions. Selecting and tuning these configurations can often take time, resources, and can affect the quality of user experiences. Ax is a machine learning system to help automate this process, so that researchers and developers can determine how to get the most out of their software in an optimally efficient way.
8 |
9 | Ax is a platform for optimizing any kind of experiment, including machine learning experiments, A/B tests, and simulations. Ax can optimize discrete configurations (e.g., variants of an A/B test) using multi-armed bandit optimization, and continuous (e.g., integer or floating point)-valued configurations using Bayesian optimization. This makes it suitable for a wide range of applications.
10 |
11 | Ax has been successfully applied to a variety of product, infrastructure, ML, and research applications at Facebook.
12 |
13 | # Unique capabilities
14 | - **Support for noisy functions**. Results of A/B tests and simulations with reinforcement learning agents often exhibit high amounts of noise. Ax supports [state-of-the-art algorithms](https://research.facebook.com/blog/2018/09/efficient-tuning-of-online-systems-using-bayesian-optimization/) which work better than traditional Bayesian optimization in high-noise settings.
15 | - **Customization**. Ax's developer API makes it easy to integrate custom data modeling and decision algorithms. This allows developers to build their own custom optimization services with minimal overhead.
16 | - **Multi-modal experimentation**. Ax has first-class support for running and combining data from different types of experiments, such as "offline" simulation data and "online" data from real-world experiments.
17 | - **Multi-objective optimization**. Ax supports multi-objective and constrained optimization which are common to real-world problems, like improving load time without increasing data use.
18 |
--------------------------------------------------------------------------------
/ax/analysis/plotly/surface/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 | import math
8 |
9 | import numpy as np
10 | from ax.core.parameter import (
11 | ChoiceParameter,
12 | FixedParameter,
13 | Parameter,
14 | RangeParameter,
15 | TParamValue,
16 | )
17 |
18 |
19 | def get_parameter_values(parameter: Parameter, density: int = 100) -> list[TParamValue]:
20 | """
21 | Get a list of parameter values to predict over for a given parameter.
22 | """
23 |
24 | # For RangeParameter use linspace for the range of the parameter
25 | if isinstance(parameter, RangeParameter):
26 | if parameter.log_scale:
27 | return np.logspace(
28 | math.log10(parameter.lower), math.log10(parameter.upper), density
29 | ).tolist()
30 |
31 | return np.linspace(parameter.lower, parameter.upper, density).tolist()
32 |
33 | # For ChoiceParameter use the values of the parameter directly
34 | if isinstance(parameter, ChoiceParameter) and parameter.is_ordered:
35 | return parameter.values
36 |
37 | raise ValueError(
38 | f"Parameter {parameter.name} must be a RangeParameter or "
39 | "ChoiceParameter with is_ordered=True to be used in surface plot."
40 | )
41 |
42 |
43 | def select_fixed_value(parameter: Parameter) -> TParamValue:
44 | """
45 | Select a fixed value for a parameter. Use mean for RangeParameter, "middle" value
46 | for ChoiceParameter, and value for FixedParameter.
47 | """
48 | if isinstance(parameter, RangeParameter):
49 | return (parameter.lower * 1.0 + parameter.upper) / 2
50 | elif isinstance(parameter, ChoiceParameter):
51 | return parameter.values[len(parameter.values) // 2]
52 | elif isinstance(parameter, FixedParameter):
53 | return parameter.value
54 | else:
55 | raise ValueError(f"Got unexpected parameter type {parameter}.")
56 |
57 |
58 | def is_axis_log_scale(parameter: Parameter) -> bool:
59 | """
60 | Check if the parameter is log scale.
61 | """
62 | return isinstance(parameter, RangeParameter) and parameter.log_scale
63 |
--------------------------------------------------------------------------------
/ax/benchmark/problems/hd_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | import copy
9 | from typing import TypeVar
10 |
11 | from ax.benchmark.benchmark_problem import BenchmarkProblem
12 | from ax.core.parameter import ParameterType, RangeParameter
13 | from ax.core.search_space import SearchSpace
14 |
15 | TProblem = TypeVar("TProblem", bound=BenchmarkProblem)
16 |
17 |
18 | def embed_higher_dimension(problem: TProblem, total_dimensionality: int) -> TProblem:
19 | """
20 | Return a new `BenchmarkProblem` with enough `RangeParameter`s added to the
21 | search space to make its total dimensionality equal to `total_dimensionality`
22 | and add `total_dimensionality` to its name.
23 |
24 | The search space of the original `problem` is within the search space of the
25 | new problem, and the constraints are copied from the original problem.
26 | """
27 | num_dummy_dimensions = total_dimensionality - len(problem.search_space.parameters)
28 |
29 | search_space = SearchSpace(
30 | parameters=[
31 | *problem.search_space.parameters.values(),
32 | *[
33 | RangeParameter(
34 | name=f"embedding_dummy_{i}",
35 | parameter_type=ParameterType.FLOAT,
36 | lower=0,
37 | upper=1,
38 | )
39 | for i in range(num_dummy_dimensions)
40 | ],
41 | ],
42 | parameter_constraints=problem.search_space.parameter_constraints,
43 | )
44 |
45 | # if problem name already has dimensionality in it, strip it
46 | def _is_dim_suffix(s: str) -> bool:
47 | return s[-1] == "d" and all(char in "0123456789" for char in s[:-1])
48 |
49 | orig_name_without_dimensionality = "_".join(
50 | [substr for substr in problem.name.split("_") if not _is_dim_suffix(substr)]
51 | )
52 | new_name = f"{orig_name_without_dimensionality}_{total_dimensionality}d"
53 |
54 | new_problem = copy.copy(problem)
55 | new_problem.name = new_name
56 | new_problem.search_space = search_space
57 | return new_problem
58 |
--------------------------------------------------------------------------------
/ax/utils/common/base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from __future__ import annotations
10 |
11 | import abc
12 |
13 | from ax.utils.common.equality import equality_typechecker, object_attribute_dicts_equal
14 |
15 |
16 | class Base:
17 | """Metaclass for core Ax classes. Provides an equality check and `db_id`
18 | property for SQA storage.
19 | """
20 |
21 | _db_id: int | None = None
22 |
23 | @property
24 | def db_id(self) -> int | None:
25 | return self._db_id
26 |
27 | @db_id.setter
28 | def db_id(self, db_id: int) -> None:
29 | self._db_id = db_id
30 |
31 | @equality_typechecker
32 | def __eq__(self, other: Base) -> bool:
33 | return object_attribute_dicts_equal(
34 | one_dict=self.__dict__, other_dict=other.__dict__
35 | )
36 |
37 | @equality_typechecker
38 | def _eq_skip_db_id_check(self, other: Base) -> bool:
39 | return object_attribute_dicts_equal(
40 | one_dict=self.__dict__, other_dict=other.__dict__, skip_db_id_check=True
41 | )
42 |
43 |
44 | class SortableBase(Base, metaclass=abc.ABCMeta):
45 | """Extension to the base class that also provides an inequality check."""
46 |
47 | @property
48 | @abc.abstractmethod
49 | def _unique_id(self) -> str:
50 | """Returns an identification string that can be used to uniquely
51 | identify this instance from others attached to the same parent
52 | object. For example, for ``Trials`` this can be their index,
53 |
54 | since that is unique w.r.t. to parent ``Experiment`` object.
55 | For ``GenerationNode``-s attached to a ``GenerationStrategy``,
56 | this can be their name since we ensure uniqueness of it upon
57 | ``GenerationStrategy`` instantiation.
58 |
59 | This method is needed to correctly update SQLAlchemy objects
60 | that appear as children of other objects, in lists or other
61 | sortable collections or containers.
62 | """
63 | pass
64 |
65 | def __lt__(self, other: SortableBase) -> bool:
66 | return self._unique_id < other._unique_id
67 |
--------------------------------------------------------------------------------
/.github/workflows/deploy.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Deploy
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | tests-and-coverage-latest:
12 | name: Tests with latest BoTorch
13 | uses: ./.github/workflows/reusable_test.yml
14 | with:
15 | pinned_botorch: false
16 | secrets: inherit
17 |
18 | tests-and-coverage-pinned:
19 | name: Tests with pinned BoTorch
20 | uses: ./.github/workflows/reusable_test.yml
21 | with:
22 | pinned_botorch: true
23 | secrets: inherit
24 |
25 | publish-stable-website:
26 |
27 | needs: tests-and-coverage-pinned # only run if test step succeeds
28 | runs-on: ubuntu-latest
29 |
30 | steps:
31 | - uses: actions/checkout@v4
32 | - name: Set up Python
33 | uses: actions/setup-python@v5
34 | with:
35 | python-version: "3.10"
36 | - name: Install dependencies
37 | run: |
38 | # use stable Botorch
39 | pip install -e ".[tutorial]"
40 | - name: Publish latest website
41 | env:
42 | DOCUSAURUS_PUBLISH_TOKEN: ${{ secrets.DOCUSAURUS_PUBLISH_TOKEN }}
43 | run: |
44 | bash scripts/publish_site.sh -d -v ${{ github.event.release.tag_name }}
45 |
46 | deploy:
47 |
48 | needs: tests-and-coverage-pinned # only run if test step succeeds
49 | runs-on: ubuntu-latest
50 |
51 | steps:
52 | - uses: actions/checkout@v4
53 | - name: Set up Python
54 | uses: actions/setup-python@v5
55 | with:
56 | python-version: "3.10"
57 | - name: Install dependencies
58 | run: |
59 | # use stable Botorch
60 | pip install -e ".[dev,mysql,notebook]"
61 | pip install --upgrade build setuptools setuptools_scm wheel
62 | - name: Fetch all history for all tags and branches
63 | run: git fetch --prune --unshallow
64 | - name: Build wheel
65 | run: |
66 | python -m build --sdist --wheel
67 | - name: Deploy to PyPI
68 | uses: pypa/gh-action-pypi-publish@release/v1
69 | with:
70 | user: __token__
71 | password: ${{ secrets.PYPI_TOKEN }}
72 | verbose: true
73 |
--------------------------------------------------------------------------------
/ax/plot/tests/test_diagnostic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import plotly.graph_objects as go
10 | from ax.modelbridge.cross_validation import cross_validate
11 | from ax.modelbridge.registry import Models
12 | from ax.plot.base import AxPlotConfig
13 | from ax.plot.diagnostic import (
14 | interact_cross_validation,
15 | interact_cross_validation_plotly,
16 | )
17 | from ax.utils.common.testutils import TestCase
18 | from ax.utils.testing.core_stubs import get_branin_experiment
19 | from ax.utils.testing.mock import mock_botorch_optimize
20 |
21 |
22 | class DiagnosticTest(TestCase):
23 | @mock_botorch_optimize
24 | def setUp(self) -> None:
25 | super().setUp()
26 | exp = get_branin_experiment(with_batch=True)
27 | exp.trials[0].run()
28 | self.model = Models.BOTORCH_MODULAR(
29 | # Model bridge kwargs
30 | experiment=exp,
31 | data=exp.fetch_data(),
32 | )
33 |
34 | def test_cross_validation(self) -> None:
35 | for autoset_axis_limits in [False, True]:
36 | cv = cross_validate(self.model)
37 | # Assert that each type of plot can be constructed successfully
38 | label_dict = {"branin": "BrAnIn"}
39 | plot = interact_cross_validation_plotly(
40 | cv, label_dict=label_dict, autoset_axis_limits=autoset_axis_limits
41 | )
42 | x_range = plot.layout.updatemenus[0].buttons[0].args[1]["xaxis.range"]
43 | y_range = plot.layout.updatemenus[0].buttons[0].args[1]["yaxis.range"]
44 | if autoset_axis_limits:
45 | self.assertTrue((len(x_range) == 2) and (x_range[0] < x_range[1]))
46 | self.assertTrue((len(y_range) == 2) and (y_range[0] < y_range[1]))
47 | else:
48 | self.assertIsNone(x_range)
49 | self.assertIsNone(y_range)
50 |
51 | self.assertIsInstance(plot, go.Figure)
52 | plot = interact_cross_validation(
53 | cv, label_dict=label_dict, autoset_axis_limits=autoset_axis_limits
54 | )
55 | self.assertIsInstance(plot, AxPlotConfig)
56 |
--------------------------------------------------------------------------------
/website/tutorials.json:
--------------------------------------------------------------------------------
1 | {
2 | "API Comparison": [
3 | {
4 | "id": "gpei_hartmann_service",
5 | "title": "[RECOMMENDED] Service API"
6 | },
7 | {
8 | "id": "gpei_hartmann_loop",
9 | "title": "Loop API"
10 | },
11 | {
12 | "id": "gpei_hartmann_developer",
13 | "title": "Developer API"
14 | }
15 | ],
16 | "Deep Dives": [
17 | {
18 | "id": "visualizations",
19 | "title": "Visualizations"
20 | },
21 | {
22 | "id": "generation_strategy",
23 | "title": "Generation Strategy"
24 | },
25 | {
26 | "id": "scheduler",
27 | "title": "Scheduler"
28 | },
29 | {
30 | "id": "modular_botax",
31 | "title": "Modular `BoTorchModel`"
32 | }
33 | ],
34 | "Bayesian Optimization": [
35 | {
36 | "id": "tune_cnn_service",
37 | "title": "Hyperparameter Optimization for PyTorch"
38 | },
39 | {
40 | "id": "submitit",
41 | "title": "Hyperparameter Optimization on SLURM via SubmitIt"
42 | },
43 | {
44 | "id": "multi_task",
45 | "title": "Multi-Task Modeling"
46 | },
47 | {
48 | "id": "multiobjective_optimization",
49 | "title": "Multi-Objective Optimization"
50 | },
51 | {
52 | "id": "saasbo",
53 | "title": "High-Dimensional Bayesian Optimization with Sparse Axis-Aligned Subspaces (SAASBO)"
54 | },
55 | {
56 | "id": "saasbo_nehvi",
57 | "title": "Fully Bayesian, High-Dimensional, Multi-Objective Optimization"
58 | },
59 | {
60 | "id": "sebo",
61 | "title": "Sparsity Exploration Bayesian Optimization (SEBO)"
62 | },
63 | {
64 | "dir": "early_stopping",
65 | "id": "early_stopping",
66 | "title": "Trial-Level Early Stopping"
67 | },
68 | {
69 | "id": "gss",
70 | "title": "Global Stopping (Experiment-Level Early Stopping)"
71 | }
72 | ],
73 | "Field Experiments": [
74 | {
75 | "id": "factorial",
76 | "title": "Bandit Optimization"
77 | },
78 | {
79 | "dir": "human_in_the_loop",
80 | "id": "human_in_the_loop",
81 | "title": "Human-in-the-Loop Optimization"
82 | }
83 | ],
84 | "Integrating External Strategies": [
85 | {
86 | "id": "external_generation_node",
87 | "title": "RandomForest with ExternalGenerationNode"
88 | }
89 | ]
90 | }
91 |
--------------------------------------------------------------------------------
/ax/storage/tests/test_botorch_modular_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # pyre-strict
7 |
8 | from ax.models.torch.botorch_modular.acquisition import Acquisition
9 | from ax.storage.botorch_modular_registry import (
10 | ACQUISITION_FUNCTION_REGISTRY,
11 | ACQUISITION_REGISTRY,
12 | MODEL_REGISTRY,
13 | register_acquisition,
14 | register_acquisition_function,
15 | register_model,
16 | REVERSE_ACQUISITION_FUNCTION_REGISTRY,
17 | REVERSE_ACQUISITION_REGISTRY,
18 | REVERSE_MODEL_REGISTRY,
19 | )
20 | from ax.utils.common.testutils import TestCase
21 | from botorch.acquisition.acquisition import AcquisitionFunction
22 | from botorch.models.model import Model
23 |
24 |
25 | class NewModel(Model):
26 | pass
27 |
28 |
29 | class NewAcquisition(Acquisition):
30 | pass
31 |
32 |
33 | class NewAcquisitionFunction(AcquisitionFunction):
34 | pass
35 |
36 |
37 | class RegisterNewClassTest(TestCase):
38 | def test_register_model(self) -> None:
39 | self.assertNotIn(NewModel, MODEL_REGISTRY)
40 | self.assertNotIn(NewModel, REVERSE_MODEL_REGISTRY.values())
41 | register_model(NewModel)
42 | self.assertIn(NewModel, MODEL_REGISTRY)
43 | self.assertIn(NewModel, REVERSE_MODEL_REGISTRY.values())
44 |
45 | def test_register_acquisition(self) -> None:
46 | self.assertNotIn(NewAcquisition, ACQUISITION_REGISTRY)
47 | self.assertNotIn(NewAcquisition, REVERSE_ACQUISITION_REGISTRY.values())
48 | register_acquisition(NewAcquisition)
49 | self.assertIn(NewAcquisition, ACQUISITION_REGISTRY)
50 | self.assertIn(NewAcquisition, REVERSE_ACQUISITION_REGISTRY.values())
51 |
52 | def test_register_acquisition_function(self) -> None:
53 | self.assertNotIn(NewAcquisitionFunction, ACQUISITION_FUNCTION_REGISTRY)
54 | self.assertNotIn(
55 | NewAcquisitionFunction, REVERSE_ACQUISITION_FUNCTION_REGISTRY.values()
56 | )
57 | register_acquisition_function(NewAcquisitionFunction)
58 | self.assertIn(NewAcquisitionFunction, ACQUISITION_FUNCTION_REGISTRY)
59 | self.assertIn(
60 | NewAcquisitionFunction, REVERSE_ACQUISITION_FUNCTION_REGISTRY.values()
61 | )
62 |
--------------------------------------------------------------------------------
/sphinx/source/service.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | ax.service
5 | ==========
6 |
7 | .. automodule:: ax.service
8 | .. currentmodule:: ax.service
9 |
10 |
11 | Ax Client
12 | ~~~~~~~~~
13 |
14 | .. automodule:: ax.service.ax_client
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 | Managed Loop
20 | ~~~~~~~~~~~~
21 |
22 | .. automodule:: ax.service.managed_loop
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Interactive Loop
28 | ~~~~~~~~~~~~~~~~
29 |
30 | .. automodule:: ax.service.interactive_loop
31 | :members:
32 | :undoc-members:
33 | :show-inheritance:
34 |
35 |
36 | Scheduler
37 | ~~~~~~~~~
38 |
39 | .. automodule:: ax.service.scheduler
40 | :members:
41 | :undoc-members:
42 | :show-inheritance:
43 |
44 | .. automodule:: ax.service.utils.scheduler_options
45 | :members:
46 | :undoc-members:
47 | :show-inheritance:
48 |
49 | Utils
50 | -----
51 |
52 | Analysis
53 | ~~~~~~~~
54 |
55 | .. automodule:: ax.service.utils.analysis_base
56 | :members:
57 | :undoc-members:
58 | :show-inheritance:
59 |
60 |
61 | Best Point Identification
62 | ~~~~~~~~~~~~~~~~~~~~~~~~~
63 |
64 | .. automodule:: ax.service.utils.best_point_mixin
65 | :members:
66 | :undoc-members:
67 | :show-inheritance:
68 |
69 |
70 | .. automodule:: ax.service.utils.best_point
71 | :members:
72 | :undoc-members:
73 | :show-inheritance:
74 |
75 |
76 | .. automodule:: ax.service.utils.best_point_utils
77 | :members:
78 | :undoc-members:
79 | :show-inheritance:
80 |
81 |
82 | Instantiation
83 | ~~~~~~~~~~~~~
84 |
85 | .. automodule:: ax.service.utils.instantiation
86 | :members:
87 | :undoc-members:
88 | :show-inheritance:
89 |
90 |
91 | Reporting
92 | ~~~~~~~~~
93 |
94 | .. automodule:: ax.service.utils.report_utils
95 | :members:
96 | :undoc-members:
97 | :show-inheritance:
98 |
99 |
100 | WithDBSettingsBase
101 | ~~~~~~~~~~~~~~~~~~
102 |
103 | .. automodule:: ax.service.utils.with_db_settings_base
104 | :members:
105 | :undoc-members:
106 | :show-inheritance:
107 |
108 |
109 | EarlyStopping
110 | ~~~~~~~~~~~~~
111 |
112 | .. automodule:: ax.service.utils.early_stopping
113 | :members:
114 | :undoc-members:
115 | :show-inheritance:
116 |
--------------------------------------------------------------------------------