├── .gitignore ├── LICENSE ├── README.md ├── azure_tools ├── __init__.py ├── azurepool.py ├── bo_wrapper.py └── common │ ├── __init__.py │ └── helpers.py ├── configs ├── configuration.cfg ├── ebo.cfg ├── py_files └── start_commands ├── ebo_core ├── __init__.py ├── bo.py ├── ebo.py ├── gibbs.py ├── helper.py ├── mondrian.py └── mypool.py ├── gp_tools ├── __init__.py ├── gp.py ├── representation.py └── test_gp_solvers.py ├── test_ebo.py └── test_functions ├── __init__.py ├── push_function.py ├── push_utils.py ├── rover_function.py ├── rover_utils.py └── simple_functions.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Zi Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ensemble-Bayesian-Optimization 2 | This is the code repository associated with the paper [_Batched Large-scale Bayesian Optimization in High-dimensional Spaces_](https://arxiv.org/pdf/1706.01445.pdf). We propose a new batch/distributed Bayesian optimization technique called **Ensemble Bayesian Optimization**, which unprecedentedly scales up Bayesian optimization both in terms of input dimensions and observation size. Please refer to the paper if you need more details on the algorithm. 3 | 4 | ## Requirements 5 | We tested our code with Python 2.7 on Ubuntu 14.04 LTS (64-bit). 6 | 7 | See configs/start_commands for required packages. 8 | 9 | ## Implementations of Gaussian processes 10 | We implemented 4 versions of Gaussian processes in gp_tools/gp.py, which can be used without the BO functionalities. 11 | 12 | * DenseKernelGP: a GP which has a dense kernel matrix. 13 | * SparseKernelGP: a GP which has a sparse kernel matrix. 14 | * SparseFeatureGP: a GP whose kernel is defined by the inner product of two sparse feature vectors. 15 | * DenseFeatureGP: a GP whose kernel is defined by the inner product of two dense feature vectors. 16 | 17 | ## Example 18 | test_ebo.m gives an example of running EBO on a 2 dimensional function with visualizations. 19 | 20 | To run EBO on expensive functions using Microsoft Azure, set the account information in configuration.cfg and the desired pool information in ebo.cfg. Then in the options, set "useAzure" to be True and "func_cheap" to be False. 21 | 22 | ## Test functions 23 | We provide 3 examples of black-box functions: 24 | 25 | * test_functions/simple_functions.py: functions sampled from a GP. 26 | * test_functions/push_function.py: a reward function for two robots pushing two objects. 27 | * test_functions/rover_function.py: a reward function for the trajectory of a 2D rover. 28 | 29 | ## Caveats on the hyperparameters of EBO 30 | From more extensive experiments we found that EBO is not be robust to the hyperparameters of the Mondrian trees including the size of each leaf (min_leaf_size), number of leaves (max_n_leaves), selections per partition (n_bo), etc. Principled ways of setting those parameters remain a future work. 31 | -------------------------------------------------------------------------------- /azure_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zi-w/Ensemble-Bayesian-Optimization/4e6f9ed04833cc2e21b5906b1181bc067298f914/azure_tools/__init__.py -------------------------------------------------------------------------------- /azure_tools/azurepool.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import configparser 5 | except ImportError: 6 | import ConfigParser as configparser 7 | import datetime 8 | import os 9 | 10 | import azure.storage.blob as azureblob 11 | import azure.batch.batch_service_client as batch 12 | import azure.batch.batch_auth as batchauth 13 | import azure.batch.models as batchmodels 14 | 15 | import common.helpers 16 | import time 17 | 18 | try: 19 | import cPickle as pickle 20 | except: 21 | import pickle 22 | import sys 23 | 24 | sys.path.append('.') 25 | 26 | import logging 27 | 28 | _TASK_FILE = 'bo_wrapper.py' 29 | 30 | 31 | class AzurePool(object): 32 | def __init__(self, pool_id, data_dir): 33 | self.pool_id = pool_id 34 | if data_dir[-1] != '/': 35 | data_dir += '/' 36 | self.data_dir = data_dir 37 | self.app = common.helpers.generate_unique_resource_name('app') 38 | self.inp = common.helpers.generate_unique_resource_name('inp') 39 | self.out = common.helpers.generate_unique_resource_name('out') 40 | global_config = configparser.ConfigParser() 41 | global_config.read('configs/configuration.cfg') 42 | 43 | our_config = configparser.ConfigParser() 44 | our_config.read('configs/ebo.cfg') 45 | 46 | batch_account_key = global_config.get('Batch', 'batchaccountkey') 47 | batch_account_name = global_config.get('Batch', 'batchaccountname') 48 | batch_service_url = global_config.get('Batch', 'batchserviceurl') 49 | 50 | storage_account_key = global_config.get('Storage', 'storageaccountkey') 51 | storage_account_name = global_config.get('Storage', 'storageaccountname') 52 | storage_account_suffix = global_config.get( 53 | 'Storage', 54 | 'storageaccountsuffix') 55 | 56 | pool_vm_size = our_config.get( 57 | 'DEFAULT', 58 | 'poolvmsize') 59 | pool_vm_count = our_config.getint( 60 | 'DEFAULT', 61 | 'poolvmcount') 62 | 63 | # remember: no space, file names split by ',' 64 | # app_file_names = our_config.get('APP', 'app').split(',') 65 | # app_file_names = [os.path.realpath(fn) for fn in app_file_names] 66 | # Print the settings we are running with 67 | common.helpers.print_configuration(global_config) 68 | common.helpers.print_configuration(our_config) 69 | 70 | credentials = batchauth.SharedKeyCredentials( 71 | batch_account_name, 72 | batch_account_key) 73 | batch_client = batch.BatchServiceClient( 74 | credentials, 75 | base_url=batch_service_url) 76 | 77 | # Retry 5 times -- default is 3 78 | batch_client.config.retry_policy.retries = 5 79 | 80 | self.storage_account_name = storage_account_name 81 | 82 | block_blob_client = azureblob.BlockBlobService( 83 | account_name=storage_account_name, 84 | account_key=storage_account_key, 85 | endpoint_suffix=storage_account_suffix) 86 | 87 | # create containers 88 | 89 | block_blob_client.create_container(self.app, fail_on_exist=False) 90 | block_blob_client.create_container(self.inp, fail_on_exist=False) 91 | block_blob_client.create_container(self.out, fail_on_exist=False) 92 | # app_files = upload_files(block_blob_client, self.app, app_file_names) 93 | 94 | output_container_sas_token = get_container_sas_token( 95 | block_blob_client, 96 | self.out, 97 | azureblob.BlobPermissions.WRITE) 98 | 99 | self.out_sas_token = output_container_sas_token 100 | 101 | create_pool(batch_client, pool_id, pool_vm_size, pool_vm_count, None) 102 | 103 | # create necessary folders 104 | if not os.path.exists(data_dir): 105 | os.makedirs(data_dir) 106 | 107 | self.batch_client = batch_client 108 | self.block_blob_client = block_blob_client 109 | 110 | def install(self): 111 | pool_id = self.pool_id 112 | batch_client = self.batch_client 113 | block_blob_client = self.block_blob_client 114 | job_id = common.helpers.generate_unique_resource_name( 115 | 'install') 116 | run_commands(batch_client, block_blob_client, job_id, pool_id) 117 | common.helpers.wait_for_tasks_to_complete( 118 | batch_client, 119 | job_id, 120 | datetime.timedelta(minutes=25)) 121 | 122 | tasks = batch_client.task.list(job_id) 123 | task_ids = [task.id for task in tasks] 124 | 125 | common.helpers.print_task_output(batch_client, job_id, task_ids) 126 | 127 | def end(self): 128 | self.delete_containers() 129 | self.delpool() 130 | 131 | def delpool(self): 132 | print("Deleting pool: ", self.pool_id) 133 | self.batch_client.pool.delete(self.pool_id) 134 | 135 | def delete_containers(self): 136 | # pool_id, batch_client, block_blob_client 137 | # clean up 138 | self.block_blob_client.delete_container( 139 | self.inp, 140 | fail_not_exist=False) 141 | self.block_blob_client.delete_container( 142 | self.out, 143 | fail_not_exist=False) 144 | self.block_blob_client.delete_container( 145 | self.app, 146 | fail_not_exist=False) 147 | 148 | def reboot_failed_nodes(self): 149 | nodes = list(self.batch_client.compute_node.list(self.pool_id)) 150 | failed = [n.id for n in nodes if n.state == batchmodels.ComputeNodeState.start_task_failed] 151 | 152 | for node in failed: 153 | self.batch_client.compute_node.reboot(self.pool_id, node) 154 | 155 | def reboot(self): 156 | nodes = list(self.batch_client.compute_node.list(self.pool_id)) 157 | errored = [n.id for n in nodes if n.state == batchmodels.ComputeNodeState.unusable] 158 | working_nodes = [n.id for n in nodes if n not in errored] 159 | 160 | for node in working_nodes: 161 | self.batch_client.compute_node.reboot(self.pool_id, node) 162 | 163 | def map(self, parameters, job_id): 164 | # write parameters to files 165 | logging.info('In AzurePool map, job id [' + job_id + ']') 166 | batch_client = self.batch_client 167 | block_blob_client = self.block_blob_client 168 | job_id = common.helpers.generate_unique_resource_name( 169 | job_id) 170 | common.helpers.delete_blobs_from_container(block_blob_client, self.out) 171 | input_file_names = [os.path.join(self.data_dir, str(i) + '.pk') for i in xrange(len(parameters))] 172 | for i, p in enumerate(parameters): 173 | pickle.dump(p, open(input_file_names[i], 'wb')) 174 | # input_file_names = [os.path.realpath(fn) for fn in input_file_names] 175 | 176 | in_files = upload_files(block_blob_client, self.inp, input_file_names) 177 | 178 | # get app files again 179 | # remember: no blank line, one file each line 180 | app_file_names = get_list_from_file('configs/py_files') 181 | # app_file_names = [os.path.realpath(fn) for fn in app_file_names] 182 | app_files = upload_files(block_blob_client, self.app, app_file_names) 183 | 184 | submit_job_and_add_tasks(batch_client, block_blob_client, job_id, self.pool_id, in_files, self.out, app_files, 185 | self.storage_account_name, self.out_sas_token) 186 | 187 | common.helpers.wait_for_tasks_to_complete( 188 | batch_client, 189 | job_id, 190 | datetime.timedelta(minutes=20)) 191 | 192 | # GET outputs 193 | common.helpers.download_blobs_from_container(block_blob_client, self.out, './') 194 | 195 | # print(os.path.join(self.data_dir, str(0) + '_out.pk')) 196 | ret = [] 197 | for i in xrange(len(parameters)): 198 | fnm = os.path.join(self.data_dir, str(i) + '_out.pk') 199 | if os.path.isfile(fnm): 200 | ret.append(pickle.load(open(fnm))) 201 | else: 202 | logging.warning('In AzurePool map, job id [' + job_id + '], ignoring lost parameter ' + str(i)) 203 | if len(ret) == 0: 204 | try: 205 | tasks = batch_client.task.list(job_id) 206 | task_ids = [task.id for task in tasks] 207 | common.helpers.print_task_output(batch_client, job_id, task_ids) 208 | assert 0 == 1, 'No return from azure' 209 | except Exception as e: 210 | print('No return from azure and pring task output failed.') 211 | logging.error(e) 212 | raise e 213 | 214 | batch_client.job.delete(job_id) 215 | return ret 216 | 217 | 218 | def upload_files(block_blob_client, container_name, files): 219 | return [get_resource_file(block_blob_client, container_name, \ 220 | file_path, os.path.realpath(file_path)) for file_path in files] 221 | 222 | 223 | def get_resource_file(block_blob_client, container_name, blob_name, file_path): 224 | sas_url = common.helpers.upload_blob_and_create_sas( 225 | block_blob_client, 226 | container_name, 227 | blob_name, 228 | file_path, 229 | datetime.datetime.utcnow() + datetime.timedelta(hours=1)) 230 | logging.info('Uploading file {} from {} to container [{}]...'.format(blob_name, file_path, container_name)) 231 | return batchmodels.ResourceFile(file_path=blob_name, 232 | blob_source=sas_url) 233 | 234 | 235 | def get_list_from_file(file_nm): 236 | """ 237 | Obtains the list of lines from a file. 238 | :param str file_nm: The name of the file. 239 | :rtype: list 240 | :return: A list of the striped lines. 241 | """ 242 | with open(file_nm) as f: 243 | content = f.readlines() 244 | return [x.strip() for x in content] 245 | 246 | 247 | def create_pool(batch_client, pool_id, vm_size, vm_count, app_files): 248 | """Creates an Azure Batch pool with the specified id. 249 | 250 | :param batch_client: The batch client to use. 251 | :type batch_client: `batchserviceclient.BatchServiceClient` 252 | :param block_blob_client: The storage block blob client to use. 253 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 254 | :param str pool_id: The id of the pool to create. 255 | :param str vm_size: vm size (sku) 256 | :param int vm_count: number of vms to allocate 257 | :param list app_files: The list of all the other scripts to upload. 258 | """ 259 | # pick the latest supported 16.04 sku for UbuntuServer 260 | sku_to_use, image_ref_to_use = \ 261 | common.helpers.select_latest_verified_vm_image_with_node_agent_sku( 262 | batch_client, 'Canonical', 'UbuntuServer', '14.04') 263 | user = batchmodels.AutoUserSpecification( 264 | scope=batchmodels.AutoUserScope.pool, 265 | elevation_level=batchmodels.ElevationLevel.admin) 266 | task_commands = get_list_from_file('configs/start_commands') 267 | print(task_commands) 268 | pool = batchmodels.PoolAddParameter( 269 | id=pool_id, 270 | virtual_machine_configuration=batchmodels.VirtualMachineConfiguration( 271 | image_reference=image_ref_to_use, 272 | node_agent_sku_id=sku_to_use), 273 | vm_size=vm_size, 274 | target_dedicated=vm_count, 275 | start_task=batchmodels.StartTask( 276 | command_line=common.helpers.wrap_commands_in_shell('linux', task_commands), 277 | user_identity=batchmodels.UserIdentity(auto_user=user), 278 | resource_files=app_files, 279 | wait_for_success=True)) 280 | 281 | common.helpers.create_pool_if_not_exist(batch_client, pool) 282 | 283 | 284 | def run_commands(batch_client, block_blob_client, job_id, pool_id): 285 | """Run the start commands listed in the file "start_commands" on 286 | all the nodes of the Azure Batch service. 287 | 288 | :param batch_client: The batch client to use. 289 | :type batch_client: `batchserviceclient.BatchServiceClient` 290 | :param block_blob_client: The storage block blob client to use. 291 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 292 | :param str job_id: The id of the job to create. 293 | :param str pool_id: The id of the pool to use. 294 | """ 295 | task_commands = get_list_from_file('configs/start_commands') 296 | logging.info(task_commands) 297 | user = batchmodels.AutoUserSpecification( 298 | scope=batchmodels.AutoUserScope.pool, 299 | elevation_level=batchmodels.ElevationLevel.admin) 300 | 301 | start = time.time() 302 | job = batchmodels.JobAddParameter( 303 | id=job_id, 304 | pool_info=batchmodels.PoolInformation(pool_id=pool_id)) 305 | 306 | batch_client.job.add(job) 307 | logging.info('job created in seconds {}'.format(time.time() - start)) 308 | 309 | start = time.time() 310 | nodes = list(batch_client.compute_node.list(pool_id)) 311 | tasks = [batchmodels.TaskAddParameter( 312 | id="EBOTask-{}".format(i), 313 | command_line=common.helpers.wrap_commands_in_shell('linux', task_commands), 314 | user_identity=batchmodels.UserIdentity(auto_user=user)) \ 315 | for i in xrange(len(nodes))] 316 | 317 | batch_client.task.add_collection(job.id, tasks) 318 | logging.info('task created in seconds {}'.format(time.time() - start)) 319 | 320 | 321 | def submit_job_and_add_tasks(batch_client, block_blob_client, job_id, pool_id, in_files, out_container_name, app_files, 322 | storage_account_name, out_sas_token): 323 | """Submits jobs to the Azure Batch service and adds 324 | tasks that runs a python script. 325 | 326 | :param batch_client: The batch client to use. 327 | :type batch_client: `batchserviceclient.BatchServiceClient` 328 | :param block_blob_client: The storage block blob client to use. 329 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 330 | :param str job_id: The id of the job to create. 331 | :param str pool_id: The id of the pool to use. 332 | :param list in_files: The list of the file paths of the inputs. 333 | :param str out_container_name: The name of the output container. 334 | :param list app_files: The list of all the other scripts to upload. 335 | :param str storage_account_name: The name of the storage account. 336 | :param str out_sas_token: A SAS token granting the specified 337 | permissions to the output container. 338 | """ 339 | start = time.time() 340 | job = batchmodels.JobAddParameter( 341 | id=job_id, 342 | pool_info=batchmodels.PoolInformation(pool_id=pool_id)) 343 | 344 | batch_client.job.add(job) 345 | logging.info('job created in seconds {}'.format(time.time() - start)) 346 | 347 | start = time.time() 348 | 349 | tasks = [batchmodels.TaskAddParameter( 350 | id="EBOTask-{}".format(i), 351 | command_line='python {} --filepath {} --storageaccount {} --storagecontainer {} --sastoken "{}"'.format( 352 | _TASK_FILE, 353 | in_file.file_path, 354 | storage_account_name, 355 | out_container_name, 356 | out_sas_token), 357 | resource_files=[in_file] + app_files) \ 358 | for i, in_file in enumerate(in_files)] 359 | 360 | cnt = 0 361 | tot_tasks = len(tasks) 362 | while cnt < tot_tasks: 363 | try: 364 | batch_client.task.add_collection(job.id, tasks[cnt:cnt + 100]) 365 | cnt += 100 366 | except Exception as e: 367 | print("Adding task failed... Going to try again in 5 seconds") 368 | logging.error(e) 369 | time.sleep(5) 370 | logging.info('task created in seconds {}'.format(time.time() - start)) 371 | 372 | 373 | def get_container_sas_token(block_blob_client, 374 | container_name, blob_permissions): 375 | """ 376 | Obtains a shared access signature granting the specified permissions to the 377 | container. 378 | :param block_blob_client: A blob service client. 379 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 380 | :param str container_name: The name of the Azure Blob storage container. 381 | :param BlobPermissions blob_permissions: 382 | :rtype: str 383 | :return: A SAS token granting the specified permissions to the container. 384 | """ 385 | # Obtain the SAS token for the container, setting the expiry time and 386 | # permissions. In this case, no start time is specified, so the shared 387 | # access signature becomes valid immediately. 388 | container_sas_token = \ 389 | block_blob_client.generate_container_shared_access_signature( 390 | container_name, 391 | permission=blob_permissions, 392 | expiry=datetime.datetime.utcnow() + datetime.timedelta(hours=2)) 393 | return container_sas_token 394 | -------------------------------------------------------------------------------- /azure_tools/bo_wrapper.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | 5 | import azure.storage.blob as azureblob 6 | try: 7 | import cPickle as pickle 8 | except: 9 | import pickle 10 | 11 | from ebo_core.bo import bo 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--filepath', required=True, 16 | help='The path to the text file to process. The path' 17 | 'may include a compute node\'s environment' 18 | 'variables, such as' 19 | '$AZ_BATCH_NODE_SHARED_DIR/filename.txt') 20 | parser.add_argument('--storageaccount', required=True, 21 | help='The name the Azure Storage account that owns the' 22 | 'blob storage container to which to upload' 23 | 'results.') 24 | parser.add_argument('--storagecontainer', required=True, 25 | help='The Azure Blob storage container to which to' 26 | 'upload results.') 27 | parser.add_argument('--sastoken', required=True, 28 | help='The SAS token providing write access to the' 29 | 'Storage container.') 30 | args = parser.parse_args() 31 | input_file = os.path.realpath(args.filepath) 32 | output_file = '{}_out{}'.format( 33 | os.path.splitext(args.filepath)[0], 34 | os.path.splitext(args.filepath)[1]) 35 | parameter = pickle.load(open(input_file)) 36 | #print(parameter) 37 | 38 | b = bo(*parameter) 39 | res = b.run() 40 | 41 | #print(res) 42 | pickle.dump(res, open(output_file, 'wb')) 43 | 44 | print("bo_wrapper.py listing files:") 45 | for item in os.listdir('.'): 46 | print(item) 47 | 48 | # Create the blob client using the container's SAS token. 49 | # This allows us to create a client that provides write 50 | # access only to the container. 51 | blob_client = azureblob.BlockBlobService(account_name=args.storageaccount, sas_token=args.sastoken) 52 | 53 | output_file_path = os.path.realpath(output_file) 54 | 55 | print('Uploading file {} to container [{}]...'.format( 56 | output_file_path, 57 | args.storagecontainer)) 58 | 59 | blob_client.create_blob_from_path(args.storagecontainer, 60 | output_file, 61 | output_file_path) 62 | -------------------------------------------------------------------------------- /azure_tools/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zi-w/Ensemble-Bayesian-Optimization/4e6f9ed04833cc2e21b5906b1181bc067298f914/azure_tools/common/__init__.py -------------------------------------------------------------------------------- /azure_tools/common/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation 2 | # 3 | # All rights reserved. 4 | # 5 | # MIT License 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a 8 | # copy of this software and associated documentation files (the "Software"), 9 | # to deal in the Software without restriction, including without limitation 10 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 11 | # and/or sell copies of the Software, and to permit persons to whom the 12 | # Software is furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 22 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | # DEALINGS IN THE SOFTWARE. 24 | 25 | from __future__ import print_function 26 | import datetime 27 | import io 28 | import os 29 | import time 30 | 31 | import azure.storage.blob as azureblob 32 | import azure.batch.models as batchmodels 33 | import logging 34 | 35 | _STANDARD_OUT_FILE_NAME = 'stdout.txt' 36 | _STANDARD_ERROR_FILE_NAME = 'stderr.txt' 37 | _SAMPLES_CONFIG_FILE_NAME = 'configuration.cfg' 38 | 39 | 40 | class TimeoutError(Exception): 41 | """An error which can occur if a timeout has expired. 42 | """ 43 | def __init__(self, message): 44 | self.message = message 45 | 46 | 47 | def decode_string(string, encoding=None): 48 | """Decode a string with specified encoding 49 | 50 | :type string: str or bytes 51 | :param string: string to decode 52 | :param str encoding: encoding of string to decode 53 | :rtype: str 54 | :return: decoded string 55 | """ 56 | if isinstance(string, str): 57 | return string 58 | if encoding is None: 59 | encoding = 'utf-8' 60 | if isinstance(string, bytes): 61 | return string.decode(encoding) 62 | raise ValueError('invalid string type: {}'.format(type(string))) 63 | 64 | 65 | def select_latest_verified_vm_image_with_node_agent_sku( 66 | batch_client, publisher, offer, sku_starts_with): 67 | """Select the latest verified image that Azure Batch supports given 68 | a publisher, offer and sku (starts with filter). 69 | 70 | :param batch_client: The batch client to use. 71 | :type batch_client: `batchserviceclient.BatchServiceClient` 72 | :param str publisher: vm image publisher 73 | :param str offer: vm image offer 74 | :param str sku_starts_with: vm sku starts with filter 75 | :rtype: tuple 76 | :return: (node agent sku id to use, vm image ref to use) 77 | """ 78 | # get verified vm image list and node agent sku ids from service 79 | node_agent_skus = batch_client.account.list_node_agent_skus() 80 | # pick the latest supported sku 81 | skus_to_use = [ 82 | (sku, image_ref) for sku in node_agent_skus for image_ref in sorted( 83 | sku.verified_image_references, key=lambda item: item.sku) 84 | if image_ref.publisher.lower() == publisher.lower() and 85 | image_ref.offer.lower() == offer.lower() and 86 | image_ref.sku.startswith(sku_starts_with) 87 | ] 88 | # skus are listed in reverse order, pick first for latest 89 | sku_to_use, image_ref_to_use = skus_to_use[0] 90 | return (sku_to_use.id, image_ref_to_use) 91 | 92 | 93 | def wait_for_tasks_to_complete(batch_client, job_id, timeout): 94 | """Waits for all the tasks in a particular job to complete. 95 | 96 | :param batch_client: The batch client to use. 97 | :type batch_client: `batchserviceclient.BatchServiceClient` 98 | :param str job_id: The id of the job to monitor. 99 | :param timeout: The maximum amount of time to wait. 100 | :type timeout: `datetime.timedelta` 101 | """ 102 | time_to_timeout_at = datetime.datetime.now() + timeout 103 | while datetime.datetime.now() < time_to_timeout_at: 104 | print("Checking if all tasks are complete...") 105 | try: 106 | tasks = batch_client.task.list(job_id) 107 | incomplete_tasks = [task for task in tasks if 108 | task.state != batchmodels.TaskState.completed] 109 | if not incomplete_tasks: 110 | return 111 | except Exception as e: 112 | print("Checking failed...") 113 | logging.error(e) 114 | time.sleep(5) 115 | 116 | raise TimeoutError("Timed out waiting for tasks to complete") 117 | 118 | 119 | def print_task_output(batch_client, job_id, task_ids, encoding=None): 120 | """Prints the stdout and stderr for each task specified. 121 | 122 | :param batch_client: The batch client to use. 123 | :type batch_client: `batchserviceclient.BatchServiceClient` 124 | :param str job_id: The id of the job to monitor. 125 | :param task_ids: The collection of tasks to print the output for. 126 | :type task_ids: `list` 127 | :param str encoding: The encoding to use when downloading the file. 128 | """ 129 | for task_id in task_ids: 130 | file_text = read_task_file_as_string( 131 | batch_client, 132 | job_id, 133 | task_id, 134 | _STANDARD_OUT_FILE_NAME, 135 | encoding) 136 | print("{} content for task {}: ".format( 137 | _STANDARD_OUT_FILE_NAME, 138 | task_id)) 139 | print(file_text) 140 | 141 | file_text = read_task_file_as_string( 142 | batch_client, 143 | job_id, 144 | task_id, 145 | _STANDARD_ERROR_FILE_NAME, 146 | encoding) 147 | print("{} content for task {}: ".format( 148 | _STANDARD_ERROR_FILE_NAME, 149 | task_id)) 150 | print(file_text) 151 | 152 | 153 | def print_configuration(config): 154 | """Prints the configuration being used as a dictionary 155 | 156 | :param config: The configuration. 157 | :type config: `configparser.ConfigParser` 158 | """ 159 | configuration_dict = {s: dict(config.items(s)) for s in 160 | config.sections() + ['DEFAULT']} 161 | 162 | print("Configuration is:") 163 | print(configuration_dict) 164 | 165 | 166 | def _read_stream_as_string(stream, encoding): 167 | """Read stream as string 168 | 169 | :param stream: input stream generator 170 | :param str encoding: The encoding of the file. The default is utf-8. 171 | :return: The file content. 172 | :rtype: str 173 | """ 174 | output = io.BytesIO() 175 | try: 176 | for data in stream: 177 | output.write(data) 178 | if encoding is None: 179 | encoding = 'utf-8' 180 | return output.getvalue().decode(encoding) 181 | finally: 182 | output.close() 183 | raise RuntimeError('could not write data to stream or decode bytes') 184 | 185 | 186 | def read_task_file_as_string( 187 | batch_client, job_id, task_id, file_name, encoding=None): 188 | """Reads the specified file as a string. 189 | 190 | :param batch_client: The batch client to use. 191 | :type batch_client: `batchserviceclient.BatchServiceClient` 192 | :param str job_id: The id of the job. 193 | :param str task_id: The id of the task. 194 | :param str file_name: The name of the file to read. 195 | :param str encoding: The encoding of the file. The default is utf-8. 196 | :return: The file content. 197 | :rtype: str 198 | """ 199 | stream = batch_client.file.get_from_task(job_id, task_id, file_name) 200 | return _read_stream_as_string(stream, encoding) 201 | 202 | 203 | def read_compute_node_file_as_string( 204 | batch_client, pool_id, node_id, file_name, encoding=None): 205 | """Reads the specified file as a string. 206 | 207 | :param batch_client: The batch client to use. 208 | :type batch_client: `batchserviceclient.BatchServiceClient` 209 | :param str pool_id: The id of the pool. 210 | :param str node_id: The id of the node. 211 | :param str file_name: The name of the file to read. 212 | :param str encoding: The encoding of the file. The default is utf-8 213 | :return: The file content. 214 | :rtype: str 215 | """ 216 | stream = batch_client.file.get_from_compute_node( 217 | pool_id, node_id, file_name) 218 | return _read_stream_as_string(stream, encoding) 219 | 220 | 221 | def create_pool_if_not_exist(batch_client, pool): 222 | """Creates the specified pool if it doesn't already exist 223 | 224 | :param batch_client: The batch client to use. 225 | :type batch_client: `batchserviceclient.BatchServiceClient` 226 | :param pool: The pool to create. 227 | :type pool: `batchserviceclient.models.PoolAddParameter` 228 | """ 229 | try: 230 | print("Attempting to create pool:", pool.id) 231 | batch_client.pool.add(pool) 232 | print("Created pool:", pool.id) 233 | except batchmodels.BatchErrorException as e: 234 | if e.error.code != "PoolExists": 235 | raise 236 | else: 237 | print("Pool {!r} already exists".format(pool.id)) 238 | 239 | 240 | def create_job(batch_service_client, job_id, pool_id): 241 | """ 242 | Creates a job with the specified ID, associated with the specified pool. 243 | 244 | :param batch_service_client: A Batch service client. 245 | :type batch_service_client: `azure.batch.BatchServiceClient` 246 | :param str job_id: The ID for the job. 247 | :param str pool_id: The ID for the pool. 248 | """ 249 | print('Creating job [{}]...'.format(job_id)) 250 | 251 | job = batchmodels.JobAddParameter( 252 | job_id, 253 | batchmodels.PoolInformation(pool_id=pool_id)) 254 | 255 | try: 256 | batch_service_client.job.add(job) 257 | except batchmodels.batch_error.BatchErrorException as err: 258 | print_batch_exception(err) 259 | if err.error.code != "JobExists": 260 | raise 261 | else: 262 | print("Job {!r} already exists".format(job_id)) 263 | 264 | 265 | def wait_for_all_nodes_state(batch_client, pool, node_state): 266 | """Waits for all nodes in pool to reach any specified state in set 267 | 268 | :param batch_client: The batch client to use. 269 | :type batch_client: `batchserviceclient.BatchServiceClient` 270 | :param pool: The pool containing the node. 271 | :type pool: `batchserviceclient.models.CloudPool` 272 | :param set node_state: node states to wait for 273 | :rtype: list 274 | :return: list of `batchserviceclient.models.ComputeNode` 275 | """ 276 | print('waiting for all nodes in pool {} to reach one of: {!r}'.format( 277 | pool.id, node_state)) 278 | i = 0 279 | while True: 280 | # refresh pool to ensure that there is no resize error 281 | pool = batch_client.pool.get(pool.id) 282 | if pool.resize_error is not None: 283 | raise RuntimeError( 284 | 'resize error encountered for pool {}: {!r}'.format( 285 | pool.id, pool.resize_error)) 286 | nodes = list(batch_client.compute_node.list(pool.id)) 287 | if (len(nodes) >= pool.target_dedicated and 288 | all(node.state in node_state for node in nodes)): 289 | return nodes 290 | i += 1 291 | if i % 3 == 0: 292 | print('waiting for {} nodes to reach desired state...'.format( 293 | pool.target_dedicated)) 294 | time.sleep(10) 295 | 296 | 297 | def create_container_and_create_sas( 298 | block_blob_client, container_name, permission, expiry=None, 299 | timeout=None): 300 | """Create a blob sas token 301 | 302 | :param block_blob_client: The storage block blob client to use. 303 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 304 | :param str container_name: The name of the container to upload the blob to. 305 | :param expiry: The SAS expiry time. 306 | :type expiry: `datetime.datetime` 307 | :param int timeout: timeout in minutes from now for expiry, 308 | will only be used if expiry is not specified 309 | :return: A SAS token 310 | :rtype: str 311 | """ 312 | if expiry is None: 313 | if timeout is None: 314 | timeout = 30 315 | expiry = datetime.datetime.utcnow() + datetime.timedelta( 316 | minutes=timeout) 317 | 318 | block_blob_client.create_container( 319 | container_name, 320 | fail_on_exist=False) 321 | 322 | return block_blob_client.generate_container_shared_access_signature( 323 | container_name=container_name, permission=permission, expiry=expiry) 324 | 325 | 326 | def create_sas_token( 327 | block_blob_client, container_name, blob_name, permission, expiry=None, 328 | timeout=None): 329 | """Create a blob sas token 330 | 331 | :param block_blob_client: The storage block blob client to use. 332 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 333 | :param str container_name: The name of the container to upload the blob to. 334 | :param str blob_name: The name of the blob to upload the local file to. 335 | :param expiry: The SAS expiry time. 336 | :type expiry: `datetime.datetime` 337 | :param int timeout: timeout in minutes from now for expiry, 338 | will only be used if expiry is not specified 339 | :return: A SAS token 340 | :rtype: str 341 | """ 342 | if expiry is None: 343 | if timeout is None: 344 | timeout = 30 345 | expiry = datetime.datetime.utcnow() + datetime.timedelta( 346 | minutes=timeout) 347 | return block_blob_client.generate_blob_shared_access_signature( 348 | container_name, blob_name, permission=permission, expiry=expiry) 349 | 350 | 351 | def upload_blob_and_create_sas( 352 | block_blob_client, container_name, blob_name, file_name, expiry, 353 | timeout=None): 354 | """Uploads a file from local disk to Azure Storage and creates 355 | a SAS for it. 356 | 357 | :param block_blob_client: The storage block blob client to use. 358 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 359 | :param str container_name: The name of the container to upload the blob to. 360 | :param str blob_name: The name of the blob to upload the local file to. 361 | :param str file_name: The name of the local file to upload. 362 | :param expiry: The SAS expiry time. 363 | :type expiry: `datetime.datetime` 364 | :param int timeout: timeout in minutes from now for expiry, 365 | will only be used if expiry is not specified 366 | :return: A SAS URL to the blob with the specified expiry time. 367 | :rtype: str 368 | """ 369 | block_blob_client.create_container( 370 | container_name, 371 | fail_on_exist=False) 372 | 373 | block_blob_client.create_blob_from_path( 374 | container_name, 375 | blob_name, 376 | file_name) 377 | 378 | sas_token = create_sas_token( 379 | block_blob_client, 380 | container_name, 381 | blob_name, 382 | permission=azureblob.BlobPermissions.READ, 383 | expiry=expiry, 384 | timeout=timeout) 385 | 386 | sas_url = block_blob_client.make_blob_url( 387 | container_name, 388 | blob_name, 389 | sas_token=sas_token) 390 | 391 | return sas_url 392 | 393 | 394 | def upload_file_to_container( 395 | block_blob_client, container_name, file_path, timeout): 396 | """ 397 | Uploads a local file to an Azure Blob storage container. 398 | 399 | :param block_blob_client: A blob service client. 400 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 401 | :param str container_name: The name of the Azure Blob storage container. 402 | :param str file_path: The local path to the file. 403 | :param int timeout: timeout in minutes from now for expiry, 404 | will only be used if expiry is not specified 405 | :rtype: `azure.batch.models.ResourceFile` 406 | :return: A ResourceFile initialized with a SAS URL appropriate for Batch 407 | tasks. 408 | """ 409 | blob_name = os.path.basename(file_path) 410 | print('Uploading file {} to container [{}]...'.format( 411 | file_path, container_name)) 412 | sas_url = upload_blob_and_create_sas( 413 | block_blob_client, container_name, blob_name, file_path, expiry=None, 414 | timeout=timeout) 415 | return batchmodels.ResourceFile( 416 | file_path=blob_name, blob_source=sas_url) 417 | 418 | 419 | def download_blob_from_container( 420 | block_blob_client, container_name, blob_name, directory_path): 421 | """ 422 | Downloads specified blob from the specified Azure Blob storage container. 423 | 424 | :param block_blob_client: A blob service client. 425 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 426 | :param container_name: The Azure Blob storage container from which to 427 | download file. 428 | :param blob_name: The name of blob to be downloaded 429 | :param directory_path: The local directory to which to download the file. 430 | """ 431 | print('Downloading result file from container [{}]...'.format( 432 | container_name)) 433 | 434 | destination_file_path = os.path.join(directory_path, blob_name) 435 | 436 | block_blob_client.get_blob_to_path( 437 | container_name, blob_name, destination_file_path) 438 | 439 | print(' Downloaded blob [{}] from container [{}] to {}'.format( 440 | blob_name, container_name, destination_file_path)) 441 | 442 | print(' Download complete!') 443 | 444 | def download_blobs_from_container(block_blob_client, 445 | container_name, directory_path): 446 | """ 447 | Downloads all blobs from the specified Azure Blob storage container. 448 | :param block_blob_client: A blob service client. 449 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 450 | :param container_name: The Azure Blob storage container from which to 451 | download files. 452 | :param directory_path: The local directory to which to download the files. 453 | """ 454 | logging.info('Downloading all files from container [{}]...'.format( 455 | container_name)) 456 | container_blobs = block_blob_client.list_blobs(container_name) 457 | for blob in container_blobs.items: 458 | destination_file_path = os.path.join(directory_path, blob.name) 459 | block_blob_client.get_blob_to_path(container_name, 460 | blob.name, 461 | destination_file_path) 462 | logging.info(' Downloaded blob [{}] from container [{}] to {}'.format( 463 | blob.name, 464 | container_name, 465 | destination_file_path)) 466 | logging.info(' Download complete!') 467 | 468 | def delete_blobs_from_container(block_blob_client, 469 | container_name): 470 | """ 471 | Downloads all blobs from the specified Azure Blob storage container. 472 | :param block_blob_client: A blob service client. 473 | :type block_blob_client: `azure.storage.blob.BlockBlobService` 474 | :param container_name: The Azure Blob storage container from which to 475 | download files. 476 | :param directory_path: The local directory to which to download the files. 477 | """ 478 | logging.info('Deleting all files from container [{}]...'.format( 479 | container_name)) 480 | container_blobs = block_blob_client.list_blobs(container_name) 481 | for blob in container_blobs.items: 482 | block_blob_client.delete_blob(container_name, blob.name) 483 | logging.info(' Deleted blob [{}] from container [{}]'.format( 484 | blob.name, 485 | container_name)) 486 | logging.info(' Deletion complete!') 487 | 488 | def generate_unique_resource_name(resource_prefix): 489 | """Generates a unique resource name by appending a time 490 | string after the specified prefix. 491 | 492 | :param str resource_prefix: The resource prefix to use. 493 | :return: A string with the format "resource_prefix-