├── .DS_Store ├── requirements.txt ├── benchmarks ├── .DS_Store ├── results │ └── .DS_Store ├── __pycache__ │ ├── benchmark_gs.cpython-38.pyc │ ├── sebo_gs_factory.cpython-38.pyc │ ├── baseline_gs_factory.cpython-38.pyc │ ├── external_l1_gs_factory.cpython-38.pyc │ └── internal_l1_gs_factory.cpython-38.pyc ├── baseline_gs_factory.py ├── sebo_gs_factory.py ├── ir_er_l0_gs_factory.py ├── internal_l1_gs_factory.py ├── benchmark_gs.py ├── external_l1_gs_factory.py ├── run_synthetic_benchmark.py └── regularized_bo.py ├── README.md ├── LICENSE ├── CONTRIBUTING.md ├── .gitignore ├── CODE_OF_CONDUCT.md └── sebo.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SparseBO/HEAD/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | scipy 4 | botorch==0.9.2 5 | ax-platform==0.3.4 6 | -------------------------------------------------------------------------------- /benchmarks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SparseBO/HEAD/benchmarks/.DS_Store -------------------------------------------------------------------------------- /benchmarks/results/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SparseBO/HEAD/benchmarks/results/.DS_Store -------------------------------------------------------------------------------- /benchmarks/__pycache__/benchmark_gs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SparseBO/HEAD/benchmarks/__pycache__/benchmark_gs.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/__pycache__/sebo_gs_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SparseBO/HEAD/benchmarks/__pycache__/sebo_gs_factory.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/__pycache__/baseline_gs_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SparseBO/HEAD/benchmarks/__pycache__/baseline_gs_factory.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/__pycache__/external_l1_gs_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SparseBO/HEAD/benchmarks/__pycache__/external_l1_gs_factory.cpython-38.pyc -------------------------------------------------------------------------------- /benchmarks/__pycache__/internal_l1_gs_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SparseBO/HEAD/benchmarks/__pycache__/internal_l1_gs_factory.cpython-38.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparseBO 2 | Code associated with paper ["Sparse Bayesian Optimization"](https://arxiv.org/abs/2203.01900) 3 | 4 | ## Installation 5 | To install the code clone the repo and install the dependencies as 6 | 7 | ``` 8 | git clone https://github.com/facebookresearch/SparseBO.git 9 | cd SparseBO 10 | python3 -m pip install -r requirements.txt 11 | ``` 12 | 13 | ## Reproducing the experiments 14 | This repository contains the code required to run the numerical experiments and the contextual Adaptive Bitrate (ABR) video playback experiment in the paper. 15 | 16 | ## License 17 | This code is MIT Licensed, as found in the LICENSE file. 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SparseBO 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to SparseBO, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # Pycharm 92 | .idea/ 93 | 94 | .DS_Store 95 | benchmarks/.DS_Store 96 | 97 | # Projects 98 | benchmarks/results/*.json 99 | -------------------------------------------------------------------------------- /benchmarks/baseline_gs_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from ax.core.data import Data 11 | from ax.core.experiment import Experiment 12 | from ax.modelbridge.factory import get_sobol 13 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 14 | from ax.modelbridge.registry import Cont_X_trans, Models, Y_trans 15 | from ax.modelbridge.torch import TorchModelBridge 16 | 17 | TORCH_DEVICE = ( 18 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 19 | ) 20 | 21 | 22 | def get_saasbo_gs( 23 | num_sobol_trials: int, torch_device: Optional[torch.device] = TORCH_DEVICE 24 | ) -> GenerationStrategy: 25 | gs = GenerationStrategy( 26 | name="SAASBO", 27 | steps=[ 28 | GenerationStep(model=get_sobol, num_trials=num_sobol_trials), 29 | GenerationStep( 30 | model=get_SAASBO, 31 | num_trials=-1, 32 | model_kwargs={"torch_device": torch_device}, 33 | ), 34 | ], 35 | ) 36 | return gs 37 | 38 | 39 | def get_SAASBO( 40 | experiment: Experiment, 41 | data: Data, 42 | torch_device: Optional[torch.device] = TORCH_DEVICE, 43 | ) -> TorchModelBridge: 44 | """Instantiates a SAASBO model for single objective optimization.""" 45 | return Models.FULLYBAYESIAN( 46 | num_samples=256, 47 | warmup_steps=512, 48 | disable_progbar=True, 49 | experiment=experiment, 50 | data=data, 51 | search_space=experiment.search_space, 52 | transforms=Cont_X_trans + Y_trans, 53 | torch_dtype=torch.double, 54 | torch_device=torch_device, 55 | ) 56 | -------------------------------------------------------------------------------- /benchmarks/sebo_gs_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 11 | from ax.modelbridge.registry import Models 12 | from ax.models.torch.botorch_modular.sebo import SEBOAcquisition 13 | from ax.models.torch.botorch_modular.surrogate import Surrogate 14 | from botorch.acquisition.multi_objective.monte_carlo import ( 15 | qNoisyExpectedHypervolumeImprovement, 16 | ) 17 | from botorch.models import FixedNoiseGP, SaasFullyBayesianSingleTaskGP 18 | from baseline_gs_factory import TORCH_DEVICE 19 | from torch import Tensor 20 | 21 | 22 | def get_sebo_gs( 23 | sparse_point: Tensor, 24 | penalty_name: str, 25 | num_sobol_trials: int, 26 | gp_model_name: str, 27 | sparsity_threshold: Optional[float] = None, 28 | torch_device: Optional[torch.device] = TORCH_DEVICE, 29 | ): 30 | if gp_model_name == "SAAS": 31 | surrogate = Surrogate(SaasFullyBayesianSingleTaskGP) 32 | elif gp_model_name == "GP": 33 | surrogate = Surrogate( 34 | botorch_model_class=FixedNoiseGP, allow_batched_models=False 35 | ) 36 | 37 | if sparsity_threshold is None: 38 | sparsity_threshold = sparse_point.shape[-1] 39 | 40 | gs = GenerationStrategy( 41 | name=f"NEHVI_{penalty_name}_MOO_SAAS", 42 | steps=[ 43 | GenerationStep( # Initialization step 44 | model=Models.SOBOL, 45 | num_trials=num_sobol_trials, 46 | ), 47 | GenerationStep( # BayesOpt step 48 | model=Models.BOTORCH_MODULAR, 49 | num_trials=-1, 50 | model_kwargs={ 51 | "surrogate": surrogate, 52 | "acquisition_class": SEBOAcquisition, 53 | "botorch_acqf_class": qNoisyExpectedHypervolumeImprovement, 54 | "torch_device": torch_device, 55 | "acquisition_options": { 56 | "penalty": penalty_name, 57 | "target_point": torch.tensor( 58 | sparse_point, device=torch_device, dtype=torch.double 59 | ), 60 | "sparsity_threshold": sparsity_threshold, 61 | "prune_baseline": False, 62 | }, 63 | }, 64 | ), 65 | ], 66 | ) 67 | return gs 68 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /benchmarks/ir_er_l0_gs_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from regularized_bo import ExternalRegularizedL0, InternalRegularizedL0 11 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 12 | from ax.modelbridge.registry import Models 13 | from ax.models.torch.botorch_modular.surrogate import Surrogate 14 | from botorch.acquisition import qNoisyExpectedImprovement 15 | from botorch.models import FixedNoiseGP, SaasFullyBayesianSingleTaskGP 16 | from baseline_gs_factory import TORCH_DEVICE 17 | from torch import Tensor 18 | 19 | 20 | def get_ir_l0_gs( 21 | sparse_point: Tensor, 22 | num_sobol_trials: int, 23 | gp_model_name: str, 24 | regularization_parameter: float, 25 | torch_device: Optional[torch.device] = TORCH_DEVICE, 26 | ): 27 | if gp_model_name == "SAAS": 28 | surrogate = Surrogate(SaasFullyBayesianSingleTaskGP) 29 | elif gp_model_name == "GP": 30 | surrogate = Surrogate(botorch_model_class=FixedNoiseGP) 31 | 32 | gs = GenerationStrategy( 33 | name=f"L0_internal_{gp_model_name}", 34 | steps=[ 35 | GenerationStep( # Initialization step 36 | model=Models.SOBOL, 37 | num_trials=num_sobol_trials, 38 | ), 39 | GenerationStep( # BayesOpt step 40 | model=Models.BOTORCH_MODULAR, 41 | num_trials=-1, 42 | model_kwargs={ 43 | "surrogate": surrogate, 44 | "acquisition_class": InternalRegularizedL0, 45 | "botorch_acqf_class": qNoisyExpectedImprovement, 46 | "torch_device": torch_device, 47 | "acquisition_options": { 48 | "target_point": torch.tensor( 49 | sparse_point, device=torch_device, dtype=torch.double 50 | ), 51 | "regularization_parameter": regularization_parameter, 52 | }, 53 | }, 54 | ), 55 | ], 56 | ) 57 | return gs 58 | 59 | 60 | def get_er_l0_gs( 61 | sparse_point: Tensor, 62 | num_sobol_trials: int, 63 | gp_model_name: str, 64 | regularization_parameter: float, 65 | torch_device: Optional[torch.device] = TORCH_DEVICE, 66 | ): 67 | if gp_model_name == "SAAS": 68 | surrogate = Surrogate(SaasFullyBayesianSingleTaskGP) 69 | elif gp_model_name == "GP": 70 | surrogate = Surrogate(botorch_model_class=FixedNoiseGP) 71 | 72 | gs = GenerationStrategy( 73 | name=f"L0_external_{gp_model_name}", 74 | steps=[ 75 | GenerationStep( # Initialization step 76 | model=Models.SOBOL, 77 | num_trials=num_sobol_trials, 78 | ), 79 | GenerationStep( # BayesOpt step 80 | model=Models.BOTORCH_MODULAR, 81 | num_trials=-1, 82 | model_kwargs={ 83 | "surrogate": surrogate, 84 | "acquisition_class": ExternalRegularizedL0, 85 | "botorch_acqf_class": qNoisyExpectedImprovement, 86 | "torch_device": torch_device, 87 | "acquisition_options": { 88 | "target_point": torch.tensor( 89 | sparse_point, device=torch_device, dtype=torch.double 90 | ), 91 | "regularization_parameter": regularization_parameter, 92 | }, 93 | }, 94 | ), 95 | ], 96 | ) 97 | return gs 98 | -------------------------------------------------------------------------------- /benchmarks/internal_l1_gs_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | from typing import Any, Optional, Tuple 9 | 10 | import torch 11 | from ax.core.data import Data 12 | from ax.core.experiment import Experiment 13 | from ax.modelbridge.factory import get_sobol 14 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 15 | from ax.modelbridge.registry import Cont_X_trans, Models, Y_trans 16 | from ax.modelbridge.torch import TorchModelBridge 17 | from ax.models.torch.botorch_defaults import _get_acquisition_func 18 | from ax.models.torch.fully_bayesian import get_fully_bayesian_acqf 19 | from botorch.acquisition.acquisition import AcquisitionFunction 20 | from botorch.acquisition.penalized import L1PenaltyObjective, PenalizedMCObjective 21 | from botorch.models.model import Model 22 | from baseline_gs_factory import TORCH_DEVICE 23 | from torch import Tensor 24 | 25 | 26 | def get_ir_l1_saas_gs( 27 | num_sobol_trials: int, 28 | sparse_point: Tensor, 29 | regularization_parameter: float, 30 | torch_device: Optional[torch.device] = TORCH_DEVICE, 31 | ) -> GenerationStrategy: 32 | gs = GenerationStrategy( 33 | name="L1_internal_SAAS", 34 | steps=[ 35 | GenerationStep(model=get_sobol, num_trials=num_sobol_trials), 36 | GenerationStep( 37 | model=functools.partial( 38 | get_L1_SAAS_internal_penalized, 39 | sparse_point=sparse_point, 40 | regularization_parameter=regularization_parameter, 41 | ), 42 | num_trials=-1, 43 | model_kwargs={"torch_device": torch_device}, 44 | ), 45 | ], 46 | ) 47 | return gs 48 | 49 | 50 | def get_ir_l1_gp_gs( 51 | num_sobol_trials: int, 52 | sparse_point: Tensor, 53 | regularization_parameter: float, 54 | torch_device: Optional[torch.device] = TORCH_DEVICE, 55 | ) -> GenerationStrategy: 56 | gs = GenerationStrategy( 57 | name="L1_internal_GP", 58 | steps=[ 59 | GenerationStep(model=get_sobol, num_trials=num_sobol_trials), 60 | GenerationStep( 61 | model=functools.partial( 62 | get_L1_internal_penalized, 63 | sparse_point=sparse_point, 64 | regularization_parameter=regularization_parameter, 65 | ), 66 | num_trials=-1, 67 | model_kwargs={"torch_device": torch_device}, 68 | ), 69 | ], 70 | ) 71 | return gs 72 | 73 | 74 | def get_L1_SAAS_internal_penalized( 75 | experiment: Experiment, 76 | data: Data, 77 | sparse_point: Tensor, 78 | regularization_parameter: Optional[Tensor] = 0.01, 79 | torch_device: Optional[torch.device] = TORCH_DEVICE, 80 | ) -> TorchModelBridge: 81 | """Instantiates a model using SAAS GP with IR-L1 ACQF for single objective optimization.""" 82 | return Models.FULLYBAYESIAN( 83 | experiment=experiment, 84 | search_space=experiment.search_space, 85 | data=data, 86 | acqf_constructor=get_fully_bayesian_acqf_nei_l1_internal, 87 | use_saas=True, 88 | num_samples=256, 89 | warmup_steps=512, 90 | transforms=Cont_X_trans + Y_trans, 91 | default_model_gen_options={ 92 | "acquisition_function_kwargs": { 93 | "chebyshev_scalarization": False, 94 | "sequential": True, 95 | "regularization_parameter": regularization_parameter, 96 | "sparse_point": sparse_point, 97 | }, 98 | }, 99 | disable_progbar=True, 100 | torch_dtype=torch.double, 101 | torch_device=torch_device, 102 | ) 103 | 104 | 105 | def get_L1_internal_penalized( 106 | experiment: Experiment, 107 | data: Data, 108 | sparse_point: Tensor, 109 | regularization_parameter: float = 0.01, 110 | torch_device: Optional[torch.device] = TORCH_DEVICE, 111 | ) -> TorchModelBridge: 112 | """Instantiates a model using standard GP with IR-L1 ACQF for single objective optimization.""" 113 | return Models.BOTORCH( 114 | acqf_constructor=get_NEI_internal_L1_penalized, # pyre-ignore 115 | default_model_gen_options={ 116 | "acquisition_function_kwargs": { 117 | "chebyshev_scalarization": False, 118 | "sequential": True, 119 | "regularization_parameter": regularization_parameter, 120 | "sparse_point": sparse_point, 121 | }, 122 | }, 123 | experiment=experiment, 124 | data=data, 125 | search_space=experiment.search_space, 126 | transforms=Cont_X_trans + Y_trans, 127 | torch_dtype=torch.double, 128 | torch_device=torch_device, 129 | ) 130 | 131 | 132 | def get_fully_bayesian_acqf_nei_l1_internal( 133 | model: Model, 134 | objective_weights: Tensor, 135 | outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, 136 | X_observed: Optional[Tensor] = None, 137 | X_pending: Optional[Tensor] = None, 138 | **kwargs: Any, 139 | ) -> AcquisitionFunction: 140 | return get_fully_bayesian_acqf( 141 | model=model, 142 | objective_weights=objective_weights, 143 | outcome_constraints=outcome_constraints, 144 | X_observed=X_observed, 145 | X_pending=X_pending, 146 | acqf_constructor=get_NEI_internal_L1_penalized, 147 | **kwargs, 148 | ) 149 | 150 | 151 | def get_NEI_internal_L1_penalized( 152 | model: Model, 153 | objective_weights: Tensor, 154 | sparse_point: Tensor, 155 | regularization_parameter: float, 156 | outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, 157 | X_observed: Optional[Tensor] = None, 158 | X_pending: Optional[Tensor] = None, 159 | **kwargs: Any, 160 | ) -> AcquisitionFunction: 161 | r"""Instantiates a qNoisyExpectedImprovement acquisition function, 162 | in which the objective function is penalized by L1 penalty.""" 163 | penalty_objective = L1PenaltyObjective(init_point=sparse_point) 164 | return _get_acquisition_func( 165 | model=model, 166 | acquisition_function_name="qNEI", 167 | objective_weights=objective_weights, 168 | outcome_constraints=outcome_constraints, 169 | X_observed=X_observed, 170 | X_pending=X_pending, 171 | mc_objective=PenalizedMCObjective, 172 | constrained_mc_objective=None, 173 | mc_objective_kwargs={ 174 | "penalty_objective": penalty_objective, 175 | "regularization_parameter": regularization_parameter, 176 | }, 177 | **kwargs, 178 | ) 179 | -------------------------------------------------------------------------------- /benchmarks/benchmark_gs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from logging import Logger 8 | from typing import Any, Dict, List, Optional 9 | 10 | import torch 11 | from ax.modelbridge.factory import get_GPEI, get_sobol 12 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 13 | from ax.modelbridge.strategies.rembo import REMBOStrategy 14 | from ax.utils.common.logger import get_logger 15 | from baseline_gs_factory import ( 16 | get_saasbo_gs, 17 | TORCH_DEVICE, 18 | ) 19 | from external_l1_gs_factory import get_er_l1_gp_gs, get_er_l1_saas_gs 20 | from internal_l1_gs_factory import get_ir_l1_gp_gs, get_ir_l1_saas_gs 21 | from ir_er_l0_gs_factory import get_er_l0_gs, get_ir_l0_gs 22 | from sebo_gs_factory import get_sebo_gs 23 | 24 | logger: Logger = get_logger(__name__) 25 | 26 | 27 | def get_generation_strategy( 28 | strategy_name: str, 29 | num_sobol_trials: int, 30 | sparse_point: List, 31 | strategies_args: Optional[Dict[str, Any]] = None, 32 | ) -> GenerationStrategy: 33 | 34 | if strategies_args.get("torch_device", None) is not None: 35 | torch_device = torch.device(strategies_args.get("torch_device")) 36 | else: 37 | torch_device = TORCH_DEVICE 38 | sparse_point = torch.tensor(sparse_point, device=torch_device, dtype=torch.double) 39 | 40 | logger.info(f"torch device: {torch_device}") 41 | 42 | if strategy_name == "Sobol": 43 | gs = GenerationStrategy( 44 | name="Sobol", steps=[GenerationStep(model=get_sobol, num_trials=-1)] 45 | ) 46 | elif strategy_name == "GPEI": 47 | gs = GenerationStrategy( 48 | name="GPEI", 49 | steps=[ 50 | GenerationStep(model=get_sobol, num_trials=num_sobol_trials), 51 | GenerationStep(model=get_GPEI, num_trials=-1), 52 | ], 53 | ) 54 | elif strategy_name == "SAASBO": 55 | gs = get_saasbo_gs(num_sobol_trials=num_sobol_trials, torch_device=torch_device) 56 | elif strategy_name == "REMBO": 57 | gs = REMBOStrategy( 58 | D=len(sparse_point), 59 | d=strategies_args.get(strategy_name, {}).get("d", 8), 60 | init_per_proj=num_sobol_trials, 61 | ) 62 | elif "L1_internal_SAAS" in strategy_name: 63 | gs = get_ir_l1_saas_gs( 64 | num_sobol_trials=num_sobol_trials, 65 | sparse_point=sparse_point, 66 | regularization_parameter=strategies_args.get( 67 | "regularization_parameter", 0.1 68 | ), 69 | torch_device=torch_device, 70 | ) 71 | elif "L1_internal_GP" in strategy_name: 72 | gs = get_ir_l1_gp_gs( 73 | num_sobol_trials=num_sobol_trials, 74 | sparse_point=sparse_point, 75 | regularization_parameter=strategies_args.get( 76 | "regularization_parameter", 0.1 77 | ), 78 | torch_device=torch_device, 79 | ) 80 | elif "L0_internal_SAAS" in strategy_name: 81 | gs = get_ir_l0_gs( 82 | sparse_point=sparse_point, 83 | num_sobol_trials=num_sobol_trials, 84 | gp_model_name="SAAS", 85 | regularization_parameter=strategies_args.get( 86 | "regularization_parameter", 0.1 87 | ), 88 | torch_device=torch_device, 89 | ) 90 | elif "L0_internal_GP" in strategy_name: 91 | gs = get_ir_l0_gs( 92 | sparse_point=sparse_point, 93 | num_sobol_trials=num_sobol_trials, 94 | gp_model_name="GP", 95 | regularization_parameter=strategies_args.get( 96 | "regularization_parameter", 0.1 97 | ), 98 | torch_device=torch_device, 99 | ) 100 | elif "L1_external_SAAS" in strategy_name: 101 | gs = get_er_l1_saas_gs( 102 | num_sobol_trials=num_sobol_trials, 103 | sparse_point=sparse_point, 104 | regularization_parameter=strategies_args.get( 105 | "regularization_parameter", 0.1 106 | ), 107 | torch_device=torch_device, 108 | ) 109 | elif "L1_external_GP" in strategy_name: 110 | gs = get_er_l1_gp_gs( 111 | num_sobol_trials=num_sobol_trials, 112 | sparse_point=sparse_point, 113 | regularization_parameter=strategies_args.get( 114 | "regularization_parameter", 0.1 115 | ), 116 | torch_device=torch_device, 117 | ) 118 | elif "L0_external_SAAS" in strategy_name: 119 | gs = get_er_l0_gs( 120 | sparse_point=sparse_point, 121 | num_sobol_trials=num_sobol_trials, 122 | gp_model_name="SAAS", 123 | regularization_parameter=strategies_args.get( 124 | "regularization_parameter", 0.1 125 | ), 126 | torch_device=torch_device, 127 | ) 128 | elif "L0_external_GP" in strategy_name: 129 | gs = get_er_l0_gs( 130 | sparse_point=sparse_point, 131 | num_sobol_trials=num_sobol_trials, 132 | gp_model_name="GP", 133 | regularization_parameter=strategies_args.get( 134 | "regularization_parameter", 0.1 135 | ), 136 | torch_device=torch_device, 137 | ) 138 | elif "NEHVI_L1_MOO_SAAS" in strategy_name: 139 | gs = get_sebo_gs( 140 | sparse_point=sparse_point, 141 | penalty_name="L1_norm", 142 | num_sobol_trials=num_sobol_trials, 143 | gp_model_name="SAAS", 144 | sparsity_threshold=strategies_args.get( 145 | "sparsity_threshold", 146 | sparse_point.shape[-1], 147 | ), 148 | torch_device=torch_device, 149 | ) 150 | elif "NEHVI_L0_MOO_SAAS" in strategy_name: 151 | gs = get_sebo_gs( 152 | sparse_point=sparse_point, 153 | penalty_name="L0_norm", 154 | num_sobol_trials=num_sobol_trials, 155 | gp_model_name="SAAS", 156 | sparsity_threshold=strategies_args.get( 157 | "sparsity_threshold", 158 | sparse_point.shape[-1], 159 | ), 160 | torch_device=torch_device, 161 | ) 162 | elif "NEHVI_L1_MOO_GP" in strategy_name: 163 | gs = get_sebo_gs( 164 | sparse_point=sparse_point, 165 | penalty_name="L1_norm", 166 | num_sobol_trials=num_sobol_trials, 167 | gp_model_name="GP", 168 | sparsity_threshold=strategies_args.get( 169 | "sparsity_threshold", 170 | sparse_point.shape[-1], 171 | ), 172 | torch_device=torch_device, 173 | ) 174 | elif "NEHVI_L0_MOO_GP" in strategy_name: 175 | gs = get_sebo_gs( 176 | sparse_point=sparse_point, 177 | penalty_name="L0_norm", 178 | num_sobol_trials=num_sobol_trials, 179 | gp_model_name="GP", 180 | sparsity_threshold=strategies_args.get( 181 | "sparsity_threshold", 182 | sparse_point.shape[-1], 183 | ), 184 | torch_device=torch_device, 185 | ) 186 | return gs 187 | -------------------------------------------------------------------------------- /benchmarks/external_l1_gs_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | from typing import Any, Optional, Tuple 9 | 10 | import torch 11 | from ax.core.data import Data 12 | from ax.core.experiment import Experiment 13 | from ax.modelbridge.factory import get_sobol 14 | from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy 15 | from ax.modelbridge.registry import Cont_X_trans, Models, Y_trans 16 | from ax.modelbridge.torch import TorchModelBridge 17 | from ax.models.torch.botorch_defaults import _get_acquisition_func 18 | from ax.models.torch.fully_bayesian import get_fully_bayesian_acqf 19 | from botorch.acquisition.acquisition import AcquisitionFunction 20 | from botorch.acquisition.penalized import L1Penalty, PenalizedAcquisitionFunction 21 | from botorch.models.model import Model 22 | from baseline_gs_factory import TORCH_DEVICE 23 | from torch import Tensor 24 | 25 | 26 | def get_er_l1_saas_gs( 27 | num_sobol_trials: int, 28 | sparse_point: Tensor, 29 | regularization_parameter: float, 30 | torch_device: Optional[torch.device] = TORCH_DEVICE, 31 | ) -> GenerationStrategy: 32 | gs = GenerationStrategy( 33 | name="L1_external_SAAS", 34 | steps=[ 35 | GenerationStep(model=get_sobol, num_trials=num_sobol_trials), 36 | GenerationStep( 37 | model=functools.partial( 38 | get_L1_SAAS_external_penalized, 39 | sparse_point=sparse_point, 40 | regularization_parameter=regularization_parameter, 41 | ), 42 | num_trials=-1, 43 | model_kwargs={"torch_device": torch_device}, 44 | ), 45 | ], 46 | ) 47 | return gs 48 | 49 | 50 | def get_er_l1_gp_gs( 51 | num_sobol_trials: int, 52 | sparse_point: Tensor, 53 | regularization_parameter: float, 54 | torch_device: Optional[torch.device] = TORCH_DEVICE, 55 | ) -> GenerationStrategy: 56 | gs = GenerationStrategy( 57 | name="L1_external_GP", 58 | steps=[ 59 | GenerationStep(model=get_sobol, num_trials=num_sobol_trials), 60 | GenerationStep( 61 | model=functools.partial( 62 | get_L1_external_penalized, 63 | sparse_point=sparse_point, 64 | regularization_parameter=regularization_parameter, 65 | ), 66 | num_trials=-1, 67 | model_kwargs={"torch_device": torch_device}, 68 | ), 69 | ], 70 | ) 71 | return gs 72 | 73 | 74 | def get_L1_SAAS_external_penalized( 75 | experiment: Experiment, 76 | data: Data, 77 | sparse_point: Tensor, 78 | regularization_parameter: Optional[float] = 0.01, 79 | torch_device: Optional[torch.device] = TORCH_DEVICE, 80 | ) -> TorchModelBridge: 81 | """Instantiates a model using SAAS GP with ER-L1 ACQF for single objective optimization.""" 82 | return Models.FULLYBAYESIAN( 83 | experiment=experiment, 84 | search_space=experiment.search_space, 85 | data=data, 86 | acqf_constructor=get_fully_bayesian_acqf_nei_l1_external, 87 | num_samples=256, 88 | warmup_steps=512, 89 | disable_progbar=True, 90 | transforms=Cont_X_trans + Y_trans, 91 | default_model_gen_options={ 92 | "acquisition_function_kwargs": { 93 | "chebyshev_scalarization": False, 94 | "sequential": True, 95 | "regularization_parameter": regularization_parameter, 96 | "sparse_point": sparse_point, 97 | }, 98 | }, 99 | torch_dtype=torch.double, 100 | torch_device=torch_device, 101 | ) 102 | 103 | 104 | def get_L1_external_penalized( 105 | experiment: Experiment, 106 | data: Data, 107 | sparse_point: Tensor, # TODO: make this ObservationFeatures 108 | regularization_parameter: float = 0.01, 109 | torch_device: Optional[torch.device] = TORCH_DEVICE, 110 | ) -> TorchModelBridge: 111 | """Instantiates a model using standard GP with ER-L1 ACQF for single objective optimization.""" 112 | return Models.BOTORCH( 113 | acqf_constructor=get_NEI_external_L1_penalized, 114 | default_model_gen_options={ 115 | "acquisition_function_kwargs": { 116 | "chebyshev_scalarization": False, 117 | "sequential": True, 118 | "regularization_parameter": regularization_parameter, 119 | "sparse_point": sparse_point, 120 | }, 121 | }, 122 | experiment=experiment, 123 | data=data, 124 | search_space=experiment.search_space, 125 | transforms=Cont_X_trans + Y_trans, 126 | torch_dtype=torch.double, 127 | torch_device=torch_device, 128 | ) 129 | 130 | 131 | def get_fully_bayesian_acqf_nei_l1_external( 132 | model: Model, 133 | objective_weights: Tensor, 134 | sparse_point: Tensor, 135 | regularization_parameter: float, 136 | outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, 137 | X_observed: Optional[Tensor] = None, 138 | X_pending: Optional[Tensor] = None, 139 | **kwargs: Any, 140 | ) -> AcquisitionFunction: 141 | r"""Instantiates a penalized qNoisyExpectedImprovement acquisition function, 142 | in which the penalty term is a L1 penalty added externally to the ACQF.""" 143 | # default acqf_constructor is get_NEI 144 | raw_acqf = get_fully_bayesian_acqf( 145 | model=model, 146 | objective_weights=objective_weights, 147 | outcome_constraints=outcome_constraints, 148 | X_observed=X_observed, 149 | X_pending=X_pending, 150 | **kwargs, 151 | ) 152 | penalty_func = L1Penalty(init_point=sparse_point) 153 | acqf = PenalizedAcquisitionFunction( 154 | raw_acqf=raw_acqf, 155 | penalty_func=penalty_func, 156 | regularization_parameter=regularization_parameter, 157 | ) 158 | return acqf 159 | 160 | 161 | def get_NEI_external_L1_penalized( 162 | model: Model, 163 | objective_weights: Tensor, 164 | sparse_point: Tensor, 165 | regularization_parameter: float, 166 | outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, 167 | X_observed: Optional[Tensor] = None, 168 | X_pending: Optional[Tensor] = None, 169 | **kwargs: Any, 170 | ) -> AcquisitionFunction: 171 | r"""Instantiates a PenalizedAcquisitionFunction that adds L1 penalty 172 | to qNoisyExpectedImprovement acquisition function. External means the L1 penalty 173 | is added outside of the NEI acuqisition function.""" 174 | # this is the same function as get_NEI_external_L1_penalized in ax.fb.models.torch.parego_l1_penalized 175 | raw_acqf = _get_acquisition_func( 176 | model=model, 177 | acquisition_function_name="qNEI", 178 | objective_weights=objective_weights, 179 | outcome_constraints=outcome_constraints, 180 | X_observed=X_observed, 181 | X_pending=X_pending, 182 | **kwargs, 183 | ) 184 | penalty_func = L1Penalty(init_point=sparse_point) 185 | acqf = PenalizedAcquisitionFunction( 186 | raw_acqf=raw_acqf, 187 | penalty_func=penalty_func, 188 | regularization_parameter=regularization_parameter, 189 | ) 190 | return acqf 191 | -------------------------------------------------------------------------------- /benchmarks/run_synthetic_benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import math 9 | from typing import Any, Dict, List, Optional 10 | 11 | import numpy as np 12 | from ax.service.ax_client import AxClient, ObjectiveProperties 13 | from ax.storage.json_store.encoder import object_to_json 14 | from ax.utils.measurement.synthetic_functions import hartmann6 15 | from benchmark_gs import get_generation_strategy 16 | 17 | 18 | def branin_augment(x_vec, augment_dim): 19 | assert len(x_vec) == augment_dim 20 | x1, x2 = ( 21 | 15 * x_vec[0] - 5, 22 | 15 * x_vec[1], 23 | ) # Only dimensions 0 and augment_dim-1 affect the value of the function 24 | t1 = x2 - 5.1 / (4 * math.pi**2) * x1**2 + 5 / math.pi * x1 - 6 25 | t2 = 10 * (1 - 1 / (8 * math.pi)) * np.cos(x1) 26 | return t1**2 + t2 + 10 27 | 28 | 29 | def hartmann6_augment(x_vec, augment_dim): 30 | assert len(x_vec) == augment_dim 31 | return hartmann6.f(np.array(x_vec[:6])) 32 | 33 | 34 | def nnz_exact(x: List[float], sparse_point: List[float]): 35 | return len(x) - (np.array(x) == np.array(sparse_point)).sum() 36 | 37 | 38 | def run_single_objective_branin_benchmark( 39 | strategy_name: str, 40 | irep: int, 41 | num_sobol_trials: int = 8, 42 | num_trials: int = 100, 43 | augment_dim: int = 10, 44 | strategies_args: Optional[Dict[str, Any]] = None, 45 | ) -> str: 46 | # set zero as the baseline to shrink towards 47 | sparse_point = [0 for _ in range(augment_dim)] 48 | gs = get_generation_strategy( 49 | strategy_name=strategy_name, 50 | sparse_point=sparse_point, 51 | num_sobol_trials=num_sobol_trials, 52 | strategies_args=strategies_args, 53 | ) 54 | if "L0" in strategy_name: 55 | penalty_name = "L0_norm" 56 | else: 57 | penalty_name = "L1_norm" 58 | 59 | axc = AxClient(generation_strategy=gs) 60 | 61 | experiment_parameters = [ 62 | { 63 | "name": f"parameter_{i}", 64 | "type": "range", 65 | "bounds": [0, 1], 66 | "value_type": "float", 67 | "log_scale": False, 68 | } 69 | for i in range(augment_dim) 70 | ] 71 | 72 | objective_metrics = { 73 | "objective": ObjectiveProperties(minimize=False), 74 | } 75 | if "MOO" in strategy_name: 76 | sparse_objective_threshold = strategies_args.get( 77 | "sparsity_threshold", augment_dim 78 | ) 79 | objective_metrics = { 80 | # I previously use: -10 81 | "objective": ObjectiveProperties(minimize=False, threshold=-10), 82 | # I previously use: 15 * augment_dim 83 | penalty_name: ObjectiveProperties( 84 | minimize=True, threshold=sparse_objective_threshold 85 | ), 86 | } 87 | 88 | axc.create_experiment( 89 | name="sourcing_experiment", 90 | parameters=experiment_parameters, 91 | objectives=objective_metrics, 92 | ) 93 | 94 | def evaluation(parameters): 95 | # put parameters into 1-D array 96 | x = [parameters.get(param["name"]) for param in experiment_parameters] 97 | res = branin_augment(x_vec=x, augment_dim=augment_dim) 98 | if penalty_name == "L0_norm": 99 | penalty_value = nnz_exact(x, sparse_point) 100 | else: 101 | penalty_value = np.linalg.norm(x, ord=1) 102 | eval_res = { 103 | # flip the sign to maximize 104 | "objective": (-res, 0.0), 105 | penalty_name: (penalty_value, 0.0), 106 | } 107 | return eval_res 108 | 109 | for _ in range(num_trials): 110 | parameters, trial_index = axc.get_next_trial() 111 | res = evaluation(parameters) 112 | axc.complete_trial(trial_index=trial_index, raw_data=res) 113 | 114 | res = json.dumps(object_to_json(axc.experiment)) 115 | with open(f'results/synthetic_branin_{strategy_name}_rep_{irep}.json', "w") as fout: 116 | json.dump(res, fout) 117 | return res 118 | 119 | 120 | def run_single_objective_branin_benchmark_reps( 121 | strategy: str, 122 | augment_dim: int = 10, 123 | num_sobol_trials: int = 8, 124 | num_trials: int = 50, 125 | reps: int = 20, 126 | strategy_args: Optional[Dict[str, Any]] = None, 127 | ): 128 | res = {strategy: []} 129 | 130 | for irep in range(reps): 131 | res[strategy].append( 132 | run_single_objective_branin_benchmark( 133 | strategy_name=strategy, 134 | irep=irep, 135 | num_sobol_trials=num_sobol_trials, 136 | num_trials=num_trials, 137 | augment_dim=augment_dim, 138 | strategies_args=strategy_args, 139 | ) 140 | ) 141 | with open(f'results/synthetic_branin_{strategy}.json', "w") as fout: 142 | json.dump(res, fout) 143 | 144 | 145 | def run_single_objective_hartmann6_benchmark( 146 | strategy_name: str, 147 | irep: int, 148 | num_sobol_trials: int = 8, 149 | num_trials: int = 100, 150 | augment_dim: int = 20, 151 | strategies_args: Optional[Dict[str, Any]] = None, 152 | ) -> str: 153 | # set zero as the baseline to shrink towards 154 | sparse_point = [0 for _ in range(augment_dim)] 155 | gs = get_generation_strategy( 156 | strategy_name=strategy_name, 157 | sparse_point=sparse_point, 158 | num_sobol_trials=num_sobol_trials, 159 | strategies_args=strategies_args, 160 | ) 161 | if "L0" in strategy_name: 162 | penalty_name = "L0_norm" 163 | else: 164 | penalty_name = "L1_norm" 165 | 166 | axc = AxClient(generation_strategy=gs) 167 | 168 | experiment_parameters = [ 169 | { 170 | "name": f"parameter_{i}", 171 | "type": "range", 172 | "bounds": [0, 1], 173 | "value_type": "float", 174 | "log_scale": False, 175 | } 176 | for i in range(augment_dim) 177 | ] 178 | 179 | objective_metrics = { 180 | "objective": ObjectiveProperties(minimize=False), 181 | } 182 | if "MOO" in strategy_name: 183 | sparse_penalty_threshold = strategies_args.get( 184 | "sparsity_threshold", augment_dim 185 | ) 186 | objective_metrics = { 187 | "objective": ObjectiveProperties(minimize=False, threshold=0), 188 | penalty_name: ObjectiveProperties( 189 | minimize=True, threshold=sparse_penalty_threshold 190 | ), 191 | } 192 | 193 | axc.create_experiment( 194 | name="hartmann6_augment_experiment", 195 | parameters=experiment_parameters, 196 | objectives=objective_metrics, 197 | ) 198 | 199 | def evaluation(parameters): 200 | # put parameters into 1-D array 201 | x = [parameters.get(param["name"]) for param in experiment_parameters] 202 | res = hartmann6_augment(x_vec=x, augment_dim=augment_dim) 203 | if penalty_name == "L0_norm": 204 | penalty_value = nnz_exact(x, sparse_point) 205 | else: 206 | penalty_value = np.linalg.norm(x, ord=1) 207 | eval_res = { 208 | # flip the sign to maximize 209 | "objective": (-res, 0.0), 210 | penalty_name: (penalty_value, 0.0), 211 | } 212 | return eval_res 213 | 214 | for _ in range(num_trials): 215 | parameters, trial_index = axc.get_next_trial() 216 | res = evaluation(parameters) 217 | axc.complete_trial(trial_index=trial_index, raw_data=res) 218 | 219 | res = json.dumps(object_to_json(axc.experiment)) 220 | with open(f'results/synthetic_hartmann6_{strategy_name}_rep_{irep}.json', "w") as fout: 221 | json.dump(res, fout) 222 | return res 223 | 224 | 225 | def run_single_objective_hartmann6_benchmark_reps( 226 | strategy: str, 227 | augment_dim: int = 10, 228 | num_sobol_trials: int = 8, 229 | num_trials: int = 50, 230 | reps: int = 20, 231 | strategy_args: Optional[Dict[str, Any]] = None, 232 | ): 233 | res = {strategy: []} 234 | 235 | for irep in range(reps): 236 | res[strategy].append( 237 | run_single_objective_hartmann6_benchmark( 238 | strategy_name=strategy, 239 | irep=irep, 240 | num_sobol_trials=num_sobol_trials, 241 | num_trials=num_trials, 242 | augment_dim=augment_dim, 243 | strategies_args=strategy_args, 244 | ) 245 | ) 246 | with open(f'results/synthetic_hartmann6_{strategy}.json', "w") as fout: 247 | json.dump(res, fout) 248 | 249 | 250 | if __name__ == '__main__': 251 | # Run all of the benchmark replicates. 252 | 253 | run_single_objective_branin_benchmark_reps( 254 | strategy="Sobol", 255 | augment_dim=10, 256 | num_sobol_trials=8, 257 | num_trials=20, 258 | reps=1, 259 | strategy_args={}, 260 | ) 261 | -------------------------------------------------------------------------------- /benchmarks/regularized_bo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | 9 | from logging import Logger 10 | from typing import Any, Callable, Dict, List, Optional, Tuple, Type 11 | 12 | import torch 13 | from ax.core.search_space import SearchSpaceDigest 14 | from ax.models.torch.botorch_modular.acquisition import Acquisition 15 | from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse 16 | from ax.models.torch.botorch_modular.sebo import ( 17 | get_batch_initial_conditions, 18 | SEBOAcquisition, 19 | ) 20 | from ax.models.torch.botorch_modular.surrogate import Surrogate 21 | from ax.models.torch_base import TorchOptConfig 22 | from ax.utils.common.constants import Keys 23 | from ax.utils.common.logger import get_logger 24 | from botorch.acquisition.acquisition import AcquisitionFunction 25 | from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform 26 | from botorch.acquisition.penalized import ( 27 | L0PenaltyApprox, 28 | L0PenaltyApproxObjective, 29 | PenalizedAcquisitionFunction, 30 | PenalizedMCObjective, 31 | ) 32 | from botorch.acquisition.risk_measures import RiskMeasureMCObjective 33 | from botorch.models.deterministic import GenericDeterministicModel 34 | from botorch.models.model import Model 35 | from botorch.optim import ( 36 | Homotopy, 37 | HomotopyParameter, 38 | LogLinearHomotopySchedule, 39 | optimize_acqf_homotopy, 40 | ) 41 | from botorch.posteriors.fully_bayesian import MCMC_DIM 42 | from botorch.utils.multi_objective.pareto import is_non_dominated 43 | from botorch.utils.objective import get_objective_weights_transform 44 | from botorch.utils.transforms import is_fully_bayesian 45 | from torch import Tensor 46 | 47 | 48 | logger: Logger = get_logger(__name__) 49 | 50 | 51 | class ExternalRegularizedL0(Acquisition): 52 | """ 53 | Implement the external regularizied acquisition function with L0 norm. 54 | 55 | The ER-L0 takes a regularization parameter and add L0 norm directly to the 56 | acqusition function. The regularization parameter controls the target sparisty 57 | level. It uses the same optimization method as SEBO-L0 i.e. a differentiable 58 | relaxation based on homotopy continuation to efficiently optimize for sparsity. 59 | """ 60 | 61 | def __init__( 62 | self, 63 | surrogates: Dict[str, Surrogate], 64 | search_space_digest: SearchSpaceDigest, 65 | torch_opt_config: TorchOptConfig, 66 | botorch_acqf_class: Type[AcquisitionFunction], 67 | options: Optional[Dict[str, Any]] = None, 68 | ) -> None: 69 | if len(surrogates) > 1: 70 | raise ValueError("ER-L0 does not support support multiple surrogates.") 71 | surrogate = surrogates[Keys.ONLY_SURROGATE] 72 | 73 | tkwargs = {"dtype": surrogate.dtype, "device": surrogate.device} 74 | options = options or {} 75 | self.target_point: Tensor = options.get("target_point", None) 76 | if self.target_point is None: 77 | raise ValueError("please provide target point.") 78 | self.target_point.to(**tkwargs) # pyre-ignore 79 | 80 | self.regularization_parameter: float = options.get( 81 | "regularization_parameter", 0.0 82 | ) 83 | 84 | # construct determinsitic model for penalty term 85 | self.penalty_name = "L0_norm" 86 | # pyre-fixme[4]: Attribute must be annotated. 87 | self.penalty_term = self._construct_penalty() 88 | 89 | # instantiate botorch_acqf_class 90 | super().__init__( 91 | surrogates={"regularized_bo": surrogate}, 92 | search_space_digest=search_space_digest, 93 | torch_opt_config=torch_opt_config, 94 | botorch_acqf_class=botorch_acqf_class, 95 | options=options, 96 | ) 97 | raw_acqf = self.acqf 98 | self.acqf = PenalizedAcquisitionFunction( 99 | raw_acqf=raw_acqf, 100 | penalty_func=self.penalty_term, 101 | regularization_parameter=self.regularization_parameter, 102 | ) 103 | # create X_pareto for gen batch initial 104 | Xs = self.surrogates["regularized_bo"].Xs 105 | all_Y = torch.cat( 106 | # pyre-ignore 107 | [d.Y.values for d in self.surrogates["regularized_bo"].training_data], 108 | dim=-1, 109 | ) 110 | Y_pareto = torch.cat( 111 | [ 112 | all_Y, 113 | self.penalty_term(Xs[0].unsqueeze(1)).unsqueeze(-1), 114 | ], 115 | dim=-1, 116 | ) 117 | # pyre-ignore 118 | self.X_pareto = self._obtain_X_pareto(Y_pareto=Y_pareto, **tkwargs) 119 | 120 | def _obtain_X_pareto(self, Y_pareto: Tensor, **tkwargs: Any) -> Tensor: 121 | ow = torch.cat([self._full_objective_weights, torch.tensor([-1], **tkwargs)]) 122 | ind_pareto = is_non_dominated(Y_pareto * ow) 123 | X_pareto = self.surrogates["regularized_bo"].Xs[0][ind_pareto].clone() 124 | return X_pareto 125 | 126 | def _construct_penalty(self) -> GenericDeterministicModel: 127 | """Construct a penalty term to be added to ER-L0 acqusition function. 128 | Returns: 129 | A tensor of size "batch_shape" representing the acqfn for each q-batch. 130 | """ 131 | L0 = L0PenaltyApprox(target_point=self.target_point, a=1e-3) 132 | return GenericDeterministicModel(f=L0) 133 | 134 | def optimize( 135 | self, 136 | n: int, 137 | search_space_digest: SearchSpaceDigest, 138 | inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, 139 | fixed_features: Optional[Dict[int, float]] = None, 140 | rounding_func: Optional[Callable[[Tensor], Tensor]] = None, 141 | optimizer_options: Optional[Dict[str, Any]] = None, 142 | ) -> Tuple[Tensor, Tensor]: 143 | """Generate a set of candidates via multi-start optimization. Obtains 144 | candidates and their associated acquisition function values. 145 | 146 | Args: 147 | n: The number of candidates to generate. 148 | search_space_digest: A ``SearchSpaceDigest`` object containing search space 149 | properties, e.g. ``bounds`` for optimization. 150 | inequality_constraints: A list of tuples (indices, coefficients, rhs), 151 | with each tuple encoding an inequality constraint of the form 152 | ``sum_i (X[indices[i]] * coefficients[i]) >= rhs``. 153 | fixed_features: A map `{feature_index: value}` for features that 154 | should be fixed to a particular value during generation. 155 | rounding_func: A function that post-processes an optimization 156 | result appropriately (i.e., according to `round-trip` 157 | transformations). 158 | optimizer_options: Options for the optimizer function, e.g. ``sequential`` 159 | or ``raw_samples``. 160 | """ 161 | candidates, expected_acquisition_value = SEBOAcquisition.optimize( 162 | self=self, # pyre-ignore 163 | n=n, 164 | search_space_digest=search_space_digest, 165 | inequality_constraints=inequality_constraints, 166 | fixed_features=fixed_features, 167 | rounding_func=rounding_func, 168 | optimizer_options=optimizer_options, 169 | ) 170 | return candidates, expected_acquisition_value 171 | 172 | def _optimize_with_homotopy( 173 | self, 174 | n: int, 175 | search_space_digest: SearchSpaceDigest, 176 | fixed_features: Optional[Dict[int, float]] = None, 177 | rounding_func: Optional[Callable[[Tensor], Tensor]] = None, 178 | optimizer_options: Optional[Dict[str, Any]] = None, 179 | ) -> Tuple[Tensor, Tensor]: 180 | """Optimize ER ACQF with L0 norm using homotopy.""" 181 | # extend to fixed a no homotopy_schedule schedule 182 | _tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device) 183 | ssd = search_space_digest 184 | bounds = _tensorize(ssd.bounds).t() 185 | 186 | homotopy_schedule = LogLinearHomotopySchedule(start=0.1, end=1e-3, num_steps=30) 187 | 188 | # Prepare arguments for optimizer 189 | optimizer_options_with_defaults = optimizer_argparse( 190 | self.acqf, 191 | bounds=bounds, 192 | q=n, 193 | optimizer_options=optimizer_options, 194 | ) 195 | 196 | def callback(): # pyre-ignore 197 | X_pending = self.acqf.X_pending 198 | self.acqf.__init__( 199 | raw_acqf=self.acqf.raw_acqf, 200 | penalty_func=self.penalty_term, 201 | regularization_parameter=self.regularization_parameter, 202 | ) 203 | self.acqf.model = self.surrogates["regularized_bo"].model 204 | self.acqf.set_X_pending(X_pending) 205 | 206 | homotopy = Homotopy( 207 | homotopy_parameters=[ 208 | HomotopyParameter( 209 | parameter=self.penalty_term._f.a, 210 | schedule=homotopy_schedule, 211 | ) 212 | ], 213 | callbacks=[callback], 214 | ) 215 | batch_initial_conditions = get_batch_initial_conditions( 216 | acq_function=self.acqf, 217 | raw_samples=optimizer_options_with_defaults["raw_samples"], 218 | X_pareto=self.X_pareto, 219 | target_point=self.target_point, 220 | num_restarts=optimizer_options_with_defaults["num_restarts"], 221 | **{"device": self.device, "dtype": self.dtype}, 222 | ) 223 | candidates, expected_acquisition_value = optimize_acqf_homotopy( 224 | q=n, 225 | acq_function=self.acqf, 226 | bounds=bounds, 227 | homotopy=homotopy, 228 | num_restarts=optimizer_options_with_defaults["num_restarts"], 229 | raw_samples=optimizer_options_with_defaults["raw_samples"], 230 | post_processing_func=rounding_func, 231 | fixed_features=fixed_features, 232 | batch_initial_conditions=batch_initial_conditions, 233 | ) 234 | return candidates, expected_acquisition_value 235 | 236 | 237 | class InternalRegularizedL0(ExternalRegularizedL0): 238 | """ 239 | Implement the internal regularizied acquisition function with L0 norm. 240 | 241 | The IR-L0 takes a regularization parameter and add L0 norm directly to the 242 | objective function. The regularization parameter controls the target sparisty 243 | level. It uses the same optimization method as SEBO-L0 i.e. a differentiable 244 | relaxation based on homotopy continuation to efficiently optimize for sparsity. 245 | """ 246 | 247 | def __init__( 248 | self, 249 | surrogates: Dict[str, Surrogate], 250 | search_space_digest: SearchSpaceDigest, 251 | torch_opt_config: TorchOptConfig, 252 | botorch_acqf_class: Type[AcquisitionFunction], 253 | options: Optional[Dict[str, Any]] = None, 254 | ) -> None: 255 | if len(surrogates) > 1: 256 | raise ValueError("IR-L0 does not support support multiple surrogates.") 257 | surrogate = surrogates[Keys.ONLY_SURROGATE] 258 | 259 | tkwargs = {"dtype": surrogate.dtype, "device": surrogate.device} 260 | options = options or {} 261 | self.target_point: Tensor = options.get("target_point", None) 262 | if self.target_point is None: 263 | raise ValueError("please provide target point.") 264 | self.target_point.to(**tkwargs) # pyre-ignore 265 | 266 | self.regularization_parameter: float = options.get( 267 | "regularization_parameter", 0.0 268 | ) 269 | # if fully-bayesian model is used, 270 | # decide dim to expand penalty term to match dimension of objective 271 | self.expand_dim: Optional[int] = None 272 | if is_fully_bayesian(surrogate.model): 273 | self.expand_dim = MCMC_DIM + 1 274 | 275 | # construct determinsitic model for penalty term 276 | self.penalty_name = "L0_norm" 277 | # pyre-fixme[4]: Attribute must be annotated. 278 | self.penalty_term = self._construct_penalty() 279 | 280 | # instantiate botorch_acqf_class 281 | Acquisition.__init__( 282 | self=self, 283 | surrogates={"regularized_bo": surrogate}, 284 | search_space_digest=search_space_digest, 285 | torch_opt_config=torch_opt_config, 286 | botorch_acqf_class=botorch_acqf_class, 287 | options=options, 288 | ) 289 | # create X_pareto for gen batch initial 290 | Xs = self.surrogates["regularized_bo"].Xs 291 | all_Y = torch.cat( 292 | # pyre-ignore 293 | [d.Y.values for d in self.surrogates["regularized_bo"].training_data], 294 | dim=-1, 295 | ) 296 | Y_pareto = torch.cat( 297 | [ 298 | all_Y, 299 | self.penalty_term(Xs[0]).transpose(1, 0), 300 | ], 301 | dim=-1, 302 | ) 303 | # pyre-ignore 304 | self.X_pareto = self._obtain_X_pareto(Y_pareto=Y_pareto, **tkwargs) 305 | 306 | def _construct_penalty(self) -> GenericDeterministicModel: 307 | """Construct a penalty term to be added to the objective function to be used in IR-L0. 308 | Returns: 309 | A "1 x batch_shape x q" tensor representing the penalty for each point. 310 | The first dimension corresponds to the dimension of MC samples. 311 | """ 312 | L0 = L0PenaltyApproxObjective(target_point=self.target_point, a=1e-3) 313 | return GenericDeterministicModel(f=L0) 314 | 315 | def get_botorch_objective_and_transform( 316 | self, 317 | botorch_acqf_class: Type[AcquisitionFunction], 318 | model: Model, 319 | objective_weights: Tensor, 320 | objective_thresholds: Optional[Tensor] = None, 321 | outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, 322 | X_observed: Optional[Tensor] = None, 323 | risk_measure: Optional[RiskMeasureMCObjective] = None, 324 | ) -> Tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]: 325 | """Construct the penalized objective by adding the penalty term to 326 | the original objective. 327 | """ 328 | if outcome_constraints is not None: 329 | raise RuntimeError( 330 | "Outcome constraints are not supported for PenalizedMCObjective in " 331 | + "InternalRegularized Acqf." 332 | ) 333 | 334 | obj_tf: Callable[ 335 | [Tensor, Optional[Tensor]], Tensor 336 | ] = get_objective_weights_transform(objective_weights) 337 | 338 | def objective(samples: Tensor, X: Optional[Tensor] = None) -> Tensor: 339 | return obj_tf(samples) # pyre-ignore 340 | 341 | mc_objective_kwargs = { 342 | "penalty_objective": self.penalty_term, 343 | "regularization_parameter": self.regularization_parameter, 344 | "expand_dim": self.expand_dim, 345 | } 346 | objective = PenalizedMCObjective(objective=objective, **mc_objective_kwargs) 347 | return objective, None 348 | 349 | def _optimize_with_homotopy( 350 | self, 351 | n: int, 352 | search_space_digest: SearchSpaceDigest, 353 | fixed_features: Optional[Dict[int, float]] = None, 354 | rounding_func: Optional[Callable[[Tensor], Tensor]] = None, 355 | optimizer_options: Optional[Dict[str, Any]] = None, 356 | ) -> Tuple[Tensor, Tensor]: 357 | """Optimize IR ACQF with L0 norm using homotopy.""" 358 | _tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device) 359 | ssd = search_space_digest 360 | bounds = _tensorize(ssd.bounds).t() 361 | 362 | homotopy_schedule = LogLinearHomotopySchedule(start=0.1, end=1e-3, num_steps=30) 363 | 364 | # Prepare arguments for optimizer 365 | optimizer_options_with_defaults = optimizer_argparse( 366 | self.acqf, 367 | bounds=bounds, 368 | q=n, 369 | optimizer_options=optimizer_options, 370 | ) 371 | print(f"optimizer options: {optimizer_options_with_defaults}") 372 | 373 | def callback(): # pyre-ignore 374 | X_pending = self.acqf.X_pending 375 | self.acqf.__init__( # pyre-ignore 376 | X_baseline=self.X_observed, 377 | model=self.surrogates["regularized_bo"].model, 378 | objective=self.acqf.objective, 379 | posterior_transform=self.acqf.posterior_transform, 380 | prune_baseline=self.options.get("prune_baseline", True), 381 | cache_root=self.options.get("cache_root", True), 382 | ) 383 | self.acqf.set_X_pending(X_pending) 384 | 385 | homotopy = Homotopy( 386 | homotopy_parameters=[ 387 | HomotopyParameter( 388 | parameter=self.penalty_term._f.a, 389 | schedule=homotopy_schedule, 390 | ) 391 | ], 392 | callbacks=[callback], 393 | ) 394 | # need to know sparse dimensions 395 | batch_initial_conditions = get_batch_initial_conditions( 396 | acq_function=self.acqf, 397 | raw_samples=optimizer_options_with_defaults["raw_samples"], 398 | X_pareto=self.X_pareto, 399 | target_point=self.target_point, 400 | num_restarts=optimizer_options_with_defaults["num_restarts"], 401 | **{"device": self.device, "dtype": self.dtype}, 402 | ) 403 | 404 | candidates, expected_acquisition_value = optimize_acqf_homotopy( 405 | q=n, 406 | acq_function=self.acqf, 407 | bounds=bounds, 408 | homotopy=homotopy, 409 | num_restarts=optimizer_options_with_defaults["num_restarts"], 410 | raw_samples=optimizer_options_with_defaults["raw_samples"], 411 | post_processing_func=rounding_func, 412 | fixed_features=fixed_features, 413 | batch_initial_conditions=batch_initial_conditions, 414 | ) 415 | return candidates, expected_acquisition_value 416 | -------------------------------------------------------------------------------- /sebo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "dataExplorerConfig": {}, 4 | "bento_stylesheets": { 5 | "bento/extensions/flow/main.css": true, 6 | "bento/extensions/kernel_selector/main.css": true, 7 | "bento/extensions/kernel_ui/main.css": true, 8 | "bento/extensions/new_kernel/main.css": true, 9 | "bento/extensions/system_usage/main.css": true, 10 | "bento/extensions/theme/main.css": true 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "language": "python", 15 | "name": "bento_kernel_default", 16 | "cinder_runtime": true, 17 | "ipyflow_runtime": false, 18 | "metadata": { 19 | "kernel_name": "bento_kernel_default", 20 | "nightly_builds": true, 21 | "fbpkg_supported": true, 22 | "cinder_runtime": true, 23 | "ipyflow_runtime": false, 24 | "is_prebuilt": true 25 | } 26 | }, 27 | "language_info": { 28 | "codemirror_mode": { 29 | "name": "ipython", 30 | "version": 3 31 | }, 32 | "file_extension": ".py", 33 | "mimetype": "text/x-python", 34 | "name": "python", 35 | "nbconvert_exporter": "python", 36 | "pygments_lexer": "ipython3" 37 | }, 38 | "last_server_session_id": "a0df38bf-8094-420c-b2b3-9649b202e414", 39 | "last_kernel_id": "392629ed-f75b-424c-b72c-741d535cb33e", 40 | "last_base_url": "https://devvm15546.prn0.facebook.com:8090/", 41 | "last_msg_id": "112aee66-acb0fdce70dcbeda5cbc6ff7_286", 42 | "captumWidgetMessage": {}, 43 | "outputWidgetContext": {} 44 | }, 45 | "nbformat": 4, 46 | "nbformat_minor": 2, 47 | "cells": [ 48 | { 49 | "cell_type": "markdown", 50 | "metadata": { 51 | "collapsed": true, 52 | "originalKey": "d3a0136e-94fa-477c-a839-20e5b7f1cdd2", 53 | "showInput": false, 54 | "customInput": null 55 | }, 56 | "source": [ 57 | "# Sparsity Exploration Bayesian Optimization (SEBO) Ax API \n", 58 | "\n", 59 | "This tutorial introduces the Sparsity Exploration Bayesian Optimization (SEBO) method and demonstrates how to utilize it using the Ax API. SEBO is designed to enhance Bayesian Optimization (BO) by taking the interpretability and simplicity of configurations into consideration. In essence, SEBO incorporates sparsity, modeled as the $L_0$ norm, as an additional objective in BO. By employing multi-objective optimization techniques such as Expected Hyper-Volume Improvement, SEBO enables the joint optimization of objectives while simultaneously incorporating feature-level sparsity. This allows users to efficiently explore different trade-offs between objectives and sparsity.\n", 60 | "\n", 61 | "\n", 62 | "For a more detailed understanding of the SEBO algorithm, please refer to the following publication:\n", 63 | "\n", 64 | "[1] [S. Liu, Q. Feng, D. Eriksson, B. Letham and E. Bakshy. Sparse Bayesian Optimization. International Conference on Artificial Intelligence and Statistics, 2023.](https://proceedings.mlr.press/v206/liu23b/liu23b.pdf)\n", 65 | "\n", 66 | "By following this tutorial, you will learn how to leverage the SEBO method through the Ax API, empowering you to effectively balance objectives and sparsity in your optimization tasks. Let's get started!" 67 | ], 68 | "attachments": {} 69 | }, 70 | { 71 | "cell_type": "code", 72 | "metadata": { 73 | "originalKey": "cea96143-019a-41c1-a388-545f48992db9", 74 | "showInput": true, 75 | "collapsed": false, 76 | "requestMsgId": "c2c22a5d-aee0-4a1e-98d9-b360aa1851ff", 77 | "executionStartTime": 1689117385062, 78 | "executionStopTime": 1689117389874, 79 | "customOutput": null 80 | }, 81 | "source": [ 82 | "import os\n", 83 | "\n", 84 | "from ax import Data, Experiment, ParameterType, RangeParameter, SearchSpace\n", 85 | "from ax.modelbridge.registry import Models\n", 86 | "from ax.runners.synthetic import SyntheticRunner\n", 87 | "\n", 88 | "import warnings\n", 89 | "warnings.filterwarnings('ignore')" 90 | ], 91 | "execution_count": 1, 92 | "outputs": [ 93 | { 94 | "output_type": "stream", 95 | "name": "stderr", 96 | "text": [ 97 | "I0711 161625.198 _utils_internal.py:199] NCCL_DEBUG env var is set to None\n" 98 | ] 99 | }, 100 | { 101 | "output_type": "stream", 102 | "name": "stderr", 103 | "text": [ 104 | "I0711 161625.200 _utils_internal.py:217] NCCL_DEBUG is forced to WARN from None\n" 105 | ] 106 | } 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "metadata": { 112 | "originalKey": "89cb2c13-8484-4bf9-82e0-3bed87ceb838", 113 | "showInput": true, 114 | "customInput": null, 115 | "collapsed": false, 116 | "requestMsgId": "abc49ffd-df0a-4f2a-b460-73a89d73b361", 117 | "executionStartTime": 1689117389896, 118 | "executionStopTime": 1689117389898, 119 | "customOutput": null 120 | }, 121 | "source": [ 122 | "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")" 123 | ], 124 | "execution_count": 2, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "metadata": { 130 | "originalKey": "1f13d0a1-accf-4faf-b40e-fbc21aeb94d9", 131 | "showInput": true, 132 | "customInput": null, 133 | "collapsed": false, 134 | "requestMsgId": "b360f1fd-9b8e-43c1-ab93-48df1580a9fb", 135 | "executionStartTime": 1689117389905, 136 | "executionStopTime": 1689117389913, 137 | "customOutput": null 138 | }, 139 | "source": [ 140 | "import torch\n", 141 | "\n", 142 | "\n", 143 | "torch.manual_seed(12345) # To always get the same Sobol points\n", 144 | "tkwargs = {\n", 145 | " \"dtype\": torch.double,\n", 146 | " \"device\": torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n", 147 | "}" 148 | ], 149 | "execution_count": 3, 150 | "outputs": [] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "originalKey": "7f07af01-ad58-4cfb-beca-f624310d278d", 156 | "showInput": false, 157 | "customInput": null 158 | }, 159 | "source": [ 160 | "# Demo of using Developer API" 161 | ], 162 | "attachments": {} 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": { 167 | "originalKey": "c8a27a2f-1120-4894-9302-48bfde402268", 168 | "showInput": false, 169 | "customInput": null 170 | }, 171 | "source": [ 172 | "## Problem Setup \n", 173 | "\n", 174 | "In this simple experiment we use the Branin function embedded in a 10-dimensional space. Additional resources:\n", 175 | "- To set up a custom metric for your problem, refer to the dedicated section of the Developer API tutorial: https://ax.dev/tutorials/gpei_hartmann_developer.html#8.-Defining-custom-metrics.\n", 176 | "- To avoid needing to setup up custom metrics by Ax Service API: https://ax.dev/tutorials/gpei_hartmann_service.html." 177 | ], 178 | "attachments": {} 179 | }, 180 | { 181 | "cell_type": "code", 182 | "metadata": { 183 | "originalKey": "e91fc838-9f47-44f1-99ac-4477df208566", 184 | "showInput": true, 185 | "customInput": null, 186 | "collapsed": false, 187 | "requestMsgId": "1591e6b0-fa9b-4b9f-be72-683dccbe923a", 188 | "executionStartTime": 1689117390036, 189 | "executionStopTime": 1689117390038 190 | }, 191 | "source": [ 192 | "import math \n", 193 | "import numpy as np\n", 194 | "\n", 195 | "\n", 196 | "aug_dim = 8 \n", 197 | "\n", 198 | "# evaluation function \n", 199 | "def branin_augment(x_vec, augment_dim):\n", 200 | " assert len(x_vec) == augment_dim\n", 201 | " x1, x2 = (\n", 202 | " 15 * x_vec[0] - 5,\n", 203 | " 15 * x_vec[1],\n", 204 | " ) # Only dimensions 0 and augment_dim-1 affect the value of the function\n", 205 | " t1 = x2 - 5.1 / (4 * math.pi**2) * x1**2 + 5 / math.pi * x1 - 6\n", 206 | " t2 = 10 * (1 - 1 / (8 * math.pi)) * np.cos(x1)\n", 207 | " return t1**2 + t2 + 10" 208 | ], 209 | "execution_count": 4, 210 | "outputs": [] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "metadata": { 215 | "originalKey": "850830c6-509f-4087-bce8-da0be4fd48ef", 216 | "showInput": true, 217 | "customInput": null, 218 | "collapsed": false, 219 | "requestMsgId": "56726053-205d-4d7e-b1b5-1a76324188ee", 220 | "executionStartTime": 1689117390518, 221 | "executionStopTime": 1689117390540, 222 | "customOutput": null 223 | }, 224 | "source": [ 225 | "from ax.core.objective import Objective\n", 226 | "from ax.core.optimization_config import OptimizationConfig\n", 227 | "from ax.metrics.noisy_function import NoisyFunctionMetric\n", 228 | "from ax.utils.common.typeutils import checked_cast\n", 229 | "\n", 230 | "\n", 231 | "class AugBraninMetric(NoisyFunctionMetric):\n", 232 | " def f(self, x: np.ndarray) -> float:\n", 233 | " return checked_cast(float, branin_augment(x_vec=x, augment_dim=aug_dim))\n", 234 | "\n", 235 | "\n", 236 | "# Create search space in Ax \n", 237 | "search_space = SearchSpace(\n", 238 | " parameters=[\n", 239 | " RangeParameter(\n", 240 | " name=f\"x{i}\",\n", 241 | " parameter_type=ParameterType.FLOAT, \n", 242 | " lower=0.0, upper=1.0\n", 243 | " )\n", 244 | " for i in range(aug_dim)\n", 245 | " ]\n", 246 | ")" 247 | ], 248 | "execution_count": 5, 249 | "outputs": [] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "metadata": { 254 | "originalKey": "d039b709-67c6-475a-96ce-290f869e0f88", 255 | "showInput": true, 256 | "customInput": null, 257 | "collapsed": false, 258 | "requestMsgId": "3e23ed64-7d10-430b-b790-91a0c7cf72fe", 259 | "executionStartTime": 1689117391899, 260 | "executionStopTime": 1689117391915 261 | }, 262 | "source": [ 263 | "# Create optimization goals \n", 264 | "optimization_config = OptimizationConfig(\n", 265 | " objective=Objective(\n", 266 | " metric=AugBraninMetric(\n", 267 | " name=\"objective\",\n", 268 | " param_names=[f\"x{i}\" for i in range(aug_dim)],\n", 269 | " noise_sd=None, # Set noise_sd=None if you want to learn the noise, otherwise it defaults to 1e-6\n", 270 | " ),\n", 271 | " minimize=True,\n", 272 | " )\n", 273 | ")\n", 274 | "\n", 275 | "# Experiment\n", 276 | "experiment = Experiment(\n", 277 | " name=\"sebo_experiment\",\n", 278 | " search_space=search_space,\n", 279 | " optimization_config=optimization_config,\n", 280 | " runner=SyntheticRunner(),\n", 281 | ")\n", 282 | "\n", 283 | "# target sparse point to regularize towards to. Here we set target sparse value being zero for all the parameters. \n", 284 | "target_point = torch.tensor([0 for _ in range(aug_dim)], **tkwargs)" 285 | ], 286 | "execution_count": 6, 287 | "outputs": [] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": { 292 | "originalKey": "e57edb00-eafc-4d07-bdb9-e8cf073b4caa", 293 | "showInput": false, 294 | "customInput": null 295 | }, 296 | "source": [ 297 | "## Run optimization loop" 298 | ], 299 | "attachments": {} 300 | }, 301 | { 302 | "cell_type": "code", 303 | "metadata": { 304 | "originalKey": "d0f279d5-da98-44da-9a4e-c30553e4d95a", 305 | "showInput": true, 306 | "customInput": null, 307 | "collapsed": false, 308 | "requestMsgId": "20d42853-0502-4a5c-8749-7fc1dcbc9879", 309 | "executionStartTime": 1689117393959, 310 | "executionStopTime": 1689117393962, 311 | "customOutput": null 312 | }, 313 | "source": [ 314 | "import torch \n", 315 | "from ax.models.torch.botorch_modular.surrogate import Surrogate\n", 316 | "from botorch.models import SingleTaskGP, FixedNoiseGP, SaasFullyBayesianSingleTaskGP\n", 317 | "from ax.models.torch.botorch_modular.sebo import SEBOAcquisition\n", 318 | "from botorch.acquisition.multi_objective import qNoisyExpectedHypervolumeImprovement" 319 | ], 320 | "execution_count": 7, 321 | "outputs": [] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "metadata": { 326 | "originalKey": "c4848148-bff5-44a7-9ad5-41e78ccb413c", 327 | "showInput": true, 328 | "customInput": null, 329 | "collapsed": false, 330 | "requestMsgId": "8aa87d22-bf89-471f-be9f-7c31f7b8bd62", 331 | "executionStartTime": 1689117395051, 332 | "executionStopTime": 1689117395069, 333 | "customOutput": null 334 | }, 335 | "source": [ 336 | "N_INIT = 10\n", 337 | "BATCH_SIZE = 1\n", 338 | "\n", 339 | "if SMOKE_TEST:\n", 340 | " N_BATCHES = 1\n", 341 | " SURROGATE_CLASS = SingleTaskGP\n", 342 | "else:\n", 343 | " N_BATCHES = 40\n", 344 | " SURROGATE_CLASS = SaasFullyBayesianSingleTaskGP\n", 345 | "\n", 346 | "print(f\"Doing {N_INIT + N_BATCHES * BATCH_SIZE} evaluations\")" 347 | ], 348 | "execution_count": 8, 349 | "outputs": [ 350 | { 351 | "output_type": "stream", 352 | "name": "stdout", 353 | "text": [ 354 | "Doing 50 evaluations\n" 355 | ] 356 | } 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "metadata": { 362 | "originalKey": "b260d85f-2797-44e3-840a-86587534b589", 363 | "showInput": true, 364 | "customInput": null, 365 | "collapsed": false, 366 | "requestMsgId": "2cc516e3-b16e-40ca-805f-dcd792c92fa6", 367 | "executionStartTime": 1689117396326, 368 | "executionStopTime": 1689117396376, 369 | "customOutput": null 370 | }, 371 | "source": [ 372 | "# Initial Sobol points\n", 373 | "sobol = Models.SOBOL(search_space=experiment.search_space)\n", 374 | "for _ in range(N_INIT):\n", 375 | " experiment.new_trial(sobol.gen(1)).run()" 376 | ], 377 | "execution_count": 9, 378 | "outputs": [] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "originalKey": "7c198035-add2-4717-be27-4fb67c4d1782", 384 | "showInput": true, 385 | "customInput": null, 386 | "collapsed": false, 387 | "requestMsgId": "d844fa20-0adf-4ba3-ace5-7253ba678db2", 388 | "executionStartTime": 1689117396900, 389 | "executionStopTime": 1689124188959, 390 | "customOutput": null 391 | }, 392 | "source": [ 393 | "data = experiment.fetch_data()\n", 394 | "\n", 395 | "for i in range(N_BATCHES):\n", 396 | "\n", 397 | " model = Models.BOTORCH_MODULAR(\n", 398 | " experiment=experiment, \n", 399 | " data=data,\n", 400 | " surrogate=Surrogate(botorch_model_class=SURROGATE_CLASS), # can use SAASGP (i.e. SaasFullyBayesianSingleTaskGP) for high-dim cases\n", 401 | " search_space=experiment.search_space,\n", 402 | " botorch_acqf_class=qNoisyExpectedHypervolumeImprovement,\n", 403 | " acquisition_class=SEBOAcquisition,\n", 404 | " acquisition_options={\n", 405 | " \"penalty\": \"L0_norm\", # it can be L0_norm or L1_norm. \n", 406 | " \"target_point\": target_point, \n", 407 | " \"sparsity_threshold\": aug_dim,\n", 408 | " },\n", 409 | " torch_device=tkwargs['device'],\n", 410 | " )\n", 411 | "\n", 412 | " generator_run = model.gen(BATCH_SIZE)\n", 413 | " trial = experiment.new_batch_trial(generator_run=generator_run)\n", 414 | " trial.run()\n", 415 | "\n", 416 | " new_data = trial.fetch_data(metrics=list(experiment.metrics.values()))\n", 417 | " data = Data.from_multiple_data([data, new_data])\n", 418 | " print(f\"Iteration: {i}, Best so far: {data.df['mean'].min():.3f}\")" 419 | ], 420 | "execution_count": 10, 421 | "outputs": [ 422 | { 423 | "output_type": "stream", 424 | "name": "stdout", 425 | "text": [ 426 | "Iteration: 0, Best so far: 2.494\n" 427 | ] 428 | }, 429 | { 430 | "output_type": "stream", 431 | "name": "stdout", 432 | "text": [ 433 | "Iteration: 1, Best so far: 2.494\n" 434 | ] 435 | }, 436 | { 437 | "output_type": "stream", 438 | "name": "stdout", 439 | "text": [ 440 | "Iteration: 2, Best so far: 2.494\n" 441 | ] 442 | }, 443 | { 444 | "output_type": "stream", 445 | "name": "stdout", 446 | "text": [ 447 | "Iteration: 3, Best so far: 2.494\n" 448 | ] 449 | }, 450 | { 451 | "output_type": "stream", 452 | "name": "stdout", 453 | "text": [ 454 | "Iteration: 4, Best so far: 2.494\n" 455 | ] 456 | }, 457 | { 458 | "output_type": "stream", 459 | "name": "stdout", 460 | "text": [ 461 | "Iteration: 5, Best so far: 2.494\n" 462 | ] 463 | }, 464 | { 465 | "output_type": "stream", 466 | "name": "stdout", 467 | "text": [ 468 | "Iteration: 6, Best so far: 2.494\n" 469 | ] 470 | }, 471 | { 472 | "output_type": "stream", 473 | "name": "stdout", 474 | "text": [ 475 | "Iteration: 7, Best so far: 2.494\n" 476 | ] 477 | }, 478 | { 479 | "output_type": "stream", 480 | "name": "stdout", 481 | "text": [ 482 | "Iteration: 8, Best so far: 2.494\n" 483 | ] 484 | }, 485 | { 486 | "output_type": "stream", 487 | "name": "stdout", 488 | "text": [ 489 | "Iteration: 9, Best so far: 2.494\n" 490 | ] 491 | }, 492 | { 493 | "output_type": "stream", 494 | "name": "stdout", 495 | "text": [ 496 | "Iteration: 10, Best so far: 2.494\n" 497 | ] 498 | }, 499 | { 500 | "output_type": "stream", 501 | "name": "stdout", 502 | "text": [ 503 | "Iteration: 11, Best so far: 1.990\n" 504 | ] 505 | }, 506 | { 507 | "output_type": "stream", 508 | "name": "stdout", 509 | "text": [ 510 | "Iteration: 12, Best so far: 1.990\n" 511 | ] 512 | }, 513 | { 514 | "output_type": "stream", 515 | "name": "stdout", 516 | "text": [ 517 | "Iteration: 13, Best so far: 1.990\n" 518 | ] 519 | }, 520 | { 521 | "output_type": "stream", 522 | "name": "stdout", 523 | "text": [ 524 | "Iteration: 14, Best so far: 1.990\n" 525 | ] 526 | }, 527 | { 528 | "output_type": "stream", 529 | "name": "stdout", 530 | "text": [ 531 | "Iteration: 15, Best so far: 1.990\n" 532 | ] 533 | }, 534 | { 535 | "output_type": "stream", 536 | "name": "stdout", 537 | "text": [ 538 | "Iteration: 16, Best so far: 0.662\n" 539 | ] 540 | }, 541 | { 542 | "output_type": "stream", 543 | "name": "stdout", 544 | "text": [ 545 | "Iteration: 17, Best so far: 0.662\n" 546 | ] 547 | }, 548 | { 549 | "output_type": "stream", 550 | "name": "stdout", 551 | "text": [ 552 | "Iteration: 18, Best so far: 0.453\n" 553 | ] 554 | }, 555 | { 556 | "output_type": "stream", 557 | "name": "stdout", 558 | "text": [ 559 | "Iteration: 19, Best so far: 0.453\n" 560 | ] 561 | }, 562 | { 563 | "output_type": "stream", 564 | "name": "stdout", 565 | "text": [ 566 | "Iteration: 20, Best so far: 0.453\n" 567 | ] 568 | }, 569 | { 570 | "output_type": "stream", 571 | "name": "stdout", 572 | "text": [ 573 | "Iteration: 21, Best so far: 0.424\n" 574 | ] 575 | }, 576 | { 577 | "output_type": "stream", 578 | "name": "stdout", 579 | "text": [ 580 | "Iteration: 22, Best so far: 0.424\n" 581 | ] 582 | }, 583 | { 584 | "output_type": "stream", 585 | "name": "stdout", 586 | "text": [ 587 | "Iteration: 23, Best so far: 0.424\n" 588 | ] 589 | }, 590 | { 591 | "output_type": "stream", 592 | "name": "stdout", 593 | "text": [ 594 | "Iteration: 24, Best so far: 0.424\n" 595 | ] 596 | }, 597 | { 598 | "output_type": "stream", 599 | "name": "stdout", 600 | "text": [ 601 | "Iteration: 25, Best so far: 0.424\n" 602 | ] 603 | }, 604 | { 605 | "output_type": "stream", 606 | "name": "stdout", 607 | "text": [ 608 | "Iteration: 26, Best so far: 0.424\n" 609 | ] 610 | }, 611 | { 612 | "output_type": "stream", 613 | "name": "stdout", 614 | "text": [ 615 | "Iteration: 27, Best so far: 0.424\n" 616 | ] 617 | }, 618 | { 619 | "output_type": "stream", 620 | "name": "stdout", 621 | "text": [ 622 | "Iteration: 28, Best so far: 0.424\n" 623 | ] 624 | }, 625 | { 626 | "output_type": "stream", 627 | "name": "stdout", 628 | "text": [ 629 | "Iteration: 29, Best so far: 0.424\n" 630 | ] 631 | }, 632 | { 633 | "output_type": "stream", 634 | "name": "stdout", 635 | "text": [ 636 | "Iteration: 30, Best so far: 0.416\n" 637 | ] 638 | }, 639 | { 640 | "output_type": "stream", 641 | "name": "stdout", 642 | "text": [ 643 | "Iteration: 31, Best so far: 0.416\n" 644 | ] 645 | }, 646 | { 647 | "output_type": "stream", 648 | "name": "stdout", 649 | "text": [ 650 | "Iteration: 32, Best so far: 0.408\n" 651 | ] 652 | }, 653 | { 654 | "output_type": "stream", 655 | "name": "stdout", 656 | "text": [ 657 | "Iteration: 33, Best so far: 0.408\n" 658 | ] 659 | }, 660 | { 661 | "output_type": "stream", 662 | "name": "stdout", 663 | "text": [ 664 | "Iteration: 34, Best so far: 0.408\n" 665 | ] 666 | }, 667 | { 668 | "output_type": "stream", 669 | "name": "stdout", 670 | "text": [ 671 | "Iteration: 35, Best so far: 0.408\n" 672 | ] 673 | }, 674 | { 675 | "output_type": "stream", 676 | "name": "stdout", 677 | "text": [ 678 | "Iteration: 36, Best so far: 0.408\n" 679 | ] 680 | }, 681 | { 682 | "output_type": "stream", 683 | "name": "stdout", 684 | "text": [ 685 | "Iteration: 37, Best so far: 0.408\n" 686 | ] 687 | }, 688 | { 689 | "output_type": "stream", 690 | "name": "stdout", 691 | "text": [ 692 | "Iteration: 38, Best so far: 0.408\n" 693 | ] 694 | }, 695 | { 696 | "output_type": "stream", 697 | "name": "stdout", 698 | "text": [ 699 | "Iteration: 39, Best so far: 0.408\n" 700 | ] 701 | } 702 | ] 703 | }, 704 | { 705 | "cell_type": "markdown", 706 | "metadata": { 707 | "originalKey": "7998635d-6750-4825-b93d-c7b61f74c3c5", 708 | "showInput": false, 709 | "customInput": null 710 | }, 711 | "source": [ 712 | "## Plot sparisty vs objective \n", 713 | "\n", 714 | "Visualize the objective and sparsity trade-offs using SEBO. Each point represent designs along the Pareto frontier found by SEBO. The x-axis corresponds to the number of active parameters used, i.e.\n", 715 | "non-sparse parameters, and the y-axis corresponds the best identified objective values. Based on this, decision-makers balance both simplicity/interpretability of generated policies and optimization performance when deciding which configuration to use." 716 | ], 717 | "attachments": {} 718 | }, 719 | { 720 | "cell_type": "code", 721 | "metadata": { 722 | "originalKey": "416ccd12-51a1-4bfe-9e10-436cd88ec6be", 723 | "showInput": true, 724 | "customInput": null, 725 | "collapsed": false, 726 | "requestMsgId": "5143ae57-1d0d-4f9d-bc9d-9d151f3e9af0", 727 | "executionStartTime": 1689124189044, 728 | "executionStopTime": 1689124189182, 729 | "customOutput": null 730 | }, 731 | "source": [ 732 | "def nnz_exact(x, sparse_point):\n", 733 | " return len(x) - (np.array(x) == np.array(sparse_point)).sum()\n", 734 | "\n", 735 | " \n", 736 | "df = data.df\n", 737 | "df['L0_norm'] = df['arm_name'].apply(lambda d: nnz_exact(list(experiment.arms_by_name[d].parameters.values()), [0 for _ in range(aug_dim)]) )" 738 | ], 739 | "execution_count": 11, 740 | "outputs": [] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "metadata": { 745 | "originalKey": "97b96822-7d7f-4a5d-8458-01ff890d2fde", 746 | "showInput": true, 747 | "customInput": null, 748 | "collapsed": false, 749 | "requestMsgId": "34abdf8d-6f0c-48a1-8700-8e2c3075a085", 750 | "executionStartTime": 1689124189219, 751 | "executionStopTime": 1689124189321, 752 | "customOutput": null 753 | }, 754 | "source": [ 755 | "result_by_sparsity = {l: df[df.L0_norm <= l]['mean'].min() for l in range(1, aug_dim+1)}\n", 756 | "result_by_sparsity" 757 | ], 758 | "execution_count": 12, 759 | "outputs": [ 760 | { 761 | "output_type": "execute_result", 762 | "data": { 763 | "text/plain": "{1: 5.915850721937628,\n 2: 0.41574213444366315,\n 3: 0.41574213444366315,\n 4: 0.40790508387544655,\n 5: 0.40790508387544655,\n 6: 0.40790508387544655,\n 7: 0.40790508387544655,\n 8: 0.40790508387544655}" 764 | }, 765 | "metadata": {}, 766 | "execution_count": 12 767 | } 768 | ] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "metadata": { 773 | "originalKey": "7193e2b0-e192-439a-b0d0-08a2029f64ca", 774 | "showInput": true, 775 | "customInput": null, 776 | "collapsed": false, 777 | "requestMsgId": "f095d820-55e0-4201-8e3a-77f17b2155f1", 778 | "executionStartTime": 1689134836494, 779 | "executionStopTime": 1689134837813, 780 | "customOutput": null 781 | }, 782 | "source": [ 783 | "import matplotlib\n", 784 | "import matplotlib.pyplot as plt\n", 785 | "import numpy as np\n", 786 | "\n", 787 | "%matplotlib inline\n", 788 | "matplotlib.rcParams.update({\"font.size\": 16})\n", 789 | "\n", 790 | "fig, ax = plt.subplots(figsize=(8, 6))\n", 791 | "ax.plot(list(result_by_sparsity.keys()), list(result_by_sparsity.values()), '.b-', label=\"sebo\", markersize=10)\n", 792 | "ax.grid(True)\n", 793 | "ax.set_title(f\"Branin, D={aug_dim}\", fontsize=20)\n", 794 | "ax.set_xlabel(\"Number of active parameters\", fontsize=20)\n", 795 | "ax.set_ylabel(\"Best value found\", fontsize=20)\n", 796 | "# ax.legend(fontsize=18)\n", 797 | "plt.show()" 798 | ], 799 | "execution_count": 20, 800 | "outputs": [ 801 | { 802 | "output_type": "display_data", 803 | "data": { 804 | "text/plain": "
", 805 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfEAAAGWCAYAAABow7qfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deZwcdZnH8U8xgQQyEAhHOAMEAoQESCZcAkIkQowil/rIiqyguOt6IigYYBUFVETFRV0F8UYXH1DkMrA6SzgUkSMhgFxyH0FIIIYEyDGp/WOelqaZo3ume6qr+vt+vebV09XVVc+vM+mnfkf9fkmapoiIiEj+rJF1ACIiIjIwSuIiIiI5pSQuIiKSU0riIiIiOaUkLiIiklNK4iIiIjmlJC4iIpJTSuIi8k9JkuyYJMnjSZKclHUsItI/JXGRIZIkSWeSJHclSfJ0kiTLkyR5LkmSW5Ik+WCSJEnW8YUNgM2BbRp5kiRJNkuS5JEkSRbHZ7EqSZJXkiSZlyTJz5Mk2XmQxx+VJMl5SZI8nCTJsiRJ7kmS5BNJkug7Twol0YxtIkMjSZIUWAFcAbwMjAYOAkYAZ6Zp+vmsYyQSYJqm/2jwObYGHgOWANcDy4AEmAjsGp/T6WmanjvA418HHAzcCTwEHAhsDHw2TdOv179EItlQEhcZIpHEF6ZpunHZtp0j0awC1kvTdHW2UQ6NsiR+e5qme1S89m7gZ8BwYJ80TW+t8di7AncBNwBvSdM0TZJkC+AB4Mk0TSfUvUAiGVHTkkiG0jT9K3AvMBIYW9qeJMmUJEn+liTJ6UmSnJwkyRNJktxT9vrBSZL8IfqvX06S5MEkSS5LkuSAynNEM/7fkiSZkCTJb5MkeTZJkhei2Xr9in0PSJIkTZLkCxXbj0+SZGmSJDOTJDkrSZL7o5n69iRJ3lznz+Qy4Fvx/fSdARxix3i8PI1aSpqmT0eNfGzfbxXJFyVxkQxFH+0Y4BXg72UvrQdsB3w6fm6NZviS4cCGsf038d5DgTlRyy03LpLXndGk3Ak8B7wfOKfKULeIC41LgaOAe4CbgA7g10mSrFeHj6PcucBqYPckSbaq8b1PxeP40oYkSTYEJgG31zdMkWwNyzoAkVaVJMlo4NRIkKekafpKD7vNAw5N03RZ+cY0Ta8Crqo43o+BY3s53UvAvmma3h/7rgMsBGbUGPYZwDdKNdwkSX4K/Gsk8zk1HqtXaZq+mCTJw5GIJwJPRiL+ah9v+1yapovSNL0lSZIrgQ8nSXIXcF00zz8PfKJeMYo0AyVxkaG1UfSNlzsjTdOv9bL/jZUJfICWlBI43Uny5SRJ7gT2qfE4t6WvH0hzUyTxRjRTL4wkvkE8bweO72P/s4BF8fu7gEuAC4E0fvZP03R+A+IUyYySuMjQWg78PH5fF9gf+HySJOulaVr1vdnRDH8EMAXYJEZ271tjLCvjfYOxshTSII/Tk3Xi8e90X3g8XsN5DotWhr8A1wAfBWYnSXJEmqadDYhVJBNK4iJD66U0TT9cehLN2jcBJyZJcn2aplf3d4BI4L+P26aKbIt4fKiWNyVJMh74RYwhODpN01VJkvwgRqx7kiTbpGn6UmNCFhlaSuIiGYpm7YujT/lAoN8kDrwt9r0JeF+apk/Rf594riRJsjmwEfAq8DSvDU7rt08cOCYG/n0zTdNVdH/OC5IkceBjwH7A7CErjEgDKYmLZG9EPFbb9126z/mSUgIvoI/H44/K7p2vtk+81Ide+XmuiEd970lh6I9ZJENRu/xQPP1LlW8rzaa2S8X20sVAW53Cy0SSJEfHbXVLgS+VttfQJ35fPL47boUrdVu8I7ZrcJsUhpK4yNAalSTJZVFLXA84IGqOV1bZlE7c5/1q3EI1IW6d2i3uNwe4MEmSG9I0PbOB5ehVkiTbAI/GbXO9jbov2SFJksvj8xgGTI7JWhYCx6Rp+vd+3t8TB04HvpAkyTRgQYzCHwt8Py4GRApBk72IDJ07o7n3ndFvewDwSNy7/N60yjmQ0zR9NG6huhfYI5rXLwC2jnu19yhrOs5CaRa4VX3ssxy4Oy5G3g78S4woXwL8N7BbmqbXDuTkaZouBKYDv44LgncCLwIn6D5xKRrNnS4idZUkyX9EIj4wTdPrs45HpMiUxEWkbpIkWStaCFYCE6ttXRCRgVGfuIjU0zmxFvk+SuAijac+cRGpiyRJ1o0JWj6fpultWccj0grUnC4iIpJTuWtO7+zs1FWHiIi0nOnTp79hnoTcJXG6C1K3Y82ZM4dp06bV7XjNSuUsFpWzWFTOYmlEOTs7e163R33iIiIiOaUkLiIiklNK4iIiIjmlJC4iIpJTTTGwzcwmxyQRe0ZM33P3k7OOS0REpJllnsTNbBfgRmA1cDOwZiwKISIiIn3IPIkDJ0YcO7v7Y1kHIyIikheZJnEzWwN4b9TE9zezU4F5wAXu3pVlbCIiIs0u02lXzWxTYAGwFGgve+kCd/9IT+/p7OxM29ra6hbD0qVLaW9vr2LPfFM5i0XlLBaVs1gaUc6urq6mnLFt83jsAiYBS4AbgOPN7GR3X9LTm+oxE05XF8yeDb/5zaMceeS2zJwJdbw2aDqaKalYVM5iUTmLZShnbMs6iZfS5v+5+710186vAE4AtgPmNuKkXV0wYwbceissW7YNl14Ke+0F111X7EQuIiLFkvV94n+PxxVl21b1sK2uZs/uTuBLl0KaJixd2v189uxGnVFERKT+sq6JL4gm9N3NbFg0q+8LrAQebdRJ586FZctev23ZMpg3Dw45pFFnFRERqa9Ma+LuvhL4aTSd3wb8BXgT8EN3f7lR550yBUaOfP22kSNh8uRGnVFERKT+sm5OBzgFuAjYDNgS+C7w6UaecObM7j7w7sGDKcOHdz+fObORZxUREamvrJvTcfdXgA8P5Tnb2roHsc2eDWedtYgNNtiIq6/WoDYREcmXZqiJZ6Ktrbv/++ijn2TRIiVwERHJn5ZN4iXbb7+Ue+6BlSuzjkRERKQ2LZ/E1167i623hvvuyzoSERGR2rR8Egfo6IA77sg6ChERkdooiUcSv/POrKMQERGpjZK4kriIiOSUknhM/nLXXd1zqouIiOSFkjiw/vowZgw8+GDWkYiIiFRPSTyoSV1ERPJGSTxMnaokLiIi+aIkHlQTFxGRvFESD1OmdCfx1auzjkRERKQ6SuJh441h1Ch45JGsIxEREamOkngZNamLiEieKImXURIXEZE8URIvoyQuIiJ5oiReppTE0zTrSERERPqnJF5m881hzTXhySezjkRERKR/SuIV1KQuIiJ5oSReQWuLi4hIXiiJV1BNXERE8kJJvIKSuIiI5IWSeIWxY2HFCliwIOtIRERE+qYkXiFJVBsXEZF8UBLvgZK4iIjkgZJ4D7S2uIiI5IGSeA9UExcRkTxQEu/BuHGweDEsXJh1JCIiIr1TEu/BGmvAlCmqjYuISHNTEu+FmtRFRKTZKYn3QklcRESanZJ4L5TERUSk2SmJ92LHHeHZZ7sHuImIiDQjJfFetLXBbrvBvHlZRyIiItIzJfE+qEldRESamZJ4H5TERUSkmSmJ96GjA+64I+soREREeqYk3oedd4bHH4elS7OORERE5I2UxPuw5powcSLcdVfWkYiIiLyRkng/1C8uIiLNSkm8H0riIiLSrJTE+6G1xUVEpFkpifdj0iR46CF49dWsIxEREXk9JfF+jBgBO+wAd9+ddSQiIiKvNyzLk5vZDODais0vu/vIjELqUel+8T32yDoSERGR12SaxIGt4rETeD5+X5RhPD3S4DYREWlGSZqmmZ3czL4IfB7Yzd3nV/Oezs7OtK2trW4xLF26lPb29j73ueee9fj2t8dzwQX5nb6tmnIWgcpZLCpnsaicA9fV1cX06dOTyu1Z18Q3jMeaat/Tpk2rWwBz5szp93h77AEnnwz77DONtdaq26mHVDXlLAKVs1hUzmJROQeus7Ozx+1ZD2wr9X1/2cxON7MdMo6nRyNHwjbbwF//mnUkIiIir2mWJP6vwJnA/WZ2XMYx9Uj3i4uISLPJOokfEzGMBN4LdAHnm1n9Or3rRIPbRESk2WTaJ+7uy+PXlwE3s1nAZGAc8FCWsVXq6IBLL806ChERkddkXRP/JzPbANgZWAk8kXU8lSZPhvnzYdWqrCMRERHplvVkL98D9gOeBPYF1gLOKauhN41Ro2CzzeCBB7qXJxUREcla1jXxJXEhcQDwFHACcFrGMfVK/eIiItJMsu4TPwU4JcsYalFK4scck3UkIiIi2dfEc0U1cRERaSZK4jXo6IC5c2H16qwjERERURKvyYYbwujR8PDDWUciIiKiJF4zNamLiEizUBKvUWltcRERkawpiddINXEREWkWSuI1KiXxDJdhFxERASXx2m26KYwYAY8/nnUkIiLS6pTEB0BN6iIi0gyUxAdAa4uLiEgzUBIfANXERUSkGSiJD0DpNjMNbhMRkSwpiQ/Allt2T736zDNZRyIiIq1MSXwAkkRN6iIikj0l8QFSEhcRkawpiQ+QkriIiGRNSXyAlMRFRCRrw3p7wcyWASMGcMzU3Xs9blGMGwcvvQTPPQebbJJ1NCIi0or6SrazgdFlzycCGwM3AV097D869rm5AXE2ndLgtrlzYcaMrKMREZFW1GsSd/d3lz83s5uBV939gJ72N7ORwOPAtxsRaDMqNakriYuISBZq6ROfCNza24vuvixq4Z+rT2jNT2uLi4hIlmpJ4i9WNK/3ZAWw0yBjyg0NbhMRkSzVksTvAKaZ2eE9vWhmOwBvBZ6qX3jNbfx4eP55ePHFrCMREZFWVMso8pOBGcCvzexu4G5gITAS2AaYBrQBZzcw3qbS1ga77dY9uO3AA7OORkREWk3VNXF3fxTYL0at7wocDXwKOD5q4IuB09z9G40NubmoSV1ERLJS0/3c7j4fOMTMNgK2BcYCy4EngfvdfXnjQm1OU6fCdddlHYWIiLSiAU3K4u4Loyn9tvqHlC8dHfCVr2QdhYiItCJNuzpIEybAk092z94mIiIylGqqiZvZ1sAB0Yye9LJb6u5n1Se85jdsGEyaBPPmwZvfnHU0IiLSSqpO4nFr2S+B4X0kcIAUaJkkTtngNiVxEREZSrXUxM+OW8jOjPnTVzYwrlzp6ICbW2LGeBERaSa1JPGtgWvc/YwGxpNLHR1w/vlZRyEiIq2mloFtjzQwjlybNAkefhhefjnrSEREpJXUksR/DOxuZm0NjCeXhg+HnXaCu+/OOhIREWkltTSn/wNYHzjXzOb2taO7/3zwoeVLaXDbXntlHYmIiLSKWpL414F24IQYgd6TJF5r2SQuIiIyVGpJ4kcDGzcwllzr6IAf/jDrKEREpJVUncTdfXZjQ8m3XXeF++6D5cu7+8hFREQaTdOu1sk668C4cXDvvVlHIiIirUJJvI7ULy4iIkOplmlXL4wZ2/qTuvvxgwsrn5TERURkKNUysO2wKge2pUBLJvGpU+GSS7KOQkREWkUtSbyjn5r4boADA1oGxMymANcC57v72QM5RtYmT+6e8GXVqu7VzURERBqpltHpT/ezyxNm9gRwHHD7AGI5G9gEuH4A720K664LW24J99/fPRWriIhIIyVp2tu8LbUzs6uBHd19fI3v2wP4C9Dp7m/ta9/Ozs60ra1+M78uXbqU9vb2uh3vzDMnsOeeLzBjxt/rdsx6qHc5m5XKWSwqZ7GonAPX1dXF9OnT37AMeL0bfbcGNh3A+06Nx6rWIZ82bdoATtGzOXPm1PV4t90GTz01hmnTJtTtmPVQ73I2K5WzWFTOYlE5B66zs7PH7bWMTj+tj1vShgNvBXYGbqwlMDPbDjgU+COwyMxOiSVP76nlOM2iowOuvDLrKEREpBXUUhM/Adiwn30eBD5eYwyfiIuDM2Pw3FeBZ4FcJvEpU2DePFi9GtbQXfgiItJAtSTxQ6LG3ZMUeAp4zN2r7mQ3s3VjINxt7n6dmX2ghnia0ujRsNFG8NBDsOOOWUcjIiJFVsvo9FsbcP5pwHrAGDO7CRgT208zs2XuflkDztlwU6d2T/qiJC4iIo2UdYNv6fxjgf2A0qj28cBaGcY1KJq5TUREhkLNo9PNbKMYxDYO2Dya0R8Cfufur9RyLHe/ItYgLx37C8AZwLHu/staY2sWHR1w7rlZRyEiIkVXUxI3sxOBzwPrliff6BN/xsxOdPdL6x9mvkyZ0l0TT1NI3nBXn4iISH1U3ZxuZocDXwe6gG8D/wa8Kx6/DYwEfmFmew4injlRE583iGNkbswYGDkSHn0060hERKTIaqmJfwZ4Eehw98crXzSzbwFzgS8CMwcSjLvfANwwkPc2m1K/+LhxWUciIiJFVcvAtp2B63tK4HQn4MeA/4t7vVueBreJiEij1ZLEX4wm876sBawcZEyFoCQuIiKNVksSvwWYbmaH9PSimR0Qo9bvql94+VW6V7yO68uIiIi8Ti194p8FDgauMLO7YlrURcD6wE7AnsAKYFYD482NzTfvHpn+9NPdy5OKiIjUW9U1cXdfEBOyXAlMBt4PfAr4ALBXrCH+Nnef39iQ8yFJ1KQuIiKNVdN94u7+IHCEmW0cs6ptGYuVPObuTzQuzHwqJfFDD806EhERKaIBrSfu7s8Dz9c/nGLp6ICf/CTrKEREpKh6bU43s1+b2UfLnl9hZscNWWQFoOZ0ERFppL76xA8CDix7/k5g7yGIqTC22QZefhmefTbrSEREpIj6ak7/C3CwmZ0OlCZ42cHMjunvoO7+8/qFmF+lwW1z58LMAc1hJyIi0ru+kviJwLXAl2KBkxTYP356k8R+SuKhdL+4kriIiNRbr0nc3eeb2U7APsAY4Mcxr/mPhzbEfOvogEtbfl03ERFphD5Hp7v7kqiNl2Zku8HdfzZk0RVARwfM0vQ3IiLSAFXfYubuH2xsKMW0/fawaBG88AKMHp11NCIiUiS1zJ0uA7DGGjB5sm41ExGR+lMSHwK6X1xERBpBSXwIKImLiEgjKIkPASVxERFpBCXxIbDTTt1Lki5ZknUkIiJSJDUtgGJmmwMnxVKkWwKr3X1CvLYB8B/AX939tw2LOIeGDYNdd4V582D/vqbKERERqUHVSTwmfrkNGAksAkYA65Tt8grw2ZiuVUm8QqlJXUlcRETqpZbm9G/E/vu6+8bAJeUvuvurkcAn1T/M/FO/uIiI1FstSXx34Fp3vyWepz3s8zywfp1iK5SODrjjjqyjEBGRIqklia+IBU76si3wyCBjKqSJE+HRR2HZsqwjERGRoqglif8ZmGFme/T0opm9D3hT7CcV1loLJkyA+fOzjkRERIqiltHpnwbeAtxiZvOBjehO3lcD2wPjY8DbqY0LN99K/eJvelPWkYiISBFUXRN396eADuAnwE5xi1kCvB3YDrgM2Mvdn29syPlVWltcRESkHmq6T9zdnwCOB443s82ArYCFwBPuvqpxYRZDRwdccEHWUYiISFHUlMTLufsCYEF9wym2XXaBBx6A5cth+PCsoxERkbzTtKtDaO21u9cXv+eerCMREZEiqGXGttOqTPqpu581uLCKq3S/+NSpWUciIiJ5V0tz+gnAhv3sU5oARkm8F5q5TURE6qWWJH4I0FdP7q7AucC76xBXYXV0wMUXZx2FiIgUQdVJ3N1v7WeXG83sM8A+wDWDD62YJk/u7hNfuRLWXDPraEREJM/qPbDtPuCIOh+zUNrbYeut4b77so5ERETyrt5JfDQwts7HLBz1i4uISD3UMjr9/X0sgDIcOAjYA9BaXf0oJfFjj806EhERybNaBradD4zq5bVScl8Uc6xLHzo64PLLs45CRETyrpYk/m/AOr28lgJPAXe6+z/qFFthTZkCd90FXV3Q1pZ1NCIikle1jE6/rLGhtI7114dNNoEHH+xenlRERGQgNO1qRjS4TUREBmvAC6DUi5ldBOwXo9oXAQ7McvcVWcfWSKUkfvTRWUciIiJ5laRp2uMLZnYtMJDpSFJ3f2u1O5vZ74AxwAsxUcw6wFfdfVZP+3d2dqZtdexIXrp0Ke3t7XU7XrVuu20DfvnLsZx33l1Dcr6syjnUVM5iUTmLReUcuK6uLqZPn/6GO8T6qonvAmw2gHP1fFXQC3d/e+l3MzsM+G3UzHs1bdq0AYTVszlz5tT1eNWaOBHOPhv2338aawxBp0ZW5RxqKmexqJzFonIOXGdnZ4/be03i7r5FXSOoTeHvNd94Y1hvPXj0Udhuu6yjERGRPGqGPvGNgG8DawMHAH8CvpZ1XEOh1C+uJC4iIgNR14ZcM/uImb2vxreNBI4CDgPWB54HXq5nXM1KI9RFRGQw6t0be3DM7FY1d3/c3ZOYDe6kSOZX1zmuptTRAXcUvuNAREQapZa50ydX2Vc9byCBuPsS4JtmdhKwr5mNKvrsb6WaeJpC0tus9CIiIr2opU/8hBh5fmM83zFGr8+J5yOACcAhg4wpAbqAQt8nDrDZZjBsGDz5JIzV2m8iIlKjWpL4HsDt7n4g3TXzC4EPlT1PgMeB6cDF1R40+tDnAC8C/xoXBre7+ysDKVCeJAlMndpdG1cSFxGRWtXSJz4W+FtvL7p7CtxeyypmZrZeJPynYzDb94ElwIdriCvXNLhNREQGqpYkvqRiPfFFdCfiDcu2jQB2qOGYq4D3A18GfgScAmzn7gPqV88jJXERERmoWprTHwS2LXv+10jq3zWzn8QMb2+JJvWquPvLwC9rC7lYOjrgYx/LOgoREcmjWpL4RcB6Zc8vAT4HGPCeSOgpcHYD4iyssWNh+XJYsKB7oJuIiEi1allP/BcVz1ea2e7AsTEq/R/ANe7+54ZEWlBJ8lqT+jvekXU0IiKSJ4OadjVGkH+vfuG0JiVxEREZiFome/lN3CN+tbv3OkpdatfRAZdcknUUIiKSN7XUxKcDhwPfMLMHgSuBq9z95gbG1xKmToWTT846ChERyZtakvgYYEYk8ncAnwU+Y2ZPxwjz/3H3uxoYa2GNGweLF8PChbDRRllHIyIieVH1feLu/qq7X+HuxwGbxrKh/wWsBE4G7jSze81sVmNDLp411oApU2Du3KwjERGRPBnQwDZ3Xw3cFD8nmtnesQLZkcBZwFfqH2qxlQa3HXRQ1pGIiEheDHh0upltDBwazevTgeFxr/it9Q2xNXR0wFVXZR2FiIjkSU1J3MzGRdI+HHhTNMcnwPyY/OUSd3+sceEWV0cHnHFG1lGIiEie1HKL2Z3AbmUzs90JXN3duu73NTbM4ttxR3j22e4Bbuuvn3U0IiKSB7XUxDcBrorEfY27L2hgXC2nrQ123RXmzYNp07KORkRE8qCWJL6Nu69qYCwtr7S2uJK4iIhUo5ZbzJTAG0zLkoqISC1qWU9cGkxJXEREaqEk3kR23hkeewyWLcs6EhERyQMl8Say5powcSLcpclrRUSkClUncTMbY2Z9DoQzs5PM7D/qElmL6uiAO+7IOgoREcmDWmrizwDf6WefmcDHBhlTS1O/uIiIVKuWJJ7ET19WA1sNMqaWpiQuIiLV6q95fCawe9mmDjP7zx52TYDNgWmaO31wdtkFHnoIXn0VRozIOhoREWlm/U32shPwxfg9BabGT28eUnP64IwYATvsAHffDXvskXU0IiLSzPpL4j8H7oia9vUx5erXe9ivC3gWeDSWKZVBKDWpK4mLiEhf+kzi7r4QuJHupvUzgDvd/cYhi65FqV9cRESqUfXc6e7+pcaGIiUdHfDTn2YdhYiINLua1hPviZltAOwDLAJu1xzrg7fbbnDvvbBiBay1VtbRiIhIs6plspdrzexPFdv2BP4GXAn8EfizmW3ckEhbyMiRsM028Ne/Zh2JiIg0s1ruE98RWLv0JGZv+yGwQQx6+wPQAZzWmFBbi/rFRUSkP7Uk8TFxC1nJccBE4AZ3f6u7z4iR7Ic1IM6WU1pbXEREpDe1JPEngS3oroWvAZwQ945/pWyfB4FN6x9m61FNXERE+lNLEv81sJeZ/SbuF58A3Oju/1u2z1jghQbE2XImT4b586GrK+tIRESkWdWSxM8B/gQcDrwNWAx8vPSimY0G9gLmNibU1jJqFGy2GTzwQNaRiIhIs6o6ibv7P4C3AHsCRwA7uvu9ZbssA94FnN2YUFuPmtRFRKQvNd0n7u5dwO3xU/nacuCqukbX4kpri7///VlHIiIizWhAk73EwLYxQOruz9Y/LCGS+FlnZR2FiIg0q5qSuJkdBnwemBTvTUvHMLPNgJ8Cne5+TsMibiEdHTB3LqxeDWvUMnpBRERaQi0ztr0TuDxuM7sSeDhWN4Pu5vQFMSHMuxoWbYvZcEMYPRoefjjrSEREpBnVUr+bBSwAdnL39wBzethnLrBtHeNreRrcJiIivaklie8MzHH3xX3s8yqwZh3ikqAkLiIivakliT/T12xsZpYAuwG6s7mOlMRFRKQ3tQxsuxb4lJmdBPyg/AUz2xT4ArADcEa1BzSzL8Zc6+Ojqf777v71mkpQcKUknqaQJFW8QUREWkatfeK3A+cCL8YCKJjZMuBp4N9jAZSv9H+ofxoTTfA3A1sC55rZsbUXo7g23RSGD4fHH886EhERaTZJmqZV7xxN5kcDxwDbAFsBC4FHgF8BF8aEMNUebw13Xx2/Hx81/Evd3Xp7T2dnZ9rW1lZ1zP1ZunQp7e3tdTteI8yatQszZy5g//0XDvgYeShnPaicxaJyFovKOXBdXV1Mnz79De2xtc7YlgIXx8+glRJ4eDoeV/ey+z9NmzatHqcHYM6cOXU9XiMcdBCsWLEhgwkzD+WsB5WzWFTOYlE5B66zs7PH7c00hciH4/HKjONoOlpbXEREelJ1TdzMxgJ/jznSS9veDhwYi5/8xd2vGUgQZnZKLKoyG7hkIMcostIc6hrcJiIi5fpN4mZ2BnAiMBJ4xcy+4O7fMLOvA58um7UtNbMrgKPcfUU1J4852M8GPgf8DrCKJnYBttyye+rVBQtg882zjkZERJpFn0nczI6LudKXx2xsY4Cvmdk6wAmx7WJgBHB83C72IeB7VZ7/UuDIOP4y4CIzA/i8uz9UlxIWQJK8dquZkriIiJaFOLsAABkWSURBVJT01yf+YeAVYIq77w6MBc4Dvgg8C+zn7t9y968CewGLS7eeVakjHocD7wGOip8Bra5WZJr0RUREKvWXLLcDfu/u9xOj083s7Ghev9HdXy3t6O6LzOwGYP9qT+7umme9Sh0d8ItfZB2FiIg0k/5q4qOAF8o3uPuL8etLPey/BFi/fuFJSWlwm4iISEl/SXwhsE4NxxsJrBxkTNKDcePgpZfgueeyjkRERJpFf83pC4HDzOzvFdtT4BgzO7xi+yhgUZ1jlLLBbXPnwowZWUcjIiLNoL8kvigGnW3cw2sj4qfSM3WKTSqUBrcpiYuICP0lcXefPnShSH86OuDyy7OOQkREmkUzTbsq/dBtZiIiUk5JPEfGj+8e2Pbii1XsLCIihackniNtbbDbbjBvXtaRiIhIM1ASzxndLy4iIiVK4jmjfnERESlREs8ZrS0uIiIlSuI5M2ECPPlk9+xtIiLS2pTEc2bYMJg0Ce66K+tIREQka0riOaR+cRERQUk8n5TERUQEJfF8UhIXERGUxPNp0iT429/glVeyjkRERLKkJJ5Dw4fDjjvC/PlZRyIiIllSEs8p3S8uIiJK4jmlfnEREVESzyklcRERURLPqV13hfvugxUrso5ERESyoiSeU+usA+PGwb33Zh2JiIhkRUk8x9SkLiLS2pTEc0xri4uItDYl8RxTTVxEpLUpiefYlClw992walXWkYiISBaUxHNs3XVhyy3h/vuzjkRERLKgJJ5zalIXEWldSuI5pyQuItK6lMRzTklcRKR1KYnn3JQpMG8erF6ddSQiIjLUlMRzbvRo2HBDeOihrCMREZGhpiReAGpSFxFpTUriBaC1xUVEWpOSeAGoJi4i0pqUxAtgypTuJJ6mWUciIiJDSUm8AMaMgZEj4bHHso5ERESGkpJ4QahJXUSk9SiJF4SSuIhI61ESLwitLS4i0nqUxAuiVBPX4DYRkdahJF4QW2wBSQJPP511JCIiMlSUxAsiSdQvLiLSaoZleXIz+xjwDmAqsAnwE3c/LsuY8qyUxA89NOtIRERkKGRdEz8YOAhYkXEchaCauIhIa0nSDEdCmdloYAmwBfBYNTXxzs7OtK2trW4xLF26lPb29rodL0sLFozgk5+cwqWX3vKG14pUzr6onMWichaLyjlwXV1dTJ8+Pancnmlzuru/QHcyr+l906ZNq1sMc+bMqevxspSm8NGPwoQJ0xgz5vWvFamcfVE5i0XlLBaVc+A6Ozt73J51c7rUkQa3iYi0FiXxglESFxFpHUriBaO1xUVEWkfWt5j9IQa1rRmbjjCzvYHz3P3CLGPLq44OmDUr6yhERGQoZJrEge2Abcqej4qfNft4j/Rh++1h0SJ44QUYPTrraEREpJGyHp2+bZbnL6I11oDJk2HuXJg+PetoRESkkdQnXkAa3CYi0hqUxAtISVxEpDUoiReQkriISGtQEi+gnXaCp56CJUuyjkRERBpJSbyAhg2DXXeFefOyjkRERBpJSbyg1KQuIlJ8SuIFpSQuIlJ8SuIFpSQuIlJ8SuIFNXEiPPIIvPxy1pGIiEijKIkX1FprwYQJMH9+1pGIiEijKIkXmJrURUSKTUm8wDo64I47so5CREQaRUm8wLS2uIhIsSmJF9guu8ADD8Dy5VlHIiIijaAkXmBrr929vvg992QdiYiINIKSeMFpcJuISHEpiReckriISHEpiReckriISHEpiRfcbrt194mvWpVkHYqIiNSZknjBrbsubLUVPP74OlmHIiIidaYk3gKmToWHHlo36zBERKTOlMQLrqsLRoyAK6/cjKuv7n5eRF1dcPXV8LOfba1yFoDKWSwqZwOlaZqrnz/84Q9pPV1//fV1PV4zWbUqTadPT9O1105TWJ22t3c/X7Uq68jqq1TO9vY0TRKVM+9Uzqwjqy+Vsz7Hj9z3hpw4bAiuEyQjs2fDrbfCK68AJCxdCp2d0N4Owwr0L79qFbz6aunZa+UcNaq7FWKNNV77SZLXP6/2tcG8t17HfewxuPFGWLnytXLeeCMcdVT3pD5F8be/qZwqZ/70VM5bb+3+Hj7kkMadt0Bf5VJp7lxYtuz125IEPvMZOPnkrKKqv3POgS9/GdL0tW1JAp/8JJx4Iqxe/fqfNH3jtmpeG8x763Hc++4rfUG8ZuVKWLwY1ltvyD/2hlm8WOVUOfOnp3IuWwbz5imJywBNmQIjR8LSpa9tGzkS9tqre9R6Uey9d8/l3Gcf2GijLCOrr/Hj4Y9/fH0529vhU59q7JfEUNtlF/jzn1XOomjlco4cCZMnN/a8GthWYDNndifs9nZIkpT29u7nM2dmHVl9qZxZR1ZfKmfWkdWXytnY86omXmBtbXDddd19Mpdf/hhHHLEtM2d2by8SlTPryOpL5cw6svpSORt7XiXxgmtr626yam9/nGnTts06nIZROYtF5SwWlbNx1JwuIiKSU0riIiIiOaUkLiIiklNK4iIiIjmlJC4iIpJTSuIiIiI5pSQuIiKSU0riIiIiOaUkLiIiklNK4iIiIjmVpOXrN+ZAZ2dnvgIWERGpg+nTpyeV23KXxEVERKSbmtNFRERySklcREQkp5TERUREckpJXEREJKeUxEVERHJKSVxERCSnhmUdgIhIKzOzycA5wJ7xnfw9dz8567ikdma2PvAN4B3ACOAW4AR3f6BR51QSFxHJiJntAtwIrAZuBtYEHsk6LhmwbwPvB+4HXgbeBlwG7NKoE7ZsEjezj8XV0lRgE+An7n5c1nHVm5l9ETgMGA8sAL7v7l/POi6RvpjZDODais0vu/vIjEJqlBPje3hnd38s62Bk0N4OPAVMdPfVZnYLsLeZre/uixtxwpZN4sDBwEHAs1kH0mBjgFfjKv8A4FwzW+juP8k6sHoys4uA/YCxwCLAgVnuviLr2BrBzKZEkjvf3c/OOp4G2CoeO4Hn4/dFGcZTd2a2BvDeqInvb2anAvOAC9y9K+v46sXMNgee7uXlae5+wxCH1EjPAxsCpelR1wIeb1QCp8WT+HHAEmALoMhXwB9199V0/2c6HvhBXC0WKokDmwPLgD8C+0QNZwUwK+vAGuTsaEG6PutAGqSUxE909/kZx9IomwBrA/sCM8q27wp8JMO46m0xcEnFtunAxtHkXCTnA98F5pjZI8BOwHsaecKWTeLu/gLdiS3rUBqqlMBD6Wp4dS+755a7v730u5kdBvw2auaFY2Z7ADOBTnf/U9bxNMiG8Vio2neFzeOxC5gUlYobgOPN7GR3X5JxfHXh7i8D/1J6HjXzR4B73f22bKOrux8Ax0dFYj/gAeCeRp5Qt5i1lg/H45UZxzFU7sg6gAY5NR7PyjiORir1fX/ZzE43sx0yjqcR2uLx/9z9Xnd/Ergitm+XcWyNdCIwHDgv60AawIGty1pXxgC3mFnDKswtWxNvNWZ2CnAEMLuHpq3cM7ONYmTo2tH3/yfga1nHVW9mth1waHQbLIp/12vcvaFX+xkoJfF/jccvmdmH3P3HGcZUb3+Px/JxG6t62FYYZjYa+PfoO/5F1vHUk5ltGoOIf+juf45tvwY+FEm9IX3/SuIFF4NnzgY+B/wOsIom9qIYCRxV9vz5Ava3AXwiWtDOBDqAr8bgzKIl8WNi0NfawCHxhX++mf2sQIO+FkQT+u5RU+uKL/uVwKNZB9cgnwTagW+4+6tZB1NnG8WAti3LtpV+X6tRJ23ZJG5mf4hBbWvGpiPMbG/gPHe/MOPw6ulS4EhgeQz8uijGAXze3R/KOrh6cffHgcTM1os+qW8AVxepX9zM1o0Bmbe5+3Vm9oGsY2oUd18ev74MuJnNAiYD44BC/N26+0oz+2lcmN0WtfDd4zbQwl2Amll7lPUV4L+zjqcBHgCeBN5mZvfExfaEGIv050adtJX7xLeLkYOlvqdR8XzNft6XNx3xODxGSR4VP4W8gHP3Je7+TeAZYF8zG5V1THU0DVgPGGNmNwGnxfbTzOzdGcfWMGa2AbBz1FCfyDqeOjsFuAjYLGpt3wU+nXVQDfLvwOi4SHku62Dqzd1XRqtRZ9zquglwDTDD3V9q1HkL+UVeDXffNusYhkKrlLMHSTRPFqlvsXTRPTZ+SsY3srkuC2b2vWhFeTKamNcCzimroReCu79SNuC0sMxsOHBSzFlRuLEqJXE75FuH8pwtm8SlWMzsfcAc4MUYDLUZcHt8SRaCu19RNokEZvYF4AzgWHf/ZbbR1d2S+H46IOZxuBD4TtZByYAdG/8n/8vdiz7B1pBSEpfci37wi8sTXCSBwtdwisrdT4mmZsk5M2sDTo5a+DlZx1M0rdwnLsWxKhYd+DLwo/jy387d52UdWIPNiZp40cspORZ3E0wH3uzuC7KOp2iSNE2zjkFEREQGQDVxERGRnFISFxERySkNbJPCMbP3ABcAp7t7riaVMLPdYqKavYA/uPsRWcdEd1znxT2we7r7i1nHIyLdlMSlamb2FLAB8C/u/oZFVMwsBea4+1uyifCf1ok4kyr2bRoxMc3smNDlf4H/G8JzrwN8EFjs7hf3sMv2MRnJyLiNTzJiZvvGKnbfcveFWccj2VJzutRii0iQ55pZ0Wa2awb7xL2057r7ke4+lPdFbxwLyEzv5fXDgE3c/akhjEl6dnzM1jeyin2l4FQTl1otA3YA/gM4P+tgCmaTeHws4zjeIBbNadjUkSIyMLrFTKoWzeXXxjzW7cD25f2jPTWnR//0hcCn3P1nZdv3jgVK/svdzyzbfjzwrZjffUYsu7l2rMD2b1FbPSMWw1ga8zCfVRHnB4CfRG1lDPBOYMO4n/okd7+9h7LtH8edArwQTdmfdffFZftsGvdm3xBLnZ4exx/l7j3+R4pV5D4T89bvFHN//w74QvkiF2b245jVqtzV7v7OXo57cEygMT5q0U8B84Fvu/sbljw0s7dGvJOjBe4J4HPufnXZ51VpH3e/Jd5/PTDN3ZN4/i7gp7GQzjcrznVszK52nLtfWrb9I7Es406xStf/xFSqva6qZ2aTo2vh+8A/YgKfLYH7gC+5+1UV+78/aqo7RLfEE7G4yLfcfW7FvvfE69+J1eB2ALZ19wVmtlXMY75DrA/9HPBwzEPwq5gnu3ScL8a/xQGxwMd0YHWU72TgXTHl6KQ4zlfd/Qc9lHUr4JvAm2Ka2dvj7/W+sn0eBbapeOv/uvuMsn3agFnxNzcuFub4vrtfVHG+fwO+EP+v3gUcDvzS3T8er28efzNvixaip4Gb3b3y71QypOZ0qdUI4KxYyOA/q9h/HWD9Hvqnh0dirfwb3CKaCX8FHBjJ8gXgA8AdwO/jPdfGMc40s6Po2VlxEXA7cCewP/BHMxtfvpOZHRZJewJwXXyxHw9cb2blcQ8HdgTeDZwH3Atc0lsCD5fHLFUj46IliaT+p1h+suTG+Cn9fhHw2z6OW/r8bgV+E2tTHwrMMbOtK8p3bCTCXaOcfwY2LVuq9aH4vEu/XxQ/fbUI3BAXVz0NvDssXptTFsPXge/F9ivj3/DL8Tn2ZVRcpMyKBHlPnHsKcEVcnJRbL/5Gb4jzLI2JgO7s4dgT46Lwkpij/cqyNb6XxOv3Aw78FZgK/Bx4X8VxxsY5r4+/oRsiiX8mLhwvjhas2XHRd4GZ7VN+gPg3mxuDB2+P9eLfAtxuZuVLW/6qbBW3X8W/028q4vlVLFW7HLgq/q/+wMwqF1bZDNgc+Bmwd/yfKq2DvV78rXw4/g6ui+O19/PvJUNMzekyED+KL9SPmdl/u/vfGnCO0939PLq/UEbG+uCTgN3d/Y7YfmCsGPS2+CKu9F/AZ0rrT5vZfwJfioTwQV6rtZwTX96TSi0LZnZa2UXAFRXHfQY40N2f76sAcXFwKPBr4L3u3hXn+1XUfI6PGibu/mMzWx0XGj9y95/2deyogVbWQt9Qm4/lH78KLAQ6yvu0Sxco7v4nM3s61u/+o7v3O12tuy80sz8D+5jZGHf/exxzbeBg4NbS52Nm28Y60lcDh8fnsEYktY+Y2Vn9fZZR6//3sn/LYyL5zAL+UBbXf1cuc1lqRejluMtiJrFHKsr3j7IVDkvH6a3FouT97n557Dsuau5bAbuUlv01sw8CP4zP6E9l7/3PSLb7uHspke4L3By1+E9HXJ8zszHRAnNKLMFbHuN+8bf1PXf/KK8tPnIHcKqZnd/DeuwXAqdWXIzOiAvW8939U32UWTKmmrjULL4ETiqtLNWg0/yz5uTuy0rPSwk8lGquY9/w7m7zKr6wvhuP5bWgPePL6tKKW6dKiXtyD8e9vYqkQyy9CHBWKY54PDu2H1nFMQbrbVH7+3HloLR+WhCqcVV8h5Q3+R8UrS/lFxhHxRK/Pyz7HFbHPmtF90x/nqz4t/wlsBjYr6K1pFYPVibwQSj/m30kmp//UUrg4Q1/sxH/0cD8UgKPY/wRWNTL32BvjonHC8uOszxq2RtFV0Sla3v4W9giHtep4dySAdXEZUDc/fdmdg1wpJnt7+43VvG2wVhZucHdV5kZ1d5K5u4vmNlLFUm/tFTrvmZW3k+5Xjz29KVXrfFAGk2x5e6N7eMGeuCoyR4RzcqbxGewbw+7lmqT9/Xw2mBdDXwl+lJL/a2Hx2N5Ei99xh80s3eUbd8xHmv+jKM2/1S0zmxSagY3s7XiomHn6G4oP09NzOxNsazk5vFduUONh3jD32zZtvK/2c2iOX6jir9BYnstn0/psz7ZzJaVbd89HrcEHu/hfZV+H7Eeb2a7RrfHr+pw4Sd1piQug/GZaHb7ppntkXUwVVoKbGJmbVGz2yq27172RVduMF9aWwIvufvr1jR39xVmtjSSQ80igf8+xgz0p9Q//o+BnKsv7n6PmT0GvNXM1o0+9ncCj7r7PWW7lj7jHgfpDeIzLo2WX5vX7nW/rcqafZ/M7IwY9DUUSp/PFtHFUqmWz6d0rH/p5fWqjuXu95rZ9Bj0t2cM0vu4mX3I3R+oIR5pMCVxGTB3v9/Mvg98PEbCVva1NZVotty4omm2NFnGKe7+tTqf8llg67ILhlIca8UAoYGOJXhbJPCbgPeVmsl7GeFe6iIYPeBS9O1S4LMR07PRZHthxT6lz3gvd/9LHc+9YfzNlboJjo0Efhnw0bI++b76xN8gBnWdHnEfHM3cXVX0iQ9U6fOZ7e5vr9Oxxrj7c4M5kLvfZGYdMZDzc9HS8zsz26l8dL5kS33iMlhnRN/kCcArPbxeuvIfPsRx9WRSXLiWJ89Sf+j2DTjfY9FsWtkMu3Nsf6iX9/VnQjxeUsXkK6UR5jv1s1/pIqPW/uX/icfDY1R6+baSun/GZrZ+tDI85e6rYnPpc/lxlWMWejMeaAOud/e5PQwEq7cn4vOv9vPp69+qrp+1u3e5+69iGuD7owuov78lGUJK4jIo7r4oRny/qZfbT0pJ5qDShqgR7zIE4f1zVrmYYe6MePq9sn1uiS/Ro8zsdV9OZjam4jawWl0WjyeXBl/F6PTTYvtFvb+1T6Wm8crPcEQ8tpVtuyYurj5QcasS0QRe8lxccNXUxB/3Xj8IvDlG1s+vaEonPofVwKdjatnyGKrt722reH5qXBj+sGzbGz6X+NyH81o3RDVq+XwHLWq1vwHGm9nR5a+Z2dpmVtmK8mw89vRvVbpVcFaMSi8/VtV962Y2qvzvI7qESuVeVO1xpPHUnC718F3go71c/d8Z/+nfbWZ/iXu+J8SI6UYp1cy+a2ZHxpfy7hHf/8a929D95fSqmZ0Utcc7zezmuJ1t66h9vCVu8xmIC4GPRDPvZDO7L0YaTwCuKd2ONACdwKvAh81sQsS7W9lneqGZ3eDuZ7r7M2Z2VoyIv9vMboyk3hGJ43O81k9/ffRv3xz9zQ+4+wlVxPM/0X+8Zel45dz9bjP7NvAp4IG4NW1F1Oi2M7PRMYK6L6fHv+U9MVBtctzC9a2yfa6J5P5FM5sWn9HuZYMUf21mv69iUZxHYjDizjEhzPwYIFhqUTkpBr19wt1freLzqcasmCTmYjP7WEyGs3H0R389bncsuS4uBH8b/547Aoe4++Pu/jszuyJaRR40s9LERrtEpa3aGvoHgK+a2a3AghhAOT6a/J+pU5mlDlQTl0GLq/TPRvJ8peK1xXEr1Z3RnL1NfOm/p4Hx/CLuz74hzvnOuHg4FZhZOcLW3S+LyS4644vqyEhIvxnMVKNRw9o77lcfHjG9ApwSvw/0uI/GvcD3AnvERcEFceExJ7atKNv/y9HcPTcuTA6JBPdwxaGPj4ucSZGwqi37Jb38Xh7zCTFJyv1xi9/b4/ayn1R5G9MD0Yx8SLzvu8AUd/9njDG73HHRsrJ//K2dHt0X82JGtSX9nShufzsyBg+OjXjviKbk78QAtI3rmMBx94cjzh/FBDdHRuK9vfLuBne/KbqvXoo++xcqPsMj4r78p+Mi9K3x7315DS1Lj8REQhMjljVjch6rV5mlPjTtqog0LTM7IC5MznD3L2Ydj0izUU1cREQkp5TERUREckpJXEREJKfUJy4iIpJTqomLiIjklJK4iIhITimJi4iI5JSSuIiISE4piYuIiOSUkriIiEhO/T/pzs5llxWyfwAAAABJRU5ErkJggg==\n" 806 | }, 807 | "metadata": { 808 | "needs_background": "light" 809 | } 810 | } 811 | ] 812 | }, 813 | { 814 | "cell_type": "markdown", 815 | "metadata": { 816 | "originalKey": "1ba68dc9-d60b-4b39-8e58-ea9bdc06b44c", 817 | "showInput": false, 818 | "customInput": null 819 | }, 820 | "source": [ 821 | "# Demo of Using GenerationStrategy and Service API \n", 822 | "\n", 823 | "Please check [Service API tutorial](https://ax.dev/tutorials/gpei_hartmann_service.html) for more detailed information. " 824 | ], 825 | "attachments": {} 826 | }, 827 | { 828 | "cell_type": "code", 829 | "metadata": { 830 | "originalKey": "c9eac7a0-d8c2-49c9-a53a-4e05b6694ced", 831 | "showInput": true, 832 | "customInput": null, 833 | "collapsed": false, 834 | "requestMsgId": "0dc37045-8f54-4091-913f-69b11b072e19", 835 | "executionStartTime": 1689124191398, 836 | "executionStopTime": 1689124192949 837 | }, 838 | "source": [ 839 | "from ax.service.ax_client import AxClient, ObjectiveProperties\n", 840 | "\n", 841 | "from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep" 842 | ], 843 | "execution_count": 14, 844 | "outputs": [] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "metadata": { 849 | "originalKey": "92678568-4757-4a9e-9424-837352f04bbc", 850 | "showInput": true, 851 | "customInput": null, 852 | "collapsed": false, 853 | "requestMsgId": "e1e19d91-06a5-4f8c-a8be-ca7bd1856c58", 854 | "executionStartTime": 1689124192961, 855 | "executionStopTime": 1689124192970 856 | }, 857 | "source": [ 858 | "N_INIT = 10\n", 859 | "BATCH_SIZE = 1\n", 860 | "\n", 861 | "if SMOKE_TEST:\n", 862 | " NUM_TRIALS = 1\n", 863 | " SURROGATE_CLASS = FixedNoiseGP\n", 864 | "else:\n", 865 | " NUM_TRIALS = 40\n", 866 | " SURROGATE_CLASS = SaasFullyBayesianSingleTaskGP\n", 867 | "\n", 868 | "print(f\"Doing {N_INIT + NUM_TRIALS * BATCH_SIZE} evaluations\")" 869 | ], 870 | "execution_count": 15, 871 | "outputs": [ 872 | { 873 | "output_type": "stream", 874 | "name": "stdout", 875 | "text": [ 876 | "Doing 50 evaluations\n" 877 | ] 878 | } 879 | ] 880 | }, 881 | { 882 | "cell_type": "markdown", 883 | "metadata": { 884 | "originalKey": "45e5586c-55eb-4908-aa73-bca4ee883b56", 885 | "showInput": false, 886 | "customInput": null 887 | }, 888 | "source": [ 889 | "## Create `GenerationStrategy`" 890 | ], 891 | "attachments": {} 892 | }, 893 | { 894 | "cell_type": "code", 895 | "metadata": { 896 | "originalKey": "7c0bfe37-8f1f-4999-8833-42ffb2569c04", 897 | "showInput": true, 898 | "customInput": null, 899 | "collapsed": false, 900 | "requestMsgId": "bbd9058a-709e-4262-abe1-720d37e8786f", 901 | "executionStartTime": 1689124192972, 902 | "executionStopTime": 1689124192975 903 | }, 904 | "source": [ 905 | "gs = GenerationStrategy(\n", 906 | " name=\"SEBO_L0\",\n", 907 | " steps=[\n", 908 | " GenerationStep( # Initialization step\n", 909 | " model=Models.SOBOL, \n", 910 | " num_trials=N_INIT,\n", 911 | " ),\n", 912 | " GenerationStep( # BayesOpt step\n", 913 | " model=Models.BOTORCH_MODULAR,\n", 914 | " # No limit on how many generator runs will be produced\n", 915 | " num_trials=-1,\n", 916 | " model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`\n", 917 | " \"surrogate\": Surrogate(botorch_model_class=SURROGATE_CLASS),\n", 918 | " \"acquisition_class\": SEBOAcquisition,\n", 919 | " \"botorch_acqf_class\": qNoisyExpectedHypervolumeImprovement,\n", 920 | " \"acquisition_options\": {\n", 921 | " \"penalty\": \"L0_norm\", # it can be L0_norm or L1_norm.\n", 922 | " \"target_point\": target_point, \n", 923 | " \"sparsity_threshold\": aug_dim,\n", 924 | " },\n", 925 | " },\n", 926 | " )\n", 927 | " ]\n", 928 | ")" 929 | ], 930 | "execution_count": 16, 931 | "outputs": [] 932 | }, 933 | { 934 | "cell_type": "markdown", 935 | "metadata": { 936 | "originalKey": "e4911bc6-32cb-42a5-908f-57f3f04e58e5", 937 | "showInput": false, 938 | "customInput": null 939 | }, 940 | "source": [ 941 | "## Initialize client and set up experiment" 942 | ], 943 | "attachments": {} 944 | }, 945 | { 946 | "cell_type": "code", 947 | "metadata": { 948 | "originalKey": "47938102-0613-4b37-acb2-9f1f5f3fe6b1", 949 | "showInput": true, 950 | "customInput": null, 951 | "collapsed": false, 952 | "requestMsgId": "38b4b17c-6aae-43b8-aa58-2df045f522fe", 953 | "executionStartTime": 1689124192979, 954 | "executionStopTime": 1689124192984 955 | }, 956 | "source": [ 957 | "ax_client = AxClient(generation_strategy=gs)\n", 958 | "\n", 959 | "experiment_parameters = [\n", 960 | " {\n", 961 | " \"name\": f\"x{i}\",\n", 962 | " \"type\": \"range\",\n", 963 | " \"bounds\": [0, 1],\n", 964 | " \"value_type\": \"float\",\n", 965 | " \"log_scale\": False,\n", 966 | " }\n", 967 | " for i in range(aug_dim)\n", 968 | "]\n", 969 | "\n", 970 | "objective_metrics = {\n", 971 | " \"objective\": ObjectiveProperties(minimize=False, threshold=-10),\n", 972 | "}\n", 973 | "\n", 974 | "ax_client.create_experiment(\n", 975 | " name=\"branin_augment_sebo_experiment\",\n", 976 | " parameters=experiment_parameters,\n", 977 | " objectives=objective_metrics,\n", 978 | ")" 979 | ], 980 | "execution_count": 17, 981 | "outputs": [ 982 | { 983 | "output_type": "stream", 984 | "name": "stderr", 985 | "text": [ 986 | "[INFO 07-11 18:09:53] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.\n" 987 | ] 988 | }, 989 | { 990 | "output_type": "stream", 991 | "name": "stderr", 992 | "text": [ 993 | "[INFO 07-11 18:09:53] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x0', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x7', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[]).\n" 994 | ] 995 | } 996 | ] 997 | }, 998 | { 999 | "cell_type": "markdown", 1000 | "metadata": { 1001 | "originalKey": "6a7942e4-9727-43d9-8d8d-c327d38c2373", 1002 | "showInput": false, 1003 | "customInput": null 1004 | }, 1005 | "source": [ 1006 | "## Define evaluation function " 1007 | ], 1008 | "attachments": {} 1009 | }, 1010 | { 1011 | "cell_type": "code", 1012 | "metadata": { 1013 | "originalKey": "4e2994ff-36ac-4d48-a789-3d0398e1e856", 1014 | "showInput": true, 1015 | "customInput": null, 1016 | "collapsed": false, 1017 | "requestMsgId": "8f74a775-a8ce-462d-993c-5c9291c748b9", 1018 | "executionStartTime": 1689124192990, 1019 | "executionStopTime": 1689124192992 1020 | }, 1021 | "source": [ 1022 | "def evaluation(parameters):\n", 1023 | " # put parameters into 1-D array\n", 1024 | " x = [parameters.get(param[\"name\"]) for param in experiment_parameters]\n", 1025 | " res = branin_augment(x_vec=x, augment_dim=aug_dim)\n", 1026 | " eval_res = {\n", 1027 | " # flip the sign to maximize\n", 1028 | " \"objective\": (res * -1, 0.0),\n", 1029 | " }\n", 1030 | " return eval_res" 1031 | ], 1032 | "execution_count": 18, 1033 | "outputs": [] 1034 | }, 1035 | { 1036 | "cell_type": "markdown", 1037 | "metadata": { 1038 | "originalKey": "4597531b-7ac8-4dd0-94c4-836672e0f4c4", 1039 | "showInput": false, 1040 | "customInput": null 1041 | }, 1042 | "source": [ 1043 | "## Run optimization loop" 1044 | ], 1045 | "attachments": {} 1046 | }, 1047 | { 1048 | "cell_type": "code", 1049 | "metadata": { 1050 | "originalKey": "bc7accb2-48a2-4c88-a932-7c79ec81075a", 1051 | "showInput": true, 1052 | "customInput": null, 1053 | "collapsed": false, 1054 | "requestMsgId": "f054e5b1-12eb-459b-a508-6944baf82dfb", 1055 | "executionStartTime": 1689124193044, 1056 | "executionStopTime": 1689130398208 1057 | }, 1058 | "source": [ 1059 | "for _ in range(NUM_TRIALS+N_INIT): \n", 1060 | " parameters, trial_index = ax_client.get_next_trial()\n", 1061 | " res = evaluation(parameters)\n", 1062 | " ax_client.complete_trial(trial_index=trial_index, raw_data=res)" 1063 | ], 1064 | "execution_count": 19, 1065 | "outputs": [ 1066 | { 1067 | "output_type": "stream", 1068 | "name": "stderr", 1069 | "text": [ 1070 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 0 with parameters {'x0': 0.340745, 'x1': 0.592392, 'x2': 0.307124, 'x3': 0.136736, 'x4': 0.453162, 'x5': 0.407409, 'x6': 0.898588, 'x7': 0.712434}.\n" 1071 | ] 1072 | }, 1073 | { 1074 | "output_type": "stream", 1075 | "name": "stderr", 1076 | "text": [ 1077 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 0 with data: {'objective': (-28.913967, 0.0)}.\n" 1078 | ] 1079 | }, 1080 | { 1081 | "output_type": "stream", 1082 | "name": "stderr", 1083 | "text": [ 1084 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 1 with parameters {'x0': 0.596941, 'x1': 0.798649, 'x2': 0.111305, 'x3': 0.329006, 'x4': 0.187743, 'x5': 0.589378, 'x6': 0.500772, 'x7': 0.008061}.\n" 1085 | ] 1086 | }, 1087 | { 1088 | "output_type": "stream", 1089 | "name": "stderr", 1090 | "text": [ 1091 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 1 with data: {'objective': (-108.522848, 0.0)}.\n" 1092 | ] 1093 | }, 1094 | { 1095 | "output_type": "stream", 1096 | "name": "stderr", 1097 | "text": [ 1098 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 2 with parameters {'x0': 0.310899, 'x1': 0.906665, 'x2': 0.859498, 'x3': 0.861769, 'x4': 0.565173, 'x5': 0.849893, 'x6': 0.743119, 'x7': 0.485293}.\n" 1099 | ] 1100 | }, 1101 | { 1102 | "output_type": "stream", 1103 | "name": "stderr", 1104 | "text": [ 1105 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 2 with data: {'objective': (-68.762484, 0.0)}.\n" 1106 | ] 1107 | }, 1108 | { 1109 | "output_type": "stream", 1110 | "name": "stderr", 1111 | "text": [ 1112 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 3 with parameters {'x0': 0.222246, 'x1': 0.682503, 'x2': 0.697094, 'x3': 0.262685, 'x4': 0.660106, 'x5': 0.783381, 'x6': 0.537969, 'x7': 0.607574}.\n" 1113 | ] 1114 | }, 1115 | { 1116 | "output_type": "stream", 1117 | "name": "stderr", 1118 | "text": [ 1119 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 3 with data: {'objective': (-10.589478, 0.0)}.\n" 1120 | ] 1121 | }, 1122 | { 1123 | "output_type": "stream", 1124 | "name": "stderr", 1125 | "text": [ 1126 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 4 with parameters {'x0': 0.391554, 'x1': 0.769673, 'x2': 0.363151, 'x3': 0.522279, 'x4': 0.8752, 'x5': 0.921642, 'x6': 0.892081, 'x7': 0.614701}.\n" 1127 | ] 1128 | }, 1129 | { 1130 | "output_type": "stream", 1131 | "name": "stderr", 1132 | "text": [ 1133 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 4 with data: {'objective': (-62.905011, 0.0)}.\n" 1134 | ] 1135 | }, 1136 | { 1137 | "output_type": "stream", 1138 | "name": "stderr", 1139 | "text": [ 1140 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 5 with parameters {'x0': 0.319981, 'x1': 0.578814, 'x2': 0.58387, 'x3': 0.310305, 'x4': 0.198673, 'x5': 0.78394, 'x6': 0.423361, 'x7': 0.853005}.\n" 1141 | ] 1142 | }, 1143 | { 1144 | "output_type": "stream", 1145 | "name": "stderr", 1146 | "text": [ 1147 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 5 with data: {'objective': (-24.971551, 0.0)}.\n" 1148 | ] 1149 | }, 1150 | { 1151 | "output_type": "stream", 1152 | "name": "stderr", 1153 | "text": [ 1154 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 6 with parameters {'x0': 0.889574, 'x1': 0.540804, 'x2': 0.668386, 'x3': 0.511087, 'x4': 0.587279, 'x5': 0.966997, 'x6': 0.699696, 'x7': 0.919272}.\n" 1155 | ] 1156 | }, 1157 | { 1158 | "output_type": "stream", 1159 | "name": "stderr", 1160 | "text": [ 1161 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 6 with data: {'objective': (-46.419155, 0.0)}.\n" 1162 | ] 1163 | }, 1164 | { 1165 | "output_type": "stream", 1166 | "name": "stderr", 1167 | "text": [ 1168 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 7 with parameters {'x0': 0.816103, 'x1': 0.454254, 'x2': 0.498263, 'x3': 0.609042, 'x4': 0.080031, 'x5': 0.321146, 'x6': 0.505942, 'x7': 0.386978}.\n" 1169 | ] 1170 | }, 1171 | { 1172 | "output_type": "stream", 1173 | "name": "stderr", 1174 | "text": [ 1175 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 7 with data: {'objective': (-46.485345, 0.0)}.\n" 1176 | ] 1177 | }, 1178 | { 1179 | "output_type": "stream", 1180 | "name": "stderr", 1181 | "text": [ 1182 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 8 with parameters {'x0': 0.687349, 'x1': 0.282216, 'x2': 0.751967, 'x3': 0.566662, 'x4': 0.79098, 'x5': 0.641958, 'x6': 0.724017, 'x7': 0.590121}.\n" 1183 | ] 1184 | }, 1185 | { 1186 | "output_type": "stream", 1187 | "name": "stderr", 1188 | "text": [ 1189 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 8 with data: {'objective': (-24.65791, 0.0)}.\n" 1190 | ] 1191 | }, 1192 | { 1193 | "output_type": "stream", 1194 | "name": "stderr", 1195 | "text": [ 1196 | "[INFO 07-11 18:09:53] ax.service.ax_client: Generated new trial 9 with parameters {'x0': 0.130133, 'x1': 0.712254, 'x2': 0.760572, 'x3': 0.411107, 'x4': 0.542096, 'x5': 0.526756, 'x6': 0.787764, 'x7': 0.674992}.\n" 1197 | ] 1198 | }, 1199 | { 1200 | "output_type": "stream", 1201 | "name": "stderr", 1202 | "text": [ 1203 | "[INFO 07-11 18:09:53] ax.service.ax_client: Completed trial 9 with data: {'objective': (-2.309687, 0.0)}.\n" 1204 | ] 1205 | }, 1206 | { 1207 | "output_type": "stream", 1208 | "name": "stderr", 1209 | "text": [ 1210 | "[INFO 07-11 18:11:36] ax.service.ax_client: Generated new trial 10 with parameters {'x0': 0.0, 'x1': 0.0, 'x2': 1.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.892852}.\n" 1211 | ] 1212 | }, 1213 | { 1214 | "output_type": "stream", 1215 | "name": "stderr", 1216 | "text": [ 1217 | "[INFO 07-11 18:11:36] ax.service.ax_client: Completed trial 10 with data: {'objective': (-308.129096, 0.0)}.\n" 1218 | ] 1219 | }, 1220 | { 1221 | "output_type": "stream", 1222 | "name": "stderr", 1223 | "text": [ 1224 | "[INFO 07-11 18:13:39] ax.service.ax_client: Generated new trial 11 with parameters {'x0': 0.0, 'x1': 0.640271, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.946358, 'x7': 0.0}.\n" 1225 | ] 1226 | }, 1227 | { 1228 | "output_type": "stream", 1229 | "name": "stderr", 1230 | "text": [ 1231 | "[INFO 07-11 18:13:39] ax.service.ax_client: Completed trial 11 with data: {'objective': (-70.230069, 0.0)}.\n" 1232 | ] 1233 | }, 1234 | { 1235 | "output_type": "stream", 1236 | "name": "stderr", 1237 | "text": [ 1238 | "[INFO 07-11 18:15:44] ax.service.ax_client: Generated new trial 12 with parameters {'x0': 0.0, 'x1': 0.519038, 'x2': 1.0, 'x3': 0.0, 'x4': 1.0, 'x5': 0.0, 'x6': 0.870499, 'x7': 0.0}.\n" 1239 | ] 1240 | }, 1241 | { 1242 | "output_type": "stream", 1243 | "name": "stderr", 1244 | "text": [ 1245 | "[INFO 07-11 18:15:44] ax.service.ax_client: Completed trial 12 with data: {'objective': (-101.117533, 0.0)}.\n" 1246 | ] 1247 | }, 1248 | { 1249 | "output_type": "stream", 1250 | "name": "stderr", 1251 | "text": [ 1252 | "[INFO 07-11 18:17:37] ax.service.ax_client: Generated new trial 13 with parameters {'x0': 0.0, 'x1': 0.0, 'x2': 0.0, 'x3': 1.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.727362, 'x7': 1.0}.\n" 1253 | ] 1254 | }, 1255 | { 1256 | "output_type": "stream", 1257 | "name": "stderr", 1258 | "text": [ 1259 | "[INFO 07-11 18:17:37] ax.service.ax_client: Completed trial 13 with data: {'objective': (-308.129096, 0.0)}.\n" 1260 | ] 1261 | }, 1262 | { 1263 | "output_type": "stream", 1264 | "name": "stderr", 1265 | "text": [ 1266 | "[INFO 07-11 18:19:31] ax.service.ax_client: Generated new trial 14 with parameters {'x0': 0.0, 'x1': 0.784581, 'x2': 1.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.700215, 'x6': 0.0, 'x7': 0.724654}.\n" 1267 | ] 1268 | }, 1269 | { 1270 | "output_type": "stream", 1271 | "name": "stderr", 1272 | "text": [ 1273 | "[INFO 07-11 18:19:31] ax.service.ax_client: Completed trial 14 with data: {'objective': (-42.085428, 0.0)}.\n" 1274 | ] 1275 | }, 1276 | { 1277 | "output_type": "stream", 1278 | "name": "stderr", 1279 | "text": [ 1280 | "[INFO 07-11 18:21:33] ax.service.ax_client: Generated new trial 15 with parameters {'x0': 0.0, 'x1': 0.710437, 'x2': 0.953762, 'x3': 0.0, 'x4': 0.0, 'x5': 0.662267, 'x6': 1.0, 'x7': 0.840749}.\n" 1281 | ] 1282 | }, 1283 | { 1284 | "output_type": "stream", 1285 | "name": "stderr", 1286 | "text": [ 1287 | "[INFO 07-11 18:21:33] ax.service.ax_client: Completed trial 15 with data: {'objective': (-55.375218, 0.0)}.\n" 1288 | ] 1289 | }, 1290 | { 1291 | "output_type": "stream", 1292 | "name": "stderr", 1293 | "text": [ 1294 | "[INFO 07-11 18:23:43] ax.service.ax_client: Generated new trial 16 with parameters {'x0': 0.0, 'x1': 0.712456, 'x2': 0.0, 'x3': 0.0, 'x4': 1.0, 'x5': 0.628146, 'x6': 0.0, 'x7': 0.846157}.\n" 1295 | ] 1296 | }, 1297 | { 1298 | "output_type": "stream", 1299 | "name": "stderr", 1300 | "text": [ 1301 | "[INFO 07-11 18:23:43] ax.service.ax_client: Completed trial 16 with data: {'objective': (-54.980534, 0.0)}.\n" 1302 | ] 1303 | }, 1304 | { 1305 | "output_type": "stream", 1306 | "name": "stderr", 1307 | "text": [ 1308 | "[INFO 07-11 18:26:09] ax.service.ax_client: Generated new trial 17 with parameters {'x0': 1.0, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1309 | ] 1310 | }, 1311 | { 1312 | "output_type": "stream", 1313 | "name": "stderr", 1314 | "text": [ 1315 | "[INFO 07-11 18:26:09] ax.service.ax_client: Completed trial 17 with data: {'objective': (-10.960889, 0.0)}.\n" 1316 | ] 1317 | }, 1318 | { 1319 | "output_type": "stream", 1320 | "name": "stderr", 1321 | "text": [ 1322 | "[INFO 07-11 18:28:06] ax.service.ax_client: Generated new trial 18 with parameters {'x0': 1.0, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 1.0, 'x6': 0.0, 'x7': 1.0}.\n" 1323 | ] 1324 | }, 1325 | { 1326 | "output_type": "stream", 1327 | "name": "stderr", 1328 | "text": [ 1329 | "[INFO 07-11 18:28:06] ax.service.ax_client: Completed trial 18 with data: {'objective': (-10.960889, 0.0)}.\n" 1330 | ] 1331 | }, 1332 | { 1333 | "output_type": "stream", 1334 | "name": "stderr", 1335 | "text": [ 1336 | "[INFO 07-11 18:29:44] ax.service.ax_client: Generated new trial 19 with parameters {'x0': 0.770094, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1337 | ] 1338 | }, 1339 | { 1340 | "output_type": "stream", 1341 | "name": "stderr", 1342 | "text": [ 1343 | "[INFO 07-11 18:29:44] ax.service.ax_client: Completed trial 19 with data: {'objective': (-20.508312, 0.0)}.\n" 1344 | ] 1345 | }, 1346 | { 1347 | "output_type": "stream", 1348 | "name": "stderr", 1349 | "text": [ 1350 | "[INFO 07-11 18:31:30] ax.service.ax_client: Generated new trial 20 with parameters {'x0': 0.137802, 'x1': 0.779453, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1351 | ] 1352 | }, 1353 | { 1354 | "output_type": "stream", 1355 | "name": "stderr", 1356 | "text": [ 1357 | "[INFO 07-11 18:31:30] ax.service.ax_client: Completed trial 20 with data: {'objective': (-0.613746, 0.0)}.\n" 1358 | ] 1359 | }, 1360 | { 1361 | "output_type": "stream", 1362 | "name": "stderr", 1363 | "text": [ 1364 | "[INFO 07-11 18:33:09] ax.service.ax_client: Generated new trial 21 with parameters {'x0': 0.536321, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1365 | ] 1366 | }, 1367 | { 1368 | "output_type": "stream", 1369 | "name": "stderr", 1370 | "text": [ 1371 | "[INFO 07-11 18:33:09] ax.service.ax_client: Completed trial 21 with data: {'objective': (-5.973257, 0.0)}.\n" 1372 | ] 1373 | }, 1374 | { 1375 | "output_type": "stream", 1376 | "name": "stderr", 1377 | "text": [ 1378 | "[INFO 07-11 18:34:47] ax.service.ax_client: Generated new trial 22 with parameters {'x0': 0.503722, 'x1': 0.219186, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1379 | ] 1380 | }, 1381 | { 1382 | "output_type": "stream", 1383 | "name": "stderr", 1384 | "text": [ 1385 | "[INFO 07-11 18:34:47] ax.service.ax_client: Completed trial 22 with data: {'objective': (-2.260464, 0.0)}.\n" 1386 | ] 1387 | }, 1388 | { 1389 | "output_type": "stream", 1390 | "name": "stderr", 1391 | "text": [ 1392 | "[INFO 07-11 18:36:41] ax.service.ax_client: Generated new trial 23 with parameters {'x0': 1.0, 'x1': 0.281918, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1393 | ] 1394 | }, 1395 | { 1396 | "output_type": "stream", 1397 | "name": "stderr", 1398 | "text": [ 1399 | "[INFO 07-11 18:36:41] ax.service.ax_client: Completed trial 23 with data: {'objective': (-3.445743, 0.0)}.\n" 1400 | ] 1401 | }, 1402 | { 1403 | "output_type": "stream", 1404 | "name": "stderr", 1405 | "text": [ 1406 | "[INFO 07-11 18:38:29] ax.service.ax_client: Generated new trial 24 with parameters {'x0': 0.549118, 'x1': 0.133697, 'x2': 0.0, 'x3': 1.0, 'x4': 0.0, 'x5': 1.0, 'x6': 0.0, 'x7': 0.0}.\n" 1407 | ] 1408 | }, 1409 | { 1410 | "output_type": "stream", 1411 | "name": "stderr", 1412 | "text": [ 1413 | "[INFO 07-11 18:38:29] ax.service.ax_client: Completed trial 24 with data: {'objective': (-0.479951, 0.0)}.\n" 1414 | ] 1415 | }, 1416 | { 1417 | "output_type": "stream", 1418 | "name": "stderr", 1419 | "text": [ 1420 | "[INFO 07-11 18:40:18] ax.service.ax_client: Generated new trial 25 with parameters {'x0': 0.080214, 'x1': 1.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1421 | ] 1422 | }, 1423 | { 1424 | "output_type": "stream", 1425 | "name": "stderr", 1426 | "text": [ 1427 | "[INFO 07-11 18:40:19] ax.service.ax_client: Completed trial 25 with data: {'objective': (-3.585129, 0.0)}.\n" 1428 | ] 1429 | }, 1430 | { 1431 | "output_type": "stream", 1432 | "name": "stderr", 1433 | "text": [ 1434 | "[INFO 07-11 18:42:08] ax.service.ax_client: Generated new trial 26 with parameters {'x0': 1.0, 'x1': 1.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1435 | ] 1436 | }, 1437 | { 1438 | "output_type": "stream", 1439 | "name": "stderr", 1440 | "text": [ 1441 | "[INFO 07-11 18:42:08] ax.service.ax_client: Completed trial 26 with data: {'objective': (-145.872191, 0.0)}.\n" 1442 | ] 1443 | }, 1444 | { 1445 | "output_type": "stream", 1446 | "name": "stderr", 1447 | "text": [ 1448 | "[INFO 07-11 18:44:13] ax.service.ax_client: Generated new trial 27 with parameters {'x0': 0.542029, 'x1': 0.136864, 'x2': 0.0, 'x3': 0.0, 'x4': 1.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1449 | ] 1450 | }, 1451 | { 1452 | "output_type": "stream", 1453 | "name": "stderr", 1454 | "text": [ 1455 | "[INFO 07-11 18:44:14] ax.service.ax_client: Completed trial 27 with data: {'objective': (-0.451738, 0.0)}.\n" 1456 | ] 1457 | }, 1458 | { 1459 | "output_type": "stream", 1460 | "name": "stderr", 1461 | "text": [ 1462 | "[INFO 07-11 18:46:25] ax.service.ax_client: Generated new trial 28 with parameters {'x0': 0.117749, 'x1': 0.847684, 'x2': 0.0, 'x3': 0.0, 'x4': 1.0, 'x5': 1.0, 'x6': 0.0, 'x7': 0.0}.\n" 1463 | ] 1464 | }, 1465 | { 1466 | "output_type": "stream", 1467 | "name": "stderr", 1468 | "text": [ 1469 | "[INFO 07-11 18:46:25] ax.service.ax_client: Completed trial 28 with data: {'objective': (-0.486016, 0.0)}.\n" 1470 | ] 1471 | }, 1472 | { 1473 | "output_type": "stream", 1474 | "name": "stderr", 1475 | "text": [ 1476 | "[INFO 07-11 18:48:39] ax.service.ax_client: Generated new trial 29 with parameters {'x0': 0.122207, 'x1': 0.831379, 'x2': 1.0, 'x3': 1.0, 'x4': 1.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1477 | ] 1478 | }, 1479 | { 1480 | "output_type": "stream", 1481 | "name": "stderr", 1482 | "text": [ 1483 | "[INFO 07-11 18:48:39] ax.service.ax_client: Completed trial 29 with data: {'objective': (-0.41913, 0.0)}.\n" 1484 | ] 1485 | }, 1486 | { 1487 | "output_type": "stream", 1488 | "name": "stderr", 1489 | "text": [ 1490 | "[INFO 07-11 18:51:07] ax.service.ax_client: Generated new trial 30 with parameters {'x0': 0.608958, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1491 | ] 1492 | }, 1493 | { 1494 | "output_type": "stream", 1495 | "name": "stderr", 1496 | "text": [ 1497 | "[INFO 07-11 18:51:07] ax.service.ax_client: Completed trial 30 with data: {'objective': (-7.404426, 0.0)}.\n" 1498 | ] 1499 | }, 1500 | { 1501 | "output_type": "stream", 1502 | "name": "stderr", 1503 | "text": [ 1504 | "[INFO 07-11 18:53:20] ax.service.ax_client: Generated new trial 31 with parameters {'x0': 0.532365, 'x1': 0.141486, 'x2': 0.0, 'x3': 1.0, 'x4': 0.0, 'x5': 0.0, 'x6': 1.0, 'x7': 1.0}.\n" 1505 | ] 1506 | }, 1507 | { 1508 | "output_type": "stream", 1509 | "name": "stderr", 1510 | "text": [ 1511 | "[INFO 07-11 18:53:20] ax.service.ax_client: Completed trial 31 with data: {'objective': (-0.591731, 0.0)}.\n" 1512 | ] 1513 | }, 1514 | { 1515 | "output_type": "stream", 1516 | "name": "stderr", 1517 | "text": [ 1518 | "[INFO 07-11 18:55:48] ax.service.ax_client: Generated new trial 32 with parameters {'x0': 0.950988, 'x1': 0.171879, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1519 | ] 1520 | }, 1521 | { 1522 | "output_type": "stream", 1523 | "name": "stderr", 1524 | "text": [ 1525 | "[INFO 07-11 18:55:48] ax.service.ax_client: Completed trial 32 with data: {'objective': (-0.575591, 0.0)}.\n" 1526 | ] 1527 | }, 1528 | { 1529 | "output_type": "stream", 1530 | "name": "stderr", 1531 | "text": [ 1532 | "[INFO 07-11 18:58:07] ax.service.ax_client: Generated new trial 33 with parameters {'x0': 0.973297, 'x1': 0.183923, 'x2': 1.0, 'x3': 0.0, 'x4': 1.0, 'x5': 1.0, 'x6': 0.0, 'x7': 0.0}.\n" 1533 | ] 1534 | }, 1535 | { 1536 | "output_type": "stream", 1537 | "name": "stderr", 1538 | "text": [ 1539 | "[INFO 07-11 18:58:07] ax.service.ax_client: Completed trial 33 with data: {'objective': (-0.561572, 0.0)}.\n" 1540 | ] 1541 | }, 1542 | { 1543 | "output_type": "stream", 1544 | "name": "stderr", 1545 | "text": [ 1546 | "[INFO 07-11 19:00:53] ax.service.ax_client: Generated new trial 34 with parameters {'x0': 0.972473, 'x1': 0.184526, 'x2': 0.0, 'x3': 1.0, 'x4': 0.0, 'x5': 0.0, 'x6': 1.0, 'x7': 0.0}.\n" 1547 | ] 1548 | }, 1549 | { 1550 | "output_type": "stream", 1551 | "name": "stderr", 1552 | "text": [ 1553 | "[INFO 07-11 19:00:53] ax.service.ax_client: Completed trial 34 with data: {'objective': (-0.547382, 0.0)}.\n" 1554 | ] 1555 | }, 1556 | { 1557 | "output_type": "stream", 1558 | "name": "stderr", 1559 | "text": [ 1560 | "[INFO 07-11 19:03:58] ax.service.ax_client: Generated new trial 35 with parameters {'x0': 0.543579, 'x1': 0.145004, 'x2': 1.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1561 | ] 1562 | }, 1563 | { 1564 | "output_type": "stream", 1565 | "name": "stderr", 1566 | "text": [ 1567 | "[INFO 07-11 19:03:58] ax.service.ax_client: Completed trial 35 with data: {'objective': (-0.406784, 0.0)}.\n" 1568 | ] 1569 | }, 1570 | { 1571 | "output_type": "stream", 1572 | "name": "stderr", 1573 | "text": [ 1574 | "[INFO 07-11 19:08:04] ax.service.ax_client: Generated new trial 36 with parameters {'x0': 0.56372, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1575 | ] 1576 | }, 1577 | { 1578 | "output_type": "stream", 1579 | "name": "stderr", 1580 | "text": [ 1581 | "[INFO 07-11 19:08:04] ax.service.ax_client: Completed trial 36 with data: {'objective': (-5.040681, 0.0)}.\n" 1582 | ] 1583 | }, 1584 | { 1585 | "output_type": "stream", 1586 | "name": "stderr", 1587 | "text": [ 1588 | "[INFO 07-11 19:10:39] ax.service.ax_client: Generated new trial 37 with parameters {'x0': 0.128424, 'x1': 0.814467, 'x2': 1.0, 'x3': 0.0, 'x4': 0.0, 'x5': 1.0, 'x6': 0.0, 'x7': 1.0}.\n" 1589 | ] 1590 | }, 1591 | { 1592 | "output_type": "stream", 1593 | "name": "stderr", 1594 | "text": [ 1595 | "[INFO 07-11 19:10:39] ax.service.ax_client: Completed trial 37 with data: {'objective': (-0.431012, 0.0)}.\n" 1596 | ] 1597 | }, 1598 | { 1599 | "output_type": "stream", 1600 | "name": "stderr", 1601 | "text": [ 1602 | "[INFO 07-11 19:13:21] ax.service.ax_client: Generated new trial 38 with parameters {'x0': 0.967249, 'x1': 0.189143, 'x2': 0.0, 'x3': 1.0, 'x4': 1.0, 'x5': 0.0, 'x6': 0.0, 'x7': 1.0}.\n" 1603 | ] 1604 | }, 1605 | { 1606 | "output_type": "stream", 1607 | "name": "stderr", 1608 | "text": [ 1609 | "[INFO 07-11 19:13:21] ax.service.ax_client: Completed trial 38 with data: {'objective': (-0.516046, 0.0)}.\n" 1610 | ] 1611 | }, 1612 | { 1613 | "output_type": "stream", 1614 | "name": "stderr", 1615 | "text": [ 1616 | "[INFO 07-11 19:17:03] ax.service.ax_client: Generated new trial 39 with parameters {'x0': 0.563272, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1617 | ] 1618 | }, 1619 | { 1620 | "output_type": "stream", 1621 | "name": "stderr", 1622 | "text": [ 1623 | "[INFO 07-11 19:17:03] ax.service.ax_client: Completed trial 39 with data: {'objective': (-5.040172, 0.0)}.\n" 1624 | ] 1625 | }, 1626 | { 1627 | "output_type": "stream", 1628 | "name": "stderr", 1629 | "text": [ 1630 | "[INFO 07-11 19:20:43] ax.service.ax_client: Generated new trial 40 with parameters {'x0': 0.111004, 'x1': 0.841851, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1631 | ] 1632 | }, 1633 | { 1634 | "output_type": "stream", 1635 | "name": "stderr", 1636 | "text": [ 1637 | "[INFO 07-11 19:20:43] ax.service.ax_client: Completed trial 40 with data: {'objective': (-0.590424, 0.0)}.\n" 1638 | ] 1639 | }, 1640 | { 1641 | "output_type": "stream", 1642 | "name": "stderr", 1643 | "text": [ 1644 | "[INFO 07-11 19:24:20] ax.service.ax_client: Generated new trial 41 with parameters {'x0': 0.563578, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1645 | ] 1646 | }, 1647 | { 1648 | "output_type": "stream", 1649 | "name": "stderr", 1650 | "text": [ 1651 | "[INFO 07-11 19:24:20] ax.service.ax_client: Completed trial 41 with data: {'objective': (-5.040465, 0.0)}.\n" 1652 | ] 1653 | }, 1654 | { 1655 | "output_type": "stream", 1656 | "name": "stderr", 1657 | "text": [ 1658 | "[INFO 07-11 19:27:24] ax.service.ax_client: Generated new trial 42 with parameters {'x0': 1.0, 'x1': 0.173494, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1659 | ] 1660 | }, 1661 | { 1662 | "output_type": "stream", 1663 | "name": "stderr", 1664 | "text": [ 1665 | "[INFO 07-11 19:27:24] ax.service.ax_client: Completed trial 42 with data: {'objective': (-2.103578, 0.0)}.\n" 1666 | ] 1667 | }, 1668 | { 1669 | "output_type": "stream", 1670 | "name": "stderr", 1671 | "text": [ 1672 | "[INFO 07-11 19:30:34] ax.service.ax_client: Generated new trial 43 with parameters {'x0': 0.563448, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1673 | ] 1674 | }, 1675 | { 1676 | "output_type": "stream", 1677 | "name": "stderr", 1678 | "text": [ 1679 | "[INFO 07-11 19:30:34] ax.service.ax_client: Completed trial 43 with data: {'objective': (-5.040312, 0.0)}.\n" 1680 | ] 1681 | }, 1682 | { 1683 | "output_type": "stream", 1684 | "name": "stderr", 1685 | "text": [ 1686 | "[INFO 07-11 19:34:21] ax.service.ax_client: Generated new trial 44 with parameters {'x0': 0.563267, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1687 | ] 1688 | }, 1689 | { 1690 | "output_type": "stream", 1691 | "name": "stderr", 1692 | "text": [ 1693 | "[INFO 07-11 19:34:21] ax.service.ax_client: Completed trial 44 with data: {'objective': (-5.04017, 0.0)}.\n" 1694 | ] 1695 | }, 1696 | { 1697 | "output_type": "stream", 1698 | "name": "stderr", 1699 | "text": [ 1700 | "[INFO 07-11 19:38:27] ax.service.ax_client: Generated new trial 45 with parameters {'x0': 0.563496, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1701 | ] 1702 | }, 1703 | { 1704 | "output_type": "stream", 1705 | "name": "stderr", 1706 | "text": [ 1707 | "[INFO 07-11 19:38:27] ax.service.ax_client: Completed trial 45 with data: {'objective': (-5.040364, 0.0)}.\n" 1708 | ] 1709 | }, 1710 | { 1711 | "output_type": "stream", 1712 | "name": "stderr", 1713 | "text": [ 1714 | "[INFO 07-11 19:41:52] ax.service.ax_client: Generated new trial 46 with parameters {'x0': 0.563076, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1715 | ] 1716 | }, 1717 | { 1718 | "output_type": "stream", 1719 | "name": "stderr", 1720 | "text": [ 1721 | "[INFO 07-11 19:41:52] ax.service.ax_client: Completed trial 46 with data: {'objective': (-5.040109, 0.0)}.\n" 1722 | ] 1723 | }, 1724 | { 1725 | "output_type": "stream", 1726 | "name": "stderr", 1727 | "text": [ 1728 | "[INFO 07-11 19:45:10] ax.service.ax_client: Generated new trial 47 with parameters {'x0': 0.563165, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1729 | ] 1730 | }, 1731 | { 1732 | "output_type": "stream", 1733 | "name": "stderr", 1734 | "text": [ 1735 | "[INFO 07-11 19:45:10] ax.service.ax_client: Completed trial 47 with data: {'objective': (-5.040126, 0.0)}.\n" 1736 | ] 1737 | }, 1738 | { 1739 | "output_type": "stream", 1740 | "name": "stderr", 1741 | "text": [ 1742 | "[INFO 07-11 19:48:48] ax.service.ax_client: Generated new trial 48 with parameters {'x0': 0.562984, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1743 | ] 1744 | }, 1745 | { 1746 | "output_type": "stream", 1747 | "name": "stderr", 1748 | "text": [ 1749 | "[INFO 07-11 19:48:48] ax.service.ax_client: Completed trial 48 with data: {'objective': (-5.040112, 0.0)}.\n" 1750 | ] 1751 | }, 1752 | { 1753 | "output_type": "stream", 1754 | "name": "stderr", 1755 | "text": [ 1756 | "[INFO 07-11 19:53:17] ax.service.ax_client: Generated new trial 49 with parameters {'x0': 0.563213, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x4': 0.0, 'x5': 0.0, 'x6': 0.0, 'x7': 0.0}.\n" 1757 | ] 1758 | }, 1759 | { 1760 | "output_type": "stream", 1761 | "name": "stderr", 1762 | "text": [ 1763 | "[INFO 07-11 19:53:17] ax.service.ax_client: Completed trial 49 with data: {'objective': (-5.040143, 0.0)}.\n" 1764 | ] 1765 | } 1766 | ] 1767 | } 1768 | ] 1769 | } 1770 | --------------------------------------------------------------------------------