├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── colab_codelab.ipynb ├── compile_pb.sh ├── docs ├── executable_specs.md ├── executors.md ├── metadata_storage.md ├── oss_xm_scope.md ├── parameter_controller.md ├── tensorboard.md └── xm_launch_api_principles.md ├── examples ├── cifar10_tensorflow │ ├── cifar10.py │ ├── launcher.py │ └── requirements.txt ├── cifar10_tensorflow_k8s_multiworker │ ├── cifar10.py │ ├── launcher.py │ └── requirements.txt ├── cifar10_tensorflow_k8s_ps │ ├── cifar10.py │ ├── launcher.py │ └── requirements.txt ├── cifar10_tensorflow_k8s_tensorboard │ ├── cifar10.py │ ├── launcher.py │ └── requirements.txt ├── cifar10_tensorflow_tpu │ ├── cifar10.py │ ├── launcher.py │ └── requirements.txt ├── cifar10_torch │ ├── cifar10.py │ ├── launcher.py │ └── requirements.txt ├── cifar10_torch_xla │ ├── cifar10.py │ ├── launcher.py │ └── requirements.txt ├── dockerfile │ ├── Dockerfile │ ├── arg_printer.cc │ └── launcher.py ├── dopamine │ ├── launcher.py │ └── train.py ├── local_arg_printer │ ├── BUILD │ ├── WORKSPACE │ ├── arg_printer.cc │ └── launcher.py ├── local_container_gpu │ ├── cifar10.py │ ├── launcher.py │ └── requirements.txt ├── local_container_links │ ├── BUILD.bazel │ ├── README.md │ ├── WORKSPACE │ ├── launcher.py │ ├── requirements.txt │ └── server.py ├── parameter_controller │ ├── db_config.yaml │ ├── inner_job │ │ ├── requirements.txt │ │ └── wait_job.py │ ├── launcher.py │ └── requirements.txt └── vizier │ ├── README.md │ ├── launcher.py │ ├── polynomial.py │ └── requirements.txt ├── jupyter_codelab.ipynb ├── setup.py ├── setup_scripts ├── install_bazel.sh ├── install_docker.sh ├── install_gcloud.sh ├── install_python.sh ├── install_xmanager.sh ├── setup_all.sh └── setup_gcp.sh └── xmanager ├── bazel ├── client.py └── file_utils.py ├── cli ├── README.md └── cli.py ├── cloud ├── README.md ├── __init__.py ├── auth.py ├── auth_test.py ├── build_image.py ├── build_image_test.py ├── cloud_build.py ├── cloud_build_test.py ├── data │ └── wrapped_entrypoint.sh ├── docker_lib.py ├── kubernetes.py ├── kubernetes_test.py ├── utils.py ├── utils_test.py ├── vertex.py └── vertex_test.py ├── contrib ├── __init__.py ├── addressing.py ├── addressing_test.py ├── copybara.py ├── executor_selector.py ├── flow.py ├── framework_defaults.py ├── framework_defaults_test.py ├── gcs.py ├── gcs_test.py ├── parameter_controller.py ├── process_entry.py ├── tensorboard.py ├── tensorboard_test.py ├── tpu.py ├── xm_tensorflow.py └── xm_tensorflow_test.py ├── docker └── docker_adapter.py ├── generated ├── README.md ├── build_event_stream_pb2.py ├── command_line_pb2.py ├── data_pb2.py ├── failure_details_pb2.py ├── invocation_policy_pb2.py └── option_filters_pb2.py ├── module_lazy_loader ├── lazy_loader_module_attrs_test.py └── module_lazy_loader.py ├── vizier └── vizier_cloud │ ├── __init__.py │ ├── study_factory.py │ ├── vizier_controller.py │ ├── vizier_exploration.py │ └── vizier_worker.py ├── xm ├── README.md ├── __init__.py ├── async_packager.py ├── async_packager_test.py ├── compute_units.py ├── core.py ├── core_test.py ├── executables.py ├── executables_test.py ├── id_predictor.py ├── id_predictor_test.py ├── job_blocks.py ├── job_blocks_test.py ├── job_operators.py ├── job_operators_test.py ├── metadata_context.py ├── packagables.py ├── packagables_generator.py ├── packagables_test.py ├── resources.py ├── resources_test.py ├── utils.py └── utils_test.py ├── xm_flags.py ├── xm_local ├── README.md ├── __init__.py ├── executables.py ├── execution.py ├── execution_test.py ├── executors.py ├── executors_test.py ├── experiment.py ├── handles.py ├── multiplexer.py ├── packaging │ ├── bazel_tools.py │ ├── bazel_tools_test.py │ ├── cloud.py │ ├── local.py │ └── router.py ├── registry.py ├── status.py └── storage │ ├── alembic.ini │ ├── alembic │ ├── README.md │ ├── env.py │ ├── script.py.mako │ └── versions │ │ └── f45829405692_migrate_or_create.py │ ├── data.proto │ └── database.py └── xm_mock └── __init__.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | The project currently does not accept pull requests, but we'd be more than happy 4 | to get your feedback via [issues](https://github.com/deepmind/xmanager/issues)! 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Community Guidelines 20 | 21 | This project follows 22 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 23 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """XManager.""" 15 | -------------------------------------------------------------------------------- /compile_pb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | readonly BAZEL_DIR="/tmp/bazel" 18 | readonly SOURCE_ROOT_DIR="$(realpath $(dirname $0))" 19 | readonly GENERATED_DIR="$SOURCE_ROOT_DIR/generated" 20 | 21 | git clone https://github.com/bazelbuild/bazel.git "${BAZEL_DIR}" 22 | 23 | protoc --proto_path="${BAZEL_DIR}" --python_out="${BAZEL_DIR}" \ 24 | "${BAZEL_DIR}/src/main/java/com/google/devtools/build/lib/buildeventstream/proto/build_event_stream.proto" \ 25 | "${BAZEL_DIR}/src/main/protobuf/command_line.proto" \ 26 | "${BAZEL_DIR}/src/main/protobuf/failure_details.proto" \ 27 | "${BAZEL_DIR}/src/main/protobuf/invocation_policy.proto" \ 28 | "${BAZEL_DIR}/src/main/protobuf/option_filters.proto" 29 | 30 | cp "${BAZEL_DIR}/src/main/java/com/google/devtools/build/lib/buildeventstream/proto/build_event_stream_pb2.py" "${GENERATED_DIR}/" 31 | cp "${BAZEL_DIR}/src/main/protobuf/command_line_pb2.py" "${GENERATED_DIR}/" 32 | cp "${BAZEL_DIR}/src/main/protobuf/failure_details_pb2.py" "${GENERATED_DIR}/" 33 | cp "${BAZEL_DIR}/src/main/protobuf/invocation_policy_pb2.py" "${GENERATED_DIR}/" 34 | cp "${BAZEL_DIR}/src/main/protobuf/option_filters_pb2.py" "${GENERATED_DIR}/" 35 | 36 | protoc --proto_path="$SOURCE_ROOT_DIR" --python_out="$SOURCE_ROOT_DIR" \ 37 | "$SOURCE_ROOT_DIR/xm_local/storage/data.proto" 38 | mv "$SOURCE_ROOT_DIR/xm_local/storage/data_pb2.py" "${GENERATED_DIR}/" 39 | 40 | # Make generated imports local to xmanager.generated. 41 | find "${GENERATED_DIR}/" -name '*pb2*' -exec \ 42 | sed -i 's/^from src\.main.* import/from . import/' {} \; 43 | # Add NOLINT line. 44 | find "${GENERATED_DIR}/" -name '*pb2*' -exec \ 45 | sed -i 's/DO NOT EDIT!/DO NOT EDIT!\n# pylint: skip-file/' {} \; 46 | -------------------------------------------------------------------------------- /docs/executors.md: -------------------------------------------------------------------------------- 1 | # Executors 2 | 3 | ## Local 4 | 5 | The local executor declares that an executable will be run on the same machine 6 | from which the launch script is invoked. 7 | 8 | ```python 9 | xm_local.Local( 10 | docker_options=xm_local.DockerOptions(...), 11 | ) 12 | ``` 13 | 14 | Making GPUs available to local containers is possible by requesting 15 | the `local_gpu` resource through a requirements object. 16 | 17 | ```python 18 | xm_local.Local( 19 | xm.JobRequirements(local_gpu=1) 20 | ) 21 | ``` 22 | 23 | Note: Only local NVIDIA GPUs can be requested for the `Local` executor through 24 | the resource requirements object. Any other requirements will be ignored. 25 | 26 | ## Vertex AI (Cloud AI Platform) 27 | 28 | The `Vertex` executor declares that an executable will be run on the Vertex AI 29 | platform. 30 | 31 | The Vertex executor takes in a resource requirements object. 32 | 33 | ```python 34 | xm_local.Vertex( 35 | xm.JobRequirements( 36 | cpu=1, # Measured in vCPUs. 37 | ram=4 * xm.GiB, 38 | T4=1, # NVIDIA Tesla T4. 39 | ), 40 | ) 41 | ``` 42 | 43 | ```python 44 | xm_local.Vertex( 45 | xm.JobRequirements( 46 | cpu=1, # Measured in vCPUs. 47 | ram=4 * xm.GiB, 48 | TPU_V2=8, # TPU v2. 49 | ), 50 | ) 51 | ``` 52 | 53 | As of June 2021, the currently supported accelerator types are: 54 | 55 | * `P100` 56 | * `V100` 57 | * `P4` 58 | * `T4` 59 | * `A100` 60 | * `TPU_V2` 61 | * `TPU_V3` 62 | 63 | ### Vertex AI Specification 64 | 65 | The Vertex AI executor allows you specify a remote image repository to push to. 66 | 67 | ```python 68 | xm_local.Vertex.Spec( 69 | push_image_tag='gcr.io//:', 70 | ) 71 | ``` 72 | 73 | ## Kubernetes (experimental) 74 | 75 | The Kubernetes executor declares that an executable will be run on a Kubernetes 76 | cluster. As of October 2021, Kubernetes is not fully supported. 77 | 78 | The Kubernetes executor pulls from your local `kubeconfig`. The XManager 79 | command-line has helpers to set up a Google Kubernetes Engine (GKE) cluster. 80 | 81 | ```bash 82 | pip install caliban==0.4.1 83 | xmanager cluster create 84 | 85 | # cleanup 86 | xmanager cluster delete 87 | ``` 88 | 89 | You can store the GKE credentials in your `kubeconfig`: 90 | 91 | ```bash 92 | gcloud container clusters get-credentials 93 | ``` 94 | 95 | ### Kubernetes Specification 96 | 97 | The Kubernetes executor allows you specify a remote image repository to push to. 98 | 99 | ```python 100 | xm_local.Kubernetes.Spec( 101 | push_image_tag='gcr.io//:', 102 | ) 103 | ``` 104 | -------------------------------------------------------------------------------- /docs/metadata_storage.md: -------------------------------------------------------------------------------- 1 | # Metadata Storage 2 | 3 | Experiment metadata is stored in a SQL database. By default, the database used 4 | is the SQLite one at `~/.xmanager/experiments.sqlite3`. If that does not 5 | suffice, the XManager client can also connect to a generic database based on a 6 | YAML configuration. An example of such a configuration is given below: 7 | 8 | ```yaml 9 | # sample_db_config.yaml 10 | 11 | # Connector used - one of ['cloudsql', 'sqlite', 'generic'] 12 | sql_connector: 'generic' 13 | 14 | sql_connection_settings: 15 | # Backend used, e.g. 'mysql', 'postgresql' 16 | backend: 'mysql' 17 | # Driver used, e.g. 'pymysql', 'pg8000' 18 | driver: 'pymysql' 19 | # Username 20 | username: 'root' 21 | # Password 22 | password: 'metadata' 23 | # Host (or instance connection name for CloudSql) 24 | host: '127.0.01' 25 | # Port 26 | port: 3309 27 | # Database name 28 | db_name: 'metadata' 29 | ``` 30 | 31 | The `sql_connector` field specifies the connector to be used. It can be one of 32 | `sqlite`, `cloudsql`, `generic`. Generally, it's recommended to use the `sqlite` 33 | connector for a local DB at a different location than the default one, 34 | `cloudsql` for connecting to a CloudSQL database, and `generic` for any other 35 | database. 36 | 37 | The `cloudsql` connector runs the 38 | [CloudSQL Auth Proxy](https://cloud.google.com/sql/docs/mysql/sql-proxy) 39 | automatically, so it requires the host that runs the XManager client to have the 40 | required permissions (Cloud SQL Editor IAM role) and project APIs enabled (Cloud 41 | SQL Admin API). When using the `cloudsql` connector, one should use an instance 42 | connection name of the format 43 | `${PROJECT_ID}:${INSTANCE_LOCATION}:${INSTANCE_NAME}` in the `host` field of the 44 | YAML configuration. 45 | 46 | When using the `generic` connector, the fields in the `sql_connection_settings` 47 | follow the format of the `sqlalchemy` connection URL. Therefore, each 48 | combination of the fields above that can form a valid `sqlalchemy` connection 49 | URL can be used. 50 | 51 | This YAML config can be passed to XManager by using the 52 | `--xm_db_yaml_config_path` flag. Note that the path is relative to the launch 53 | script. If the flag is not enabled, the default SQLite database is used. 54 | 55 | The XManager client attempts to always keep the database to the latest version. 56 | For details, check 57 | [database migration docs](https://github.com/deepmind/xmanager/tree/main/xmanager/xm_local/storage/alembic/README.md). 58 | 59 | ## Slides 60 | 61 | Remote metadata execution and metadata storage are present in 62 | [these slides](https://storage.googleapis.com/gresearch/xmanager/remote_execution_slides.pdf). 63 | -------------------------------------------------------------------------------- /docs/oss_xm_scope.md: -------------------------------------------------------------------------------- 1 | # Scope of OSS XManager 2 | 3 | ## What's included 4 | 5 | Internally within Alphabet, XManager is a fully-featured ecosystem aimed to 6 | facilitate the experimentation process that is deeply integrated with our 7 | internal infrastructure. The XManager repository that you are seeing only 8 | contains a small subset of our ecosystem. Due to the amount of work needed and 9 | coupling with internal tools, we have decided to only open source the launch 10 | API: the means to describe what an experiment consists of and how to launch it. 11 | As a result, users can launch these experiments locally, on Google Cloud 12 | Platform or on Kubernetes. The API can be extended to support other Cloud 13 | platforms. 14 | 15 | ## The scope 16 | 17 | The project’s primary goal is to facilitate collaboration between internal 18 | researchers and the scientific community. The Open Sourced XManager API provides 19 | a way to share our internally-developed code, enabling everybody to reproduce 20 | results, build new knowledge on top of them, or put the code to work for the 21 | benefit of society. It also enables closer collaborations between Alphabet and 22 | external researchers. 23 | 24 | As we are focused on delivering our internal roadmap, we may not have bandwidth 25 | for external feature requests, even if they come with an implementation. 26 | XManager was designed to be highly flexible and adaptable and has a 27 | lot of potential for further use cases. Therefore we don’t exclude the 28 | possibility of increasing its scope going forward. 29 | 30 | ## Can I use XManager for research unrelated to Google or DeepMind? 31 | 32 | Yes. XManager API is distributed under Apache 2.0 licence which gives you a lot 33 | of flexibility. We understand that our 34 | [contribution policy](https://github.com/deepmind/xmanager/blob/main/CONTRIBUTING.md) 35 | doesn't sound very reassuring, but we can suggest the following to counter the 36 | risks: 37 | 38 | * The licence allows you to have and maintain a fork with all the changes you 39 | need. 40 | * The API has been designed to be modular. Similar to how `xmanager.xm_local` 41 | extends `xmanager.xm` and provides support for orchestrating experiments 42 | from a local machine, you may have your own module tailoring the API to your 43 | needs. In fact this is how our internal version adds support for 44 | Alphabet-specific infrastructure. 45 | * The core `xmanager.xm` composes the basic building blocks of a research 46 | experiment that we think every implementation should be based on. And one of 47 | the things we've learned over years is that research is a fast moving field. 48 | It was in our best interest to design it generic, flexible, and extensible. 49 | 50 | ## What is not included 51 | 52 | The following was intentionally left outside of the XManager scope to allow it 53 | to be shared in a reasonable time frame. Note that while these features are 54 | important parts of the research infrastructure we believe that many of them are 55 | better to be provided as a collection of well-integrated libraries / services 56 | rather than an all-in-one tool. 57 | 58 | * Web user interface. Vertex AI and TensorBoard provide good alternatives if 59 | you need a UI. 60 | * Metrics storage. XManager API doesn't cover storing, retrieving, or 61 | conducting analysis of experiment metrics. 62 | * Computational resource sharing. `xm_local` supports running many experiments 63 | in the Cloud in parallel. But it doesn't enforce any policy for sharing 64 | resources between many researchers. 65 | * Continuous status tracking. The base XManager implementation does not track 66 | the running status of trials or jobs. It does not contain any mechanism for 67 | push-notification or notifying a user of success or failure. 68 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Tensorflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn.""" 15 | import os 16 | 17 | from absl import app 18 | from absl import flags 19 | import tensorflow as tf 20 | from tensorflow.keras import datasets 21 | from tensorflow.keras import layers 22 | from tensorflow.keras import models 23 | 24 | # When using Vertex Tensorboard, the tensorboard will be present as a 25 | # environment variable. 26 | LOG_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '') 27 | 28 | FLAGS = flags.FLAGS 29 | flags.DEFINE_integer('epochs', 5, 'epochs') 30 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate') 31 | 32 | 33 | def main(_): 34 | (train_images, train_labels), (test_images, test_labels) = ( 35 | datasets.cifar10.load_data() 36 | ) 37 | 38 | # Normalize pixel values to be between 0 and 1 39 | train_images, test_images = train_images / 255.0, test_images / 255.0 40 | 41 | strategy = tf.distribute.MirroredStrategy() 42 | with strategy.scope(): 43 | model = models.Sequential() 44 | model.add( 45 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)) 46 | ) 47 | model.add(layers.MaxPooling2D((2, 2))) 48 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 49 | model.add(layers.MaxPooling2D((2, 2))) 50 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 51 | 52 | model.add(layers.Flatten()) 53 | model.add(layers.Dense(64, activation='relu')) 54 | model.add(layers.Dense(10)) 55 | 56 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) 57 | model.compile( 58 | optimizer=optimizer, 59 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 60 | metrics=['accuracy'], 61 | ) 62 | 63 | callbacks = [] 64 | if LOG_DIR: 65 | callbacks = [ 66 | tf.keras.callbacks.TensorBoard( 67 | log_dir=LOG_DIR, 68 | histogram_freq=1, 69 | ), 70 | ] 71 | 72 | model.fit( 73 | train_images, 74 | train_labels, 75 | epochs=FLAGS.epochs, 76 | validation_data=(test_images, test_labels), 77 | callbacks=callbacks, 78 | verbose=2, 79 | ) 80 | 81 | 82 | if __name__ == '__main__': 83 | app.run(main) 84 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for CIFAR10. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/cifar10_tensorflow/launcher.py -- \ 19 | --xm_wrap_late_bindings [--image_path=gcr.io/path/to/image/tag] 20 | """ 21 | import asyncio 22 | import itertools 23 | import os 24 | 25 | from absl import app 26 | from absl import flags 27 | from xmanager import xm 28 | from xmanager import xm_local 29 | from xmanager.cloud import vertex 30 | 31 | FLAGS = flags.FLAGS 32 | flags.DEFINE_string('tensorboard', None, 'Tensorboard instance.') 33 | flags.DEFINE_integer('gpus_per_node', 2, 'Number of GPUs per node.') 34 | 35 | 36 | def main(_): 37 | with xm_local.create_experiment(experiment_title='cifar10') as experiment: 38 | spec = xm.PythonContainer( 39 | # Package the current directory that this script is in. 40 | path='.', 41 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6', 42 | entrypoint=xm.ModuleName('cifar10'), 43 | ) 44 | 45 | [executable] = experiment.package( 46 | [ 47 | xm.Packageable( 48 | executable_spec=spec, 49 | executor_spec=xm_local.Vertex.Spec(), 50 | args={}, 51 | ), 52 | ] 53 | ) 54 | 55 | learning_rates = [0.1, 0.001] 56 | trials = list( 57 | dict([('learning_rate', lr)]) 58 | for (lr,) in itertools.product(learning_rates) 59 | ) 60 | 61 | tensorboard = FLAGS.tensorboard 62 | if not tensorboard: 63 | tensorboard = vertex.get_default_client().get_or_create_tensorboard( 64 | 'cifar10' 65 | ) 66 | tensorboard = asyncio.get_event_loop().run_until_complete(tensorboard) 67 | 68 | for i, hyperparameters in enumerate(trials): 69 | output_dir = os.environ.get('GOOGLE_CLOUD_BUCKET_NAME', None) 70 | if output_dir: 71 | output_dir = os.path.join( 72 | output_dir, str(experiment.experiment_id), str(i) 73 | ) 74 | tensorboard_capability = xm_local.TensorboardCapability( 75 | name=tensorboard, base_output_directory=output_dir 76 | ) 77 | experiment.add( 78 | xm.Job( 79 | executable=executable, 80 | executor=xm_local.Vertex( 81 | tensorboard=tensorboard_capability, 82 | requirements=xm.JobRequirements(t4=FLAGS.gpus_per_node), 83 | ), 84 | args=hyperparameters, 85 | ) 86 | ) 87 | 88 | 89 | if __name__ == '__main__': 90 | app.run(main) 91 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | tensorflow 3 | tensorflow-datasets 4 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_multiworker/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Tensorflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn.""" 15 | import os 16 | 17 | from absl import app 18 | from absl import flags 19 | import tensorflow as tf 20 | from tensorflow.keras import datasets 21 | from tensorflow.keras import layers 22 | from tensorflow.keras import models 23 | 24 | # When using Vertex Tensorboard, the tensorboard will be present as a 25 | # environment variable. 26 | LOG_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '') 27 | 28 | FLAGS = flags.FLAGS 29 | flags.DEFINE_integer('epochs', 5, 'epochs') 30 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate') 31 | 32 | 33 | def main(_): 34 | strategy = tf.distribute.MultiWorkerMirroredStrategy() 35 | with strategy.scope(): 36 | model = models.Sequential() 37 | model.add( 38 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)) 39 | ) 40 | model.add(layers.MaxPooling2D((2, 2))) 41 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 42 | model.add(layers.MaxPooling2D((2, 2))) 43 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 44 | 45 | model.add(layers.Flatten()) 46 | model.add(layers.Dense(64, activation='relu')) 47 | model.add(layers.Dense(10)) 48 | 49 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) 50 | model.compile( 51 | optimizer=optimizer, 52 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 53 | metrics=['accuracy'], 54 | ) 55 | 56 | callbacks = [] 57 | if LOG_DIR: 58 | callbacks = [ 59 | tf.keras.callbacks.TensorBoard( 60 | log_dir=LOG_DIR, 61 | histogram_freq=1, 62 | ), 63 | ] 64 | 65 | (train_images, train_labels), (test_images, test_labels) = ( 66 | datasets.cifar10.load_data() 67 | ) 68 | # Normalize pixel values to be between 0 and 1 69 | train_images, test_images = train_images / 255.0, test_images / 255.0 70 | 71 | train_dataset = tf.data.Dataset.from_tensor_slices( 72 | (train_images, train_labels) 73 | ) 74 | validation_dataset = tf.data.Dataset.from_tensor_slices( 75 | (test_images, test_labels) 76 | ) 77 | 78 | train_dataset = train_dataset.batch(32) 79 | validation_dataset = validation_dataset.batch(32) 80 | 81 | options = tf.data.Options() 82 | options.experimental_distribute.auto_shard_policy = ( 83 | tf.data.experimental.AutoShardPolicy.DATA 84 | ) 85 | 86 | train_dataset.with_options(options) 87 | validation_dataset.with_options(options) 88 | 89 | model.fit( 90 | train_dataset, 91 | epochs=FLAGS.epochs, 92 | validation_data=validation_dataset, 93 | callbacks=callbacks, 94 | verbose=2, 95 | ) 96 | 97 | 98 | if __name__ == '__main__': 99 | app.run(main) 100 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_multiworker/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for CIFAR10 using TF's MultiWorkerMirroredStrategy using the Kubernetes back-end. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/cifar10_tensorflow_k8s_multiworker/launcher.py -- \ 19 | --xm_wrap_late_bindings [--image_path=gcr.io/path/to/image/tag] 20 | """ 21 | import itertools 22 | 23 | from absl import app 24 | from xmanager import xm 25 | from xmanager import xm_local 26 | from xmanager.contrib import xm_tensorflow 27 | 28 | 29 | def main(_): 30 | with xm_local.create_experiment( 31 | experiment_title='kubernetes_multiworker' 32 | ) as experiment: 33 | spec = xm.PythonContainer( 34 | # Package the current directory that this script is in. 35 | path='.', 36 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6', 37 | entrypoint=xm.ModuleName('cifar10'), 38 | ) 39 | 40 | [executable] = experiment.package( 41 | [ 42 | xm.Packageable( 43 | executable_spec=spec, 44 | executor_spec=xm_local.Kubernetes.Spec(), 45 | args={}, 46 | ), 47 | ] 48 | ) 49 | 50 | learning_rates = [0.001] 51 | trials = list( 52 | dict([('learning_rate', lr)]) 53 | for (lr,) in itertools.product(learning_rates) 54 | ) 55 | 56 | builder = xm_tensorflow.MultiWorkerMirroredStrategyBuilder( 57 | experiment=experiment, 58 | worker_executable=executable, 59 | worker_executor=xm_local.Kubernetes( 60 | requirements=xm.JobRequirements(t4=1) 61 | ), 62 | worker_name='worker', 63 | num_workers=3, 64 | ) 65 | 66 | for hyperparameters in trials: 67 | experiment.add(builder.gen_job_group(), args=hyperparameters) 68 | 69 | 70 | if __name__ == '__main__': 71 | app.run(main) 72 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_multiworker/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | tensorflow 3 | tensorflow-datasets 4 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_ps/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Tensorflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn.""" 15 | import os 16 | 17 | from absl import app 18 | from absl import flags 19 | import tensorflow as tf 20 | from tensorflow.keras import datasets 21 | from tensorflow.keras import layers 22 | from tensorflow.keras import models 23 | 24 | FLAGS = flags.FLAGS 25 | flags.DEFINE_integer('epochs', 5, 'epochs') 26 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate') 27 | 28 | 29 | def main(_): 30 | cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 31 | 32 | if cluster_resolver.task_type in ('worker', 'ps'): 33 | os.environ['GRPC_FAIL_FAST'] = 'use_caller' 34 | 35 | server = tf.distribute.Server( 36 | cluster_resolver.cluster_spec(), 37 | job_name=cluster_resolver.task_type, 38 | task_index=cluster_resolver.task_id, 39 | protocol=cluster_resolver.rpc_layer or 'grpc', 40 | start=True, 41 | ) 42 | server.join() 43 | 44 | (train_images, train_labels), _ = datasets.cifar10.load_data() 45 | 46 | def dataset_fn(input_context): 47 | dataset = tf.data.Dataset.from_tensor_slices( 48 | (train_images, train_labels) 49 | ).repeat() 50 | dataset = dataset.shard( 51 | input_context.num_input_pipelines, input_context.input_pipeline_id 52 | ) 53 | dataset = dataset.batch(64) 54 | dataset = dataset.prefetch(2) 55 | 56 | return dataset 57 | 58 | strategy = tf.distribute.experimental.ParameterServerStrategy( 59 | cluster_resolver 60 | ) 61 | with strategy.scope(): 62 | model = models.Sequential() 63 | model.add( 64 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)) 65 | ) 66 | model.add(layers.MaxPooling2D((2, 2))) 67 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 68 | model.add(layers.MaxPooling2D((2, 2))) 69 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 70 | 71 | model.add(layers.Flatten()) 72 | model.add(layers.Dense(64, activation='relu')) 73 | model.add(layers.Dense(10)) 74 | 75 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) 76 | model.compile( 77 | optimizer=optimizer, 78 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 79 | metrics=['accuracy'], 80 | ) 81 | 82 | model.fit( 83 | tf.keras.utils.experimental.DatasetCreator(dataset_fn), 84 | steps_per_epoch=1500, 85 | epochs=FLAGS.epochs, 86 | ) 87 | 88 | 89 | if __name__ == '__main__': 90 | app.run(main) 91 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_ps/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for CIFAR10 using TF's ParameterServerStrategy. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/cifar10_tensorflow_k8s__ps/launcher.py -- \ 19 | --xm_wrap_late_bindings [--image_path=gcr.io/path/to/image/tag] 20 | """ 21 | import itertools 22 | 23 | from absl import app 24 | from xmanager import xm 25 | from xmanager import xm_local 26 | from xmanager.contrib import xm_tensorflow 27 | 28 | 29 | def main(_): 30 | with xm_local.create_experiment( 31 | experiment_title='kubernetes_multiworker' 32 | ) as experiment: 33 | spec = xm.PythonContainer( 34 | # Package the current directory that this script is in. 35 | path='.', 36 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6', 37 | entrypoint=xm.ModuleName('cifar10'), 38 | ) 39 | 40 | [executable] = experiment.package( 41 | [ 42 | xm.Packageable( 43 | executable_spec=spec, 44 | executor_spec=xm_local.Kubernetes.Spec(), 45 | args={}, 46 | ) 47 | ] 48 | ) 49 | 50 | learning_rates = [0.001] 51 | trials = list( 52 | dict([('learning_rate', lr)]) 53 | for (lr,) in itertools.product(learning_rates) 54 | ) 55 | 56 | builder = xm_tensorflow.ParameterServerStrategyBuilder( 57 | experiment=experiment, 58 | chief_executable=executable, 59 | chief_executor=xm_local.Kubernetes( 60 | requirements=xm.JobRequirements(t4=1) 61 | ), 62 | worker_executable=executable, 63 | worker_executor=xm_local.Kubernetes( 64 | requirements=xm.JobRequirements(t4=1) 65 | ), 66 | worker_name='worker', 67 | ps_executable=executable, 68 | ps_executor=xm_local.Kubernetes(), 69 | ps_name='ps', 70 | num_workers=2, 71 | num_ps=1, 72 | ) 73 | 74 | for hyperparameters in trials: 75 | experiment.add(builder.gen_job_group(), args=hyperparameters) 76 | 77 | 78 | if __name__ == '__main__': 79 | app.run(main) 80 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_ps/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | tensorflow 3 | tensorflow-datasets 4 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_tensorboard/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Tensorflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn.""" 15 | from absl import app 16 | from absl import flags 17 | import tensorflow as tf 18 | from tensorflow.keras import datasets 19 | from tensorflow.keras import layers 20 | from tensorflow.keras import models 21 | 22 | _TENSORBOARD_LOG_DIR = flags.DEFINE_string('tensorboard_log_dir', None, '') 23 | 24 | FLAGS = flags.FLAGS 25 | flags.DEFINE_integer('epochs', 5, 'epochs') 26 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate') 27 | 28 | 29 | def main(_): 30 | (train_images, train_labels), (test_images, test_labels) = ( 31 | datasets.cifar10.load_data() 32 | ) 33 | 34 | # Normalize pixel values to be between 0 and 1 35 | train_images, test_images = train_images / 255.0, test_images / 255.0 36 | 37 | strategy = tf.distribute.MirroredStrategy() 38 | with strategy.scope(): 39 | model = models.Sequential() 40 | model.add( 41 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)) 42 | ) 43 | model.add(layers.MaxPooling2D((2, 2))) 44 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 45 | model.add(layers.MaxPooling2D((2, 2))) 46 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 47 | 48 | model.add(layers.Flatten()) 49 | model.add(layers.Dense(64, activation='relu')) 50 | model.add(layers.Dense(10)) 51 | 52 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) 53 | model.compile( 54 | optimizer=optimizer, 55 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 56 | metrics=['accuracy'], 57 | ) 58 | 59 | callbacks = [] 60 | if _TENSORBOARD_LOG_DIR.value: 61 | callbacks = [ 62 | tf.keras.callbacks.TensorBoard( 63 | log_dir=_TENSORBOARD_LOG_DIR.value, 64 | histogram_freq=1, 65 | ), 66 | ] 67 | 68 | model.fit( 69 | train_images, 70 | train_labels, 71 | epochs=FLAGS.epochs, 72 | validation_data=(test_images, test_labels), 73 | callbacks=callbacks, 74 | verbose=2, 75 | ) 76 | 77 | 78 | if __name__ == '__main__': 79 | app.run(main) 80 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_tensorboard/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for running CIFAR10 on Kubernetes with Tensorboard auxiliary job. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/cifar10_tensorflow_k8s_tensorboard/launcher.py -- \ 19 | --tensorboard_log_dir=TENSORBOARD_LOG_DIR \ 20 | [--tensorboard_timeout_secs=TIMEOUT_SECS] 21 | """ 22 | import itertools 23 | import os 24 | 25 | from absl import app 26 | from absl import flags 27 | from xmanager import xm 28 | from xmanager import xm_local 29 | from xmanager.contrib import tensorboard 30 | 31 | _TENSORBOARD_LOG_DIR = flags.DEFINE_string( 32 | 'tensorboard_log_dir', 33 | None, 34 | 'Log directory to be used by workers and Tensorboard.', 35 | ) 36 | 37 | _TENSORBOARD_TIMEOUT_SECS = flags.DEFINE_integer( 38 | 'tensorboard_timeout_secs', 39 | 60 * 60, 40 | 'The amount of time the Tensorboard job should run for.', 41 | ) 42 | 43 | 44 | def main(_): 45 | with xm_local.create_experiment(experiment_title='cifar10') as experiment: 46 | spec = xm.PythonContainer( 47 | # Package the current directory that this script is in. 48 | path='.', 49 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6', 50 | entrypoint=xm.ModuleName('cifar10'), 51 | ) 52 | 53 | [executable] = experiment.package( 54 | [ 55 | xm.Packageable( 56 | executable_spec=spec, 57 | executor_spec=xm_local.Kubernetes.Spec(), 58 | ), 59 | ] 60 | ) 61 | 62 | learning_rates = [0.1, 0.001] 63 | trials = list( 64 | dict([('learning_rate', lr)]) 65 | for (lr,) in itertools.product(learning_rates) 66 | ) 67 | 68 | log_dir = None 69 | if _TENSORBOARD_LOG_DIR.value: 70 | log_dir = ( 71 | f'{_TENSORBOARD_LOG_DIR.value}/{str(experiment.experiment_id)}/logs' 72 | ) 73 | 74 | if log_dir: 75 | tensorboard.add_tensorboard( 76 | experiment, 77 | log_dir, 78 | executor=xm_local.Kubernetes(), 79 | timeout_secs=_TENSORBOARD_TIMEOUT_SECS.value, 80 | ) 81 | 82 | for i, hyperparameters in enumerate(trials): 83 | output_dir = os.path.join(log_dir, str(i)) 84 | experiment.add( 85 | xm.Job( 86 | executable=executable, 87 | executor=xm_local.Kubernetes(), 88 | args=dict({'tensorboard_log_dir': output_dir, **hyperparameters}), 89 | ) 90 | ) 91 | 92 | 93 | if __name__ == '__main__': 94 | app.run(main) 95 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_k8s_tensorboard/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | tensorflow 3 | tensorflow-datasets 4 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_tpu/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Tensorflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn.""" 15 | import os 16 | 17 | from absl import app 18 | from absl import flags 19 | import tensorflow as tf 20 | from tensorflow.keras import datasets 21 | from tensorflow.keras import layers 22 | from tensorflow.keras import models 23 | 24 | # When using Vertex Tensorboard, the tensorboard will be present as a 25 | # environment variable. 26 | LOG_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '') 27 | 28 | FLAGS = flags.FLAGS 29 | flags.DEFINE_integer('epochs', 5, 'epochs') 30 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate') 31 | 32 | 33 | def main(_): 34 | (train_images, train_labels), (test_images, test_labels) = ( 35 | datasets.cifar10.load_data() 36 | ) 37 | 38 | # Normalize pixel values to be between 0 and 1 39 | train_images, test_images = train_images / 255.0, test_images / 255.0 40 | 41 | resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local') 42 | tf.config.experimental_connect_to_cluster(resolver) 43 | tf.tpu.experimental.initialize_tpu_system(resolver) 44 | strategy = tf.distribute.TPUStrategy(resolver) 45 | with strategy.scope(): 46 | model = models.Sequential() 47 | model.add( 48 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)) 49 | ) 50 | model.add(layers.MaxPooling2D((2, 2))) 51 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 52 | model.add(layers.MaxPooling2D((2, 2))) 53 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 54 | 55 | model.add(layers.Flatten()) 56 | model.add(layers.Dense(64, activation='relu')) 57 | model.add(layers.Dense(10)) 58 | 59 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) 60 | model.compile( 61 | optimizer=optimizer, 62 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 63 | metrics=['accuracy'], 64 | ) 65 | 66 | callbacks = [] 67 | if LOG_DIR: 68 | callbacks = [ 69 | tf.keras.callbacks.TensorBoard( 70 | log_dir=LOG_DIR, 71 | histogram_freq=1, 72 | ), 73 | ] 74 | 75 | model.fit( 76 | train_images, 77 | train_labels, 78 | epochs=FLAGS.epochs, 79 | validation_data=(test_images, test_labels), 80 | callbacks=callbacks, 81 | verbose=2, 82 | ) 83 | 84 | 85 | if __name__ == '__main__': 86 | app.run(main) 87 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_tpu/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for CIFAR10. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/cifar10_tensorflow_tpu/launcher.py 19 | """ 20 | import asyncio 21 | import itertools 22 | import os 23 | 24 | from absl import app 25 | from absl import flags 26 | from xmanager import xm 27 | from xmanager import xm_local 28 | from xmanager.cloud import build_image 29 | from xmanager.cloud import vertex 30 | from xmanager.contrib import tpu 31 | 32 | FLAGS = flags.FLAGS 33 | flags.DEFINE_string('tensorboard', None, 'Tensorboard instance.') 34 | 35 | 36 | def main(_): 37 | with xm_local.create_experiment(experiment_title='cifar10') as experiment: 38 | directory = os.path.basename(os.path.dirname(os.path.realpath(__file__))) 39 | # pyformat: disable 40 | spec = xm.PythonContainer( 41 | # Package the current directory that this script is in. 42 | path='.', 43 | # tpuvm requires Python3.8 and GLIBC_2.29, which requires at least 44 | # debian:11 or ubuntu:20.04 45 | base_image='ubuntu:20.04', 46 | docker_instructions=( 47 | ['RUN apt-get update --allow-releaseinfo-change && apt-get install -y python-is-python3 python3-pip wget'] + # pylint: disable=line-too-long 48 | build_image.default_steps(directory, use_deep_module=False) + 49 | tpu.tpuvm_docker_instructions()), 50 | entrypoint=xm.ModuleName('cifar10'), 51 | ) 52 | # pyformat: enable 53 | 54 | [executable] = experiment.package( 55 | [ 56 | xm.Packageable( 57 | executable_spec=spec, 58 | executor_spec=xm_local.Vertex.Spec(), 59 | args={}, 60 | ), 61 | ] 62 | ) 63 | 64 | learning_rates = [0.1, 0.001] 65 | trials = list( 66 | dict([('learning_rate', lr)]) 67 | for (lr,) in itertools.product(learning_rates) 68 | ) 69 | 70 | tensorboard = FLAGS.tensorboard 71 | if not tensorboard: 72 | tensorboard = vertex.get_default_client().get_or_create_tensorboard( 73 | 'cifar10' 74 | ) 75 | tensorboard = asyncio.get_event_loop().run_until_complete(tensorboard) 76 | 77 | for i, hyperparameters in enumerate(trials): 78 | output_dir = os.environ.get('GOOGLE_CLOUD_BUCKET_NAME', None) 79 | if output_dir: 80 | output_dir = os.path.join( 81 | output_dir, str(experiment.experiment_id), str(i) 82 | ) 83 | tensorboard_capability = xm_local.TensorboardCapability( 84 | name=tensorboard, base_output_directory=output_dir 85 | ) 86 | experiment.add( 87 | xm.Job( 88 | executable=executable, 89 | executor=xm_local.Vertex( 90 | requirements=xm.JobRequirements(tpu_v2=8), 91 | tensorboard=tensorboard_capability, 92 | ), 93 | args=hyperparameters, 94 | ) 95 | ) 96 | 97 | 98 | if __name__ == '__main__': 99 | app.run(main) 100 | -------------------------------------------------------------------------------- /examples/cifar10_tensorflow_tpu/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | # tensorflow # installed as part of tpuvm 3 | tensorflow-datasets 4 | -------------------------------------------------------------------------------- /examples/cifar10_torch/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for CIFAR10. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/cifar10_torch/launcher.py -- \ 19 | --xm_wrap_late_bindings [--image_path=gcr.io/path/to/image/tag] 20 | """ 21 | 22 | import itertools 23 | 24 | from absl import app 25 | from absl import flags 26 | from xmanager import xm 27 | from xmanager import xm_local 28 | from xmanager.cloud import utils 29 | 30 | FLAGS = flags.FLAGS 31 | flags.DEFINE_string('image_path', None, 'Image path.') 32 | 33 | flags.DEFINE_integer('nodes', 1, 'Number of nodes.') 34 | flags.DEFINE_integer('gpus_per_node', 2, 'Number of GPUs per node.') 35 | 36 | 37 | @xm.run_in_asyncio_loop 38 | async def main(_): 39 | async with xm_local.create_experiment( 40 | experiment_title='cifar10' 41 | ) as experiment: 42 | if FLAGS.image_path: 43 | spec = xm.Container(image_path=FLAGS.image_path) 44 | else: 45 | spec = xm.PythonContainer( 46 | # Package the current directory that this script is in. 47 | path='.', 48 | base_image='gcr.io/deeplearning-platform-release/pytorch-gpu.1-12', 49 | entrypoint=xm.ModuleName('cifar10'), 50 | ) 51 | 52 | [executable] = experiment.package( 53 | [ 54 | xm.Packageable( 55 | executable_spec=spec, 56 | executor_spec=xm_local.Vertex.Spec(), 57 | args={ 58 | # TODO: replace workerpool0 with the actual 59 | # name of the job when Vertex AI supports custom name worker 60 | # pools. 61 | 'master_addr_port': xm.ShellSafeArg( 62 | utils.get_workerpool_address('workerpool0') 63 | ), 64 | }, 65 | ), 66 | ] 67 | ) 68 | 69 | batch_sizes = [64, 1024] 70 | learning_rates = [0.1, 0.001] 71 | trials = list( 72 | dict([('batch_size', bs), ('learning_rate', lr)]) 73 | for (bs, lr) in itertools.product(batch_sizes, learning_rates) 74 | ) 75 | 76 | work_units = [] 77 | for hyperparameters in trials: 78 | job_group = xm.JobGroup() 79 | for i in range(FLAGS.nodes): 80 | hyperparameters = dict(hyperparameters) 81 | hyperparameters['world_size'] = FLAGS.nodes 82 | hyperparameters['rank'] = i 83 | job_group.jobs[f'node_{i}'] = xm.Job( 84 | executable=executable, 85 | executor=xm_local.Vertex( 86 | xm.JobRequirements(t4=FLAGS.gpus_per_node) 87 | ), 88 | args=hyperparameters, 89 | ) 90 | work_units.append(await experiment.add(job_group)) 91 | print('Waiting for async launches to return values...') 92 | for work_unit in work_units: 93 | await work_unit.wait_until_complete() 94 | print('Experiment completed.') 95 | 96 | 97 | if __name__ == '__main__': 98 | app.run(main) 99 | -------------------------------------------------------------------------------- /examples/cifar10_torch/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | absl-py 16 | numpy 17 | torch 18 | torchvision 19 | -------------------------------------------------------------------------------- /examples/cifar10_torch_xla/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for CIFAR10. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/cifar10_torch/launcher.py -- \ 19 | --xm_wrap_late_bindings \ 20 | [--image_path=gcr.io/path/to/image/tag] \ 21 | [--platform=gpu] 22 | """ 23 | import itertools 24 | 25 | from absl import app 26 | from absl import flags 27 | from xmanager import xm 28 | from xmanager import xm_local 29 | 30 | FLAGS = flags.FLAGS 31 | flags.DEFINE_string('image_path', None, 'Image path.') 32 | flags.DEFINE_string('platform', 'cpu', 'cpu/gpu/tpu.') 33 | flags.DEFINE_integer('cores', 1, 'Number of cores. Use 8 if platform==tpu.') 34 | 35 | 36 | def main(_): 37 | with xm_local.create_experiment(experiment_title='cifar10') as experiment: 38 | if FLAGS.image_path: 39 | spec = xm.Container(image_path=FLAGS.image_path) 40 | else: 41 | # Package the current directory that this script is in. 42 | spec = xm.PythonContainer( 43 | path='.', 44 | # This base_image is experimental and works with cpu/gpu/tpu. 45 | # https://cloud.google.com/ai-platform/deep-learning-containers/docs/choosing-container 46 | base_image='gcr.io/deeplearning-platform-release/pytorch-xla.1-8', 47 | entrypoint=xm.ModuleName('cifar10'), 48 | ) 49 | 50 | [executable] = experiment.package( 51 | [ 52 | xm.Packageable( 53 | executable_spec=spec, 54 | executor_spec=xm_local.Vertex.Spec(), 55 | args={'platform': FLAGS.platform}, 56 | ), 57 | ] 58 | ) 59 | 60 | batch_sizes = [64, 1024] 61 | learning_rates = [0.1, 0.001] 62 | trials = list( 63 | dict([('batch_size', bs), ('learning_rate', lr)]) 64 | for (bs, lr) in itertools.product(batch_sizes, learning_rates) 65 | ) 66 | 67 | requirements = xm.JobRequirements() 68 | if FLAGS.platform == 'gpu': 69 | requirements = xm.JobRequirements(t4=FLAGS.cores) 70 | elif FLAGS.platform == 'tpu': 71 | requirements = xm.JobRequirements(tpu_v3=8) 72 | for hyperparameters in trials: 73 | jobs = {} 74 | jobs['coordinator'] = xm.Job( 75 | executable=executable, 76 | executor=xm_local.Vertex(requirements), 77 | args=hyperparameters, 78 | ) 79 | experiment.add(xm.JobGroup(**jobs)) 80 | break 81 | print('Waiting for async launches to return values...') 82 | print('Launch completed and successful.') 83 | 84 | 85 | if __name__ == '__main__': 86 | app.run(main) 87 | -------------------------------------------------------------------------------- /examples/cifar10_torch_xla/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | absl-py 16 | numpy 17 | tensorflow 18 | torch 19 | torchvision 20 | -------------------------------------------------------------------------------- /examples/dockerfile/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gcc:latest 2 | 3 | COPY arg_printer.cc arg_printer.cc 4 | RUN g++ -o arg_printer arg_printer.cc 5 | 6 | ENTRYPOINT ["./arg_printer"] 7 | -------------------------------------------------------------------------------- /examples/dockerfile/arg_printer.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 DeepMind Technologies Limited 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | 18 | int main(int argc, char* argv[]) { 19 | std::cout << "Hello World!" << std::endl; 20 | 21 | std::cout << std::getenv("FOO") << std::endl; 22 | for (int i = 0; i < argc; ++i) { 23 | std::cout << argv[i] << std::endl; 24 | } 25 | return 0; 26 | } 27 | -------------------------------------------------------------------------------- /examples/dockerfile/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """XManager launcher that runs an image built from a Dockerfile.""" 15 | 16 | from typing import Sequence 17 | 18 | from absl import app 19 | from xmanager import xm 20 | from xmanager import xm_local 21 | 22 | 23 | def main(argv: Sequence[str]) -> None: 24 | del argv 25 | 26 | with xm_local.create_experiment( 27 | experiment_title='Example using Dockerfile()' 28 | ) as experiment: 29 | executable_spec = xm.Dockerfile() 30 | [executable] = experiment.package( 31 | [ 32 | xm.Packageable( 33 | executable_spec=executable_spec, 34 | executor_spec=xm_local.Vertex.Spec(), 35 | ), 36 | ] 37 | ) 38 | experiment.add( 39 | xm.Job( 40 | executable=executable, 41 | executor=xm_local.Vertex(), 42 | env_vars={'FOO': 'bar'}, 43 | args=['--a=1', '--b=2', '--c=3', '--d=4'], 44 | ) 45 | ) 46 | 47 | 48 | if __name__ == '__main__': 49 | app.run(main) 50 | -------------------------------------------------------------------------------- /examples/dopamine/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""Launcher for Dopamine. 15 | 16 | Usage: 17 | xmanager launch examples/dopaminelauncher.py -- \ 18 | --gin_file=https://raw.githubusercontent.com/google/dopamine/master/dopamine/agents/dqn/configs/dqn_mountaincar.gin 19 | """ 20 | import asyncio 21 | import os 22 | 23 | from absl import app 24 | from absl import flags 25 | from xmanager import xm 26 | from xmanager import xm_local 27 | from xmanager.cloud import vertex 28 | 29 | FLAGS = flags.FLAGS 30 | flags.DEFINE_string( 31 | 'gin_file', 32 | 'https://raw.githubusercontent.com/google/dopamine/master/dopamine/agents/dqn/configs/dqn_mountaincar.gin', 33 | 'Gin file pulled from https://github.com/google/dopamine.', 34 | ) 35 | flags.DEFINE_string('tensorboard', None, 'Tensorboard instance.') 36 | 37 | 38 | def main(_): 39 | with xm_local.create_experiment(experiment_title='dopamine') as experiment: 40 | gin_file = os.path.basename(FLAGS.gin_file) 41 | add_instruction = f'ADD {FLAGS.gin_file} {gin_file}' 42 | if FLAGS.gin_file.startswith('http'): 43 | add_instruction = f'RUN wget -O ./{gin_file} {FLAGS.gin_file}' 44 | spec = xm.PythonContainer( 45 | docker_instructions=[ 46 | 'RUN apt update && apt install -y python3-opencv', 47 | 'RUN pip install dopamine-rl', 48 | 'COPY dopamine/ workdir', 49 | 'WORKDIR workdir', 50 | add_instruction, 51 | ], 52 | entrypoint=xm.ModuleName('train'), 53 | ) 54 | 55 | [executable] = experiment.package( 56 | [ 57 | xm.Packageable( 58 | executable_spec=spec, 59 | executor_spec=xm_local.Vertex.Spec(), 60 | args={ 61 | 'gin_files': gin_file, 62 | }, 63 | ), 64 | ] 65 | ) 66 | 67 | tensorboard = FLAGS.tensorboard 68 | if not tensorboard: 69 | tensorboard = vertex.get_default_client().get_or_create_tensorboard( 70 | 'cifar10' 71 | ) 72 | tensorboard = asyncio.get_event_loop().run_until_complete(tensorboard) 73 | output_dir = os.environ['GOOGLE_CLOUD_BUCKET_NAME'] 74 | output_dir = os.path.join(output_dir, str(experiment.experiment_id)) 75 | tensorboard_capability = xm_local.TensorboardCapability( 76 | name=tensorboard, base_output_directory=output_dir 77 | ) 78 | experiment.add( 79 | xm.Job( 80 | executable=executable, 81 | executor=xm_local.Vertex( 82 | xm.JobRequirements(t4=1), tensorboard=tensorboard_capability 83 | ), 84 | ) 85 | ) 86 | 87 | 88 | if __name__ == '__main__': 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /examples/dopamine/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The entry point for running a Dopamine agent.""" 15 | import os 16 | 17 | from absl import app 18 | from absl import flags 19 | from absl import logging 20 | 21 | from dopamine.discrete_domains import run_experiment 22 | import tensorflow as tf 23 | 24 | FLAGS = flags.FLAGS 25 | flags.DEFINE_multi_string( 26 | 'gin_files', 27 | [], 28 | ( 29 | 'List of paths to gin configuration files (e.g.' 30 | '"dopamine/agents/dqn/dqn.gin").' 31 | ), 32 | ) 33 | 34 | # When using Vertex Tensorboard, the tensorboard will be present as a 35 | # environment variable. 36 | BASE_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '/tmp/dopamine_runs') 37 | 38 | 39 | def main(unused_argv): 40 | logging.set_verbosity(logging.INFO) 41 | tf.compat.v1.disable_v2_behavior() 42 | run_experiment.load_gin_configs(FLAGS.gin_files, []) 43 | runner = run_experiment.create_runner(BASE_DIR) 44 | runner.run_experiment() 45 | 46 | 47 | if __name__ == '__main__': 48 | app.run(main) 49 | -------------------------------------------------------------------------------- /examples/local_arg_printer/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | licenses(["notice"]) 16 | 17 | load("@rules_cc//cc:defs.bzl", "cc_binary") 18 | 19 | cc_binary( 20 | name = "arg_printer", 21 | srcs = ["arg_printer.cc"], 22 | ) 23 | -------------------------------------------------------------------------------- /examples/local_arg_printer/WORKSPACE: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/local_arg_printer/arg_printer.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 DeepMind Technologies Limited 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | int main(int argc, char* argv[]) { 21 | std::ofstream ouf(std::getenv("OUTPUT_PATH")); 22 | ouf << argc << std::endl; 23 | for (int i = 0; i < argc; ++i) { 24 | ouf << argv[i] << std::endl; 25 | } 26 | // Sleep to demonstrate local jobs waiting. 27 | std::this_thread::sleep_for(std::chrono::seconds(5)); 28 | return 0; 29 | } 30 | -------------------------------------------------------------------------------- /examples/local_arg_printer/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """XManager launcher that runs locally a binary built with Bazel. 15 | 16 | One must `cd` into xmanager/examples/local_arg_printer/ in order to run this 17 | example because Bazel needs to locate the WORKSPACE file. 18 | """ 19 | 20 | from typing import Sequence 21 | 22 | from absl import app 23 | from xmanager import xm 24 | from xmanager import xm_local 25 | 26 | 27 | def main(argv: Sequence[str]) -> None: 28 | del argv 29 | 30 | with xm_local.create_experiment( 31 | experiment_title='local_arg_printer' 32 | ) as experiment: 33 | [executable] = experiment.package( 34 | [ 35 | xm.Packageable( 36 | executable_spec=xm.BazelBinary( 37 | label='//:arg_printer' 38 | ), 39 | executor_spec=xm_local.Local.Spec(), 40 | ), 41 | ] 42 | ) 43 | experiment.add( 44 | xm.Job( 45 | executable=executable, 46 | executor=xm_local.Local(), 47 | env_vars={'OUTPUT_PATH': '/tmp/local_arg_printer.txt'}, 48 | ) 49 | ) 50 | 51 | 52 | if __name__ == '__main__': 53 | app.run(main) 54 | -------------------------------------------------------------------------------- /examples/local_container_gpu/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Tensorflow Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn.""" 15 | import os 16 | 17 | from absl import app 18 | from absl import flags 19 | import tensorflow as tf 20 | from tensorflow.keras import datasets 21 | from tensorflow.keras import layers 22 | from tensorflow.keras import models 23 | 24 | # When using Vertex Tensorboard, the tensorboard will be present as a 25 | # environment variable. 26 | LOG_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '') 27 | 28 | FLAGS = flags.FLAGS 29 | flags.DEFINE_integer('epochs', 5, 'epochs') 30 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate') 31 | 32 | 33 | def main(_): 34 | (train_images, train_labels), (test_images, test_labels) = ( 35 | datasets.cifar10.load_data() 36 | ) 37 | 38 | # Normalize pixel values to be between 0 and 1 39 | train_images, test_images = train_images / 255.0, test_images / 255.0 40 | 41 | strategy = tf.distribute.MirroredStrategy() 42 | with strategy.scope(): 43 | model = models.Sequential() 44 | model.add( 45 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)) 46 | ) 47 | model.add(layers.MaxPooling2D((2, 2))) 48 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 49 | model.add(layers.MaxPooling2D((2, 2))) 50 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 51 | 52 | model.add(layers.Flatten()) 53 | model.add(layers.Dense(64, activation='relu')) 54 | model.add(layers.Dense(10)) 55 | 56 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) 57 | model.compile( 58 | optimizer=optimizer, 59 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 60 | metrics=['accuracy'], 61 | ) 62 | 63 | callbacks = [] 64 | if LOG_DIR: 65 | callbacks = [ 66 | tf.keras.callbacks.TensorBoard( 67 | log_dir=LOG_DIR, 68 | histogram_freq=1, 69 | ), 70 | ] 71 | 72 | model.fit( 73 | train_images, 74 | train_labels, 75 | epochs=FLAGS.epochs, 76 | validation_data=(test_images, test_labels), 77 | callbacks=callbacks, 78 | verbose=2, 79 | ) 80 | 81 | 82 | if __name__ == '__main__': 83 | app.run(main) 84 | -------------------------------------------------------------------------------- /examples/local_container_gpu/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager local launcher for CIFAR10 using GPUs. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/local_container_gpu/launcher.py -- \ 19 | --xm_wrap_late_bindings 20 | """ 21 | 22 | from absl import app 23 | from absl import flags 24 | from xmanager import xm 25 | from xmanager import xm_local 26 | 27 | _EXP_NAME = flags.DEFINE_string( 28 | 'exp_name', 'local-cifar10-gpu', 'Name of the experiment.', short_name='n' 29 | ) 30 | _INTERACTIVE = flags.DEFINE_bool( 31 | 'interactive', 32 | False, 33 | 'Launch the container and allow interactive access to it.', 34 | ) 35 | 36 | 37 | def main(argv) -> None: 38 | if len(argv) > 1: 39 | raise app.UsageError('Too many command-line arguments.') 40 | 41 | create_experiment = xm_local.create_experiment 42 | with create_experiment(experiment_title=_EXP_NAME.value) as experiment: 43 | docker_options = xm_local.DockerOptions(interactive=_INTERACTIVE.value) 44 | # Creating local executor with extra flag to track job's progress. 45 | executor = xm_local.Local( 46 | xm.JobRequirements(local_gpu=2), 47 | experimental_stream_output=True, 48 | docker_options=docker_options, 49 | ) 50 | 51 | # Empty args means nothing is passed into the job. 52 | executable_args = {} 53 | (executable,) = experiment.package( 54 | [ 55 | xm.python_container( 56 | executor_spec=executor.Spec(), 57 | args=executable_args, 58 | # Package the current directory that this script is in. 59 | path='.', 60 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6', 61 | entrypoint=xm.ModuleName('local_container_gpu.cifar10'), 62 | use_deep_module=True, 63 | ) 64 | ] 65 | ) 66 | job = xm.Job(executable, executor) 67 | 68 | experiment.add(job) 69 | 70 | 71 | if __name__ == '__main__': 72 | app.run(main) 73 | -------------------------------------------------------------------------------- /examples/local_container_gpu/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | tensorflow 3 | tensorflow-datasets 4 | -------------------------------------------------------------------------------- /examples/local_container_links/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | load("@io_bazel_rules_docker//container:container.bzl", "container_image") 16 | load("@python_deps//:requirements.bzl", "requirement") 17 | load("@subpar//:subpar.bzl", "par_binary") 18 | 19 | licenses(["notice"]) 20 | 21 | par_binary( 22 | name = "server", 23 | srcs = ["server.py"], 24 | deps = [ 25 | requirement("absl-py"), 26 | requirement("bottle"), 27 | requirement("bottle-redis"), 28 | ], 29 | ) 30 | 31 | container_image( 32 | name = "server_image", 33 | base = "@io_docker_index_library_python//image", 34 | entrypoint = ["/server.par"], 35 | files = [":server.par"], 36 | ) 37 | -------------------------------------------------------------------------------- /examples/local_container_links/README.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | launcher.py packages and executes (locally) two Docker containers: 4 | [Redis](https://hub.docker.com/_/redis), a key-value database, and a 5 | [Bottle](https://bottlepy.org/)-based HTTP server that listens on port 8080 and 6 | responds to requests at `/increment` by running 7 | [`INCR counter`](https://redis.io/commands/INCR) on Redis and sending back the 8 | result. 9 | 10 | ## Instructions 11 | 12 | 1. Clone the repository and install XManager via `pip install ./xmanager`. 13 | 2. Install Bazelisk via `npm install -g @bazel/bazelisk`. 14 | 3. Get into this directory and run `xmanager launch launcher.py -- --xm_bazel_command=bazelisk`. 15 | 4. Once it is idling, open `http://localhost:8080/increment` in the browser. 16 | -------------------------------------------------------------------------------- /examples/local_container_links/WORKSPACE: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") 16 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 17 | 18 | http_archive( 19 | name = "io_bazel_rules_docker", 20 | sha256 = "59d5b42ac315e7eadffa944e86e90c2990110a1c8075f1cd145f487e999d22b3", 21 | strip_prefix = "rules_docker-0.17.0", 22 | urls = ["https://github.com/bazelbuild/rules_docker/releases/download/v0.17.0/rules_docker-v0.17.0.tar.gz"], 23 | ) 24 | 25 | http_archive( 26 | name = "rules_python", 27 | sha256 = "934c9ceb552e84577b0faf1e5a2f0450314985b4d8712b2b70717dc679fdc01b", 28 | url = "https://github.com/bazelbuild/rules_python/releases/download/0.3.0/rules_python-0.3.0.tar.gz", 29 | ) 30 | 31 | git_repository( 32 | name = "subpar", 33 | remote = "https://github.com/google/subpar", 34 | tag = "2.0.0", 35 | ) 36 | 37 | load("@rules_python//python:pip.bzl", "pip_install") 38 | 39 | pip_install( 40 | name = "python_deps", 41 | requirements = "//:requirements.txt", 42 | ) 43 | 44 | load("@io_bazel_rules_docker//repositories:repositories.bzl", container_repositories = "repositories") 45 | 46 | container_repositories() 47 | 48 | load("@io_bazel_rules_docker//repositories:deps.bzl", container_deps = "deps") 49 | 50 | container_deps() 51 | 52 | load("@io_bazel_rules_docker//container:container.bzl", "container_pull") 53 | 54 | container_pull( 55 | name = "io_docker_index_library_python", 56 | digest = "sha256:11f3ccfbb8e809246f5993be99693966ffc99e1c7b632251fde27c0ce45b35f2", 57 | registry = "index.docker.io", 58 | repository = "library/python", 59 | tag = "3.9.6", 60 | ) 61 | -------------------------------------------------------------------------------- /examples/local_container_links/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A launcher for server.py and Redis. 15 | 16 | See README.md for details. 17 | """ 18 | 19 | from typing import Sequence 20 | 21 | from absl import app 22 | from xmanager import xm 23 | from xmanager import xm_local 24 | 25 | 26 | def main(argv: Sequence[str]) -> None: 27 | del argv # Unused. 28 | 29 | with xm_local.create_experiment( 30 | experiment_title='local_container_links' 31 | ) as experiment: 32 | [redis, server] = experiment.package([ 33 | xm.Packageable( 34 | executable_spec=xm.Container(image_path='redis'), 35 | executor_spec=xm_local.Local.Spec(), 36 | ), 37 | xm.Packageable( 38 | executable_spec=xm.BazelContainer(label='//:server_image.tar'), 39 | executor_spec=xm_local.Local.Spec(), 40 | ), 41 | ]) 42 | 43 | async def generator(work_unit): 44 | work_unit.add( 45 | xm.JobGroup( 46 | server=xm.Job( 47 | executable=server, 48 | executor=xm_local.Local( 49 | docker_options=xm_local.DockerOptions(ports={8080: 8080}) 50 | ), 51 | args={'redis_host': work_unit.get_full_job_name('redis')}, 52 | ), 53 | redis=xm.Job( 54 | name='redis', 55 | executable=redis, 56 | executor=xm_local.Local( 57 | docker_options=xm_local.DockerOptions( 58 | volumes={'/tmp/redis': '/data'} 59 | ) 60 | ), 61 | ), 62 | ) 63 | ) 64 | 65 | experiment.add(generator) 66 | 67 | 68 | if __name__ == '__main__': 69 | app.run(main) 70 | -------------------------------------------------------------------------------- /examples/local_container_links/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | absl-py 16 | bottle 17 | bottle-redis 18 | -------------------------------------------------------------------------------- /examples/local_container_links/server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """An HTTP server incrementing a value in Redis.""" 15 | 16 | from typing import Sequence 17 | 18 | from absl import app 19 | from absl import flags 20 | import bottle 21 | import bottle.ext.redis 22 | 23 | redis_host = flags.DEFINE_string('redis_host', None, "Redis' host.") 24 | 25 | server = bottle.Bottle() 26 | 27 | 28 | @server.route('/increment') 29 | def increment(rdb): 30 | return str(rdb.incr('counter')) 31 | 32 | 33 | def main(argv: Sequence[str]) -> None: 34 | del argv # Unused. 35 | 36 | server.install(bottle.ext.redis.RedisPlugin(host=redis_host.value)) 37 | bottle.run(server, host='0.0.0.0', port=8080, debug=True) 38 | 39 | 40 | if __name__ == '__main__': 41 | app.run(main) 42 | -------------------------------------------------------------------------------- /examples/parameter_controller/db_config.yaml: -------------------------------------------------------------------------------- 1 | # Connector used - one of ['cloudsql', 'sqlite', 'generic'] 2 | sql_connector: '' 3 | 4 | sql_connection_settings: 5 | # Backend used, e.g. 'mysql', 'postgresql' 6 | backend: '' 7 | # Username 8 | username: '' 9 | # Password 10 | password: '' 11 | # Host (or instance connection name for CloudSql) 12 | host: '' 13 | # Port 14 | port: 0 15 | # Database name 16 | db_name: '' -------------------------------------------------------------------------------- /examples/parameter_controller/inner_job/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py -------------------------------------------------------------------------------- /examples/parameter_controller/inner_job/wait_job.py: -------------------------------------------------------------------------------- 1 | """Job waiting.""" 2 | import time 3 | 4 | from absl import app 5 | from absl import flags 6 | 7 | _TIME_TO_SLEEP = flags.DEFINE_integer('time_to_sleep', 10, 'Time to sleep.') 8 | 9 | 10 | def main(_): 11 | print(f'Hello, waiting for {_TIME_TO_SLEEP.value}...') 12 | time.sleep(_TIME_TO_SLEEP.value) 13 | print('Done!') 14 | 15 | 16 | if __name__ == '__main__': 17 | app.run(main) 18 | -------------------------------------------------------------------------------- /examples/parameter_controller/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for a parameter controller example. 15 | 16 | The given program launches a dummy job on VertexAI, waits for its completion 17 | and then launches another job on Kubernetes. This kind of workflow can be used 18 | to define pipelines. 19 | 20 | Usage: 21 | 22 | 23 | xmanager launch examples/parameter_controller/launcher.py -- \ 24 | --xm_db_yaml_config=db_config.yaml 25 | [--xm_k8s_service_account_name=...] 26 | [--xm_gcp_service_account_name=...] 27 | 28 | The content of `db_config.yaml` must be updated to match the connection details 29 | to the DB used. 30 | """ 31 | import os 32 | 33 | from absl import app 34 | from absl import flags 35 | from xmanager import xm 36 | from xmanager import xm_local 37 | from xmanager.contrib import parameter_controller 38 | 39 | 40 | def main(_): 41 | with xm_local.create_experiment(experiment_title='cifar10') as experiment: 42 | 43 | @parameter_controller.controller( 44 | executor=xm_local.Kubernetes(), 45 | controller_args={ 46 | 'xm_k8s_service_account_name': ( 47 | flags.FLAGS.xm_k8s_service_account_name 48 | ), 49 | 'xm_gcp_service_account_name': ( 50 | flags.FLAGS.xm_gcp_service_account_name 51 | ), 52 | }, 53 | controller_env_vars={ 54 | 'GOOGLE_CLOUD_BUCKET_NAME': os.environ['GOOGLE_CLOUD_BUCKET_NAME'], 55 | }, 56 | # Package contents of this directory inside parameter controller job 57 | package_path='.', 58 | ) 59 | async def parameter_controller_example(experiment: xm.Experiment): 60 | spec = xm.PythonContainer( 61 | # Package contents of job to be launched 62 | path='inner_job', 63 | base_image='python:3.9', 64 | entrypoint=xm.ModuleName('wait_job'), 65 | ) 66 | [vertex_executable, k8s_executable] = experiment.package([ 67 | xm.Packageable( 68 | executable_spec=spec, 69 | executor_spec=xm_local.Vertex.Spec(), 70 | args={'time_to_sleep': 10}, 71 | ), 72 | xm.Packageable( 73 | executable_spec=spec, 74 | executor_spec=xm_local.Kubernetes.Spec(), 75 | args={'time_to_sleep': 20}, 76 | ), 77 | ]) 78 | 79 | wu1 = await experiment.add(xm.Job(k8s_executable, xm_local.Kubernetes())) 80 | await wu1.wait_until_complete() 81 | 82 | wu2 = await experiment.add(xm.Job(vertex_executable, xm_local.Vertex())) 83 | await wu2.wait_until_complete() 84 | 85 | experiment.add(parameter_controller_example()) # pylint: disable=no-value-for-parameter 86 | 87 | 88 | if __name__ == '__main__': 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /examples/parameter_controller/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | six 3 | sqlparse 4 | dm-launchpad[tensorflow] 5 | numpy 6 | cloudpickle 7 | cloud-sql-python-connector[pymysql] -------------------------------------------------------------------------------- /examples/vizier/README.md: -------------------------------------------------------------------------------- 1 | This is a launcher to express the use case for Vertex Vizier: 2 | 3 | https://cloud.google.com/vertex-ai/docs/vizier 4 | 5 | 6 | This is a dummy project for using hyperparameter optimization. 7 | 8 | In this project, we define a bivariate polynomial: 9 | 10 | ax^2 + by^2 + cxy + dx + ey + f 11 | 12 | with fixed args: a,b,c,d,e,f 13 | 14 | and hyper parameters: x,y 15 | -------------------------------------------------------------------------------- /examples/vizier/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""XManager launcher for Polynomial. 15 | 16 | Usage: 17 | 18 | xmanager launch examples/vizier/launcher.py -- \ 19 | --xm_wrap_late_bindings 20 | """ 21 | 22 | from absl import app 23 | 24 | from google.cloud import aiplatform_v1beta1 as aip 25 | 26 | from xmanager import xm 27 | from xmanager import xm_local 28 | from xmanager.vizier import vizier_cloud 29 | 30 | 31 | def get_study_spec() -> aip.StudySpec: 32 | return aip.StudySpec( 33 | algorithm=aip.StudySpec.Algorithm.RANDOM_SEARCH, 34 | parameters=[ 35 | aip.StudySpec.ParameterSpec( 36 | parameter_id='x', 37 | double_value_spec=aip.StudySpec.ParameterSpec.DoubleValueSpec( 38 | min_value=-2.0, max_value=2.0 39 | ), 40 | ), 41 | aip.StudySpec.ParameterSpec( 42 | parameter_id='y', 43 | double_value_spec=aip.StudySpec.ParameterSpec.DoubleValueSpec( 44 | min_value=-2.0, max_value=2.0 45 | ), 46 | ), 47 | ], 48 | metrics=[ 49 | aip.StudySpec.MetricSpec( 50 | metric_id='loss', goal=aip.StudySpec.MetricSpec.GoalType.MINIMIZE 51 | ) 52 | ], 53 | ) 54 | 55 | 56 | def main(_): 57 | with xm_local.create_experiment(experiment_title='polynomial') as experiment: 58 | spec = xm.PythonContainer( 59 | # Package the current directory that this script is in. 60 | path='.', 61 | base_image='gcr.io/deeplearning-platform-release/base-cpu', 62 | entrypoint=xm.ModuleName('polynomial'), 63 | ) 64 | 65 | [executable] = experiment.package( 66 | [ 67 | xm.Packageable( 68 | executable_spec=spec, 69 | executor_spec=xm_local.Vertex.Spec(), 70 | ), 71 | ] 72 | ) 73 | 74 | vizier_cloud.VizierExploration( 75 | experiment=experiment, 76 | job=xm.Job( 77 | executable=executable, 78 | executor=xm_local.Vertex(), 79 | ), 80 | study_factory=vizier_cloud.NewStudy(study_config=get_study_spec()), 81 | num_trials_total=3, 82 | num_parallel_trial_runs=2, 83 | ).launch() 84 | 85 | 86 | if __name__ == '__main__': 87 | app.run(main) 88 | -------------------------------------------------------------------------------- /examples/vizier/polynomial.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Compute the value of a bivariate quadratic polynomial. 15 | 16 | An example of finding the min values of x and y with auto hyperparam tuning. 17 | Set a, b, c, d, e, and f to constant args. Set x and y to be hyperparameters. 18 | """ 19 | 20 | from absl import app 21 | from absl import flags 22 | 23 | from vizier_worker import VizierWorker 24 | 25 | FLAGS = flags.FLAGS 26 | flags.DEFINE_integer('a', 1, 'a in ax^2 + by^2 + cxy + dx + ey + f') 27 | flags.DEFINE_integer('b', 1, 'b in ax^2 + by^2 + cxy + dx + ey + f') 28 | flags.DEFINE_integer('c', 0, 'c in ax^2 + by^2 + cxy + dx + ey + f') 29 | flags.DEFINE_integer('d', 1, 'd in ax^2 + by^2 + cxy + dx + ey + f') 30 | flags.DEFINE_integer('e', 1, 'e in ax^2 + by^2 + cxy + dx + ey + f') 31 | flags.DEFINE_integer('f', 1, 'f in ax^2 + by^2 + cxy + dx + ey + f') 32 | 33 | flags.DEFINE_float('x', 0, 'The hyperparameter variable X.') 34 | flags.DEFINE_float('y', 0, 'The hyperparameter variable Y.') 35 | 36 | flags.DEFINE_string( 37 | 'trial_name', 38 | None, 39 | ( 40 | 'Identifying the current job trial that measurements ' 41 | 'will be submitted to with `add_trial_measurement`. Format: ' 42 | 'projects/{project}/locations/{location}/studies/{study}/trials/{trial}' 43 | ), 44 | ) 45 | 46 | 47 | def main(_): 48 | worker = VizierWorker(FLAGS.trial_name) 49 | 50 | # dummy training loop: "train" for one epoch 51 | metric_value = float( 52 | FLAGS.a * FLAGS.x * FLAGS.x 53 | + FLAGS.b * FLAGS.y * FLAGS.y 54 | + FLAGS.c * FLAGS.x * FLAGS.y 55 | + FLAGS.d * FLAGS.x 56 | + FLAGS.e * FLAGS.y 57 | + FLAGS.f 58 | ) 59 | 60 | worker.add_trial_measurement(1, {'loss': metric_value}) 61 | 62 | 63 | if __name__ == '__main__': 64 | app.run(main) 65 | -------------------------------------------------------------------------------- /examples/vizier/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | google-cloud-aiplatform 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Setup configuration specifying XManager dependencies.""" 15 | 16 | from setuptools import find_namespace_packages 17 | from setuptools import setup 18 | 19 | with open('README.md', 'r', encoding='utf-8') as fh: 20 | long_description = fh.read() 21 | 22 | setup( 23 | name='xmanager', 24 | version='0.7.1', 25 | description='A framework for managing machine learning experiments', 26 | long_description=long_description, 27 | long_description_content_type='text/markdown', 28 | author='DeepMind Technologies Limited', 29 | packages=find_namespace_packages(exclude=['examples.*']), 30 | package_data={'': ['*.sh', '*.sql', '*.ini', '*.mako']}, 31 | python_requires='>=3.10', 32 | install_requires=[ 33 | 'absl-py', 34 | 'alembic==1.4.3', 35 | 'async_generator', 36 | 'attrs', 37 | 'cloud-sql-python-connector', 38 | 'docker', 39 | 'etils[epath]', 40 | 'google-api-core', 41 | 'google-api-python-client', 42 | 'google-auth', 43 | 'google-cloud-aiplatform', 44 | 'google-cloud-storage', 45 | 'humanize', 46 | 'immutabledict', 47 | 'kubernetes', 48 | 'pyyaml', 49 | 'sqlalchemy==1.2.19', 50 | 'sqlparse', 51 | 'termcolor', 52 | ], 53 | entry_points={ 54 | 'console_scripts': [ 55 | 'xmanager = xmanager.cli.cli:entrypoint', 56 | ], 57 | }, 58 | # https://github.com/pypa/warehouse/blob/de4a2e5e2ec26d01bf7813da427ebc4725dccde9/warehouse/templates/packaging/detail.html#L20-L60 59 | project_urls={ 60 | 'Homepage': 'https://github.com/deepmind/xmanager', 61 | 'Issue tracker': 'https://github.com/deepmind/xmanager/issues', 62 | }, 63 | ) 64 | -------------------------------------------------------------------------------- /setup_scripts/install_bazel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # Installs the latest version of Bazel. 4 | 5 | # Copyright 2021 DeepMind Technologies Limited 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | sudo apt-get install -y npm 20 | sudo ln -s /usr/bin/nodejs /usr/bin/node 21 | sudo npm install -g @bazel/bazelisk 22 | export PATH=$PATH:$(npm bin) 23 | echo $(bazel --version) -------------------------------------------------------------------------------- /setup_scripts/install_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # Installs the latest version of Docker. 4 | 5 | # Copyright 2021 DeepMind Technologies Limited 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | sudo apt update 20 | sudo apt-get install -y curl 21 | curl -fsSL https://get.docker.com -o get-docker.sh 22 | sudo sh get-docker.sh -------------------------------------------------------------------------------- /setup_scripts/install_gcloud.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # Installs the Google Cloud SDK. 4 | 5 | # Copyright 2021 DeepMind Technologies Limited 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | sudo apt-get update 20 | echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \ 21 | | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 22 | sudo apt-get install -y apt-transport-https ca-certificates gnupg curl 23 | curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ 24 | | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 25 | sudo apt-get update 26 | sudo apt-get -y install google-cloud-sdk -------------------------------------------------------------------------------- /setup_scripts/install_python.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # Installs Python 3.9. 4 | 5 | # Copyright 2021 DeepMind Technologies Limited 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | sudo apt-get update 20 | sudo apt-get install -y software-properties-common 21 | sudo add-apt-repository -y ppa:deadsnakes/ppa 22 | sudo apt-get update 23 | sudo apt-get install -y python3.9 24 | sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 25 | sudo apt-get install -y python3-pip python3-distutils 26 | python3 -m pip install --user --upgrade pip -------------------------------------------------------------------------------- /setup_scripts/install_xmanager.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # Installs XManager. 4 | 5 | # Copyright 2021 DeepMind Technologies Limited 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | sudo apt-get update 20 | sudo apt-get install -y git 21 | 22 | python3 -m pip install --user --upgrade setuptools 23 | python3 -m pip install --user git+https://github.com/deepmind/xmanager 24 | 25 | if [ -n "${BASH_VERSION}" ]; then 26 | echo "Adding ~/.local/bin to PATH in ~/.bashrc..." 27 | # TODO: Use sed to search and replace instead of simply appending 28 | echo "export PATH=${PATH}:~/.local/bin" >> ~/.bashrc 29 | export PATH="${PATH}:~/.local/bin" 30 | else 31 | echo "Add ~/.local/bin/xmanager to your PATH to directly launch xmanager from the shell." 32 | fi -------------------------------------------------------------------------------- /setup_scripts/setup_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # Installs XManager and all required dependencies. 4 | # Creates basic GCP configuration based on user preferences. 5 | 6 | # Copyright 2021 DeepMind Technologies Limited 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License 19 | 20 | chmod +x *.sh 21 | 22 | ./install_bazel.sh && 23 | ./install_docker.sh && 24 | ./install_gcloud.sh && 25 | ./install_python.sh && 26 | . ./install_xmanager.sh && 27 | . ./setup_gcp.sh -------------------------------------------------------------------------------- /xmanager/bazel/client.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A module for communicating with the Bazel server.""" 15 | 16 | import abc 17 | from typing import List, Sequence, Tuple 18 | 19 | import attr 20 | 21 | 22 | class BazelService(abc.ABC): 23 | """An interface for Bazel operations.""" 24 | 25 | @abc.abstractmethod 26 | def fetch_kinds(self, labels: Sequence[str]) -> List[str]: 27 | """Fetches kinds of given targets. 28 | 29 | See https://docs.bazel.build/versions/main/query.html#output-label_kind. 30 | 31 | Args: 32 | labels: Labels of the targets to query. 33 | 34 | Returns: 35 | A list of kinds, for example, `['py_binary rule']`. 36 | """ 37 | raise NotImplementedError 38 | 39 | @abc.abstractmethod 40 | def build_targets( 41 | self, labels: Sequence[str], tail_args: Sequence[str] 42 | ) -> List[List[str]]: 43 | """Builds given targets and returns paths to their important outputs. 44 | 45 | Args: 46 | labels: Labels of the targets to build. 47 | tail_args: Arguments to append to the Bazel command. 48 | 49 | Returns: 50 | For each label returns a list of its important outputs. 51 | """ 52 | raise NotImplementedError 53 | 54 | 55 | def _to_tuple(sequence: Sequence[str]) -> Tuple[str, ...]: 56 | """A standalone function to satisfy PyType.""" 57 | return tuple(sequence) 58 | 59 | 60 | @attr.s(auto_attribs=True, frozen=True) 61 | class BazelTarget: 62 | """A Bazel target to be built.""" 63 | 64 | label: str 65 | bazel_args: Tuple[str, ...] = attr.ib(converter=_to_tuple) 66 | -------------------------------------------------------------------------------- /xmanager/bazel/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A collection of tools for files operations.""" 15 | 16 | import os 17 | import tempfile 18 | 19 | 20 | class TemporaryFilePath: 21 | """A context manager providing a temporary file path. 22 | 23 | Unlike NamedTemporaryFile, TemporaryFilePath closes the file when one enters 24 | the context. 25 | """ 26 | 27 | _path: str 28 | 29 | def __enter__(self): 30 | fd, path = tempfile.mkstemp() 31 | os.close(fd) 32 | self._path = path 33 | return path 34 | 35 | def __exit__(self, error_type, error_value, traceback): 36 | os.remove(self._path) 37 | -------------------------------------------------------------------------------- /xmanager/cli/README.md: -------------------------------------------------------------------------------- 1 | # XManager CLI 2 | 3 | This directory contains the command-line interface for XManager. 4 | 5 | ## `launch` 6 | 7 | Runs a given launch script. 8 | 9 | ``` 10 | xmanager launch path/to/launch/script.py 11 | ``` 12 | 13 | ## GKE 14 | 15 | In order to use the Kubernetes executor, you must host a Kubernetes cluster or 16 | use a Kubernetes cluster hosted by a cloud provider. An easy way to quickly 17 | create a cluster is by using [caliban](https://caliban.readthedocs.io/). 18 | 19 | The `xmanager` CLI can create clusters by calling caliban. To create a GKE 20 | auto-scaling cluster to be used with the Kubernetes executor, run: 21 | 22 | ``` 23 | xmanager cluster create 24 | ``` 25 | 26 | This command is equivalent to running `caliban cluster create`. 27 | 28 | To delete this cluster: 29 | 30 | ``` 31 | xmanager cluster delete 32 | ``` 33 | 34 | This command is equivalent to running `caliban cluster delete`. 35 | 36 | 37 | -------------------------------------------------------------------------------- /xmanager/cli/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Xmanager command-line interface.""" 15 | 16 | import errno 17 | import importlib 18 | import os 19 | import sys 20 | 21 | from absl import app 22 | 23 | _DEFAULT_ZONE = 'us-west1-b' 24 | _DEFAULT_CLUSTER_NAME = 'xmanager-via-caliban' 25 | 26 | 27 | def main(argv): 28 | if len(argv) < 3: 29 | raise app.UsageError('There must be at least 2 command-line arguments') 30 | cmd = argv[1] 31 | if cmd == 'launch': 32 | launch_script = argv[2] 33 | if not os.path.exists(launch_script): 34 | raise OSError(errno.ENOENT, f'File not found: {launch_script}') 35 | sys.path.insert(0, os.path.abspath(os.path.dirname(launch_script))) 36 | launch_module, _ = os.path.splitext(os.path.basename(launch_script)) 37 | m = importlib.import_module(launch_module) 38 | sys.path.pop(0) 39 | argv = [ 40 | launch_script, 41 | '--xm_launch_script={}'.format(launch_script), 42 | ] + argv[3:] 43 | app.run(m.main, argv=argv) 44 | elif cmd == 'cluster': 45 | caliban_gke = importlib.import_module('caliban.platform.gke.cli') 46 | caliban_gke_types = importlib.import_module('caliban.platform.gke.types') 47 | subcmd = argv[2] 48 | args = { 49 | 'dry_run': False, 50 | 'cluster_name': _DEFAULT_CLUSTER_NAME, 51 | 'zone': _DEFAULT_ZONE, 52 | 'release_channel': caliban_gke_types.ReleaseChannel.REGULAR, 53 | 'single_zone': True, 54 | } 55 | if subcmd == 'create': 56 | caliban_gke._cluster_create(args) # pylint: disable=protected-access 57 | elif subcmd == 'delete': 58 | caliban_gke._cluster_delete(args) # pylint: disable=protected-access 59 | else: 60 | raise app.UsageError( 61 | f'Subcommand `{cmd} {subcmd}` is not a supported subcommand' 62 | ) 63 | else: 64 | raise app.UsageError(f'Command `{cmd}` is not a supported command') 65 | 66 | 67 | def entrypoint(): 68 | app.run(main) 69 | 70 | 71 | if __name__ == '__main__': 72 | app.run(main) 73 | -------------------------------------------------------------------------------- /xmanager/cloud/README.md: -------------------------------------------------------------------------------- 1 | # `cloud` 2 | 3 | This directory contains tools and clients for various Google Cloud products. 4 | -------------------------------------------------------------------------------- /xmanager/cloud/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /xmanager/cloud/build_image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from absl.testing import absltest 15 | from xmanager import xm 16 | from xmanager.cloud import build_image 17 | 18 | 19 | class BuildImageTest(absltest.TestCase): 20 | 21 | def create_container(self, entrypoint) -> xm.PythonContainer: 22 | return xm.PythonContainer(entrypoint=entrypoint) 23 | 24 | def test_get_entrypoint_commands_module_adds_suffix(self): 25 | project = self.create_container(xm.ModuleName('some.python.module')) 26 | entrypoint_commands = build_image._get_entrypoint_commands(project) 27 | self.assertEndsWith(entrypoint_commands, ' "$@"') 28 | 29 | def test_get_entrypoint_commands_adds_suffix(self): 30 | commands = ['echo "aaa"'] 31 | project = self.create_container(xm.CommandList(commands)) 32 | entrypoint_commands = build_image._get_entrypoint_commands(project) 33 | self.assertEndsWith(entrypoint_commands, ' "$@"') 34 | 35 | def test_get_entrypoint_commands_no_dup_plain_suffix(self): 36 | commands = ['echo "aaa" $@'] 37 | project = self.create_container(xm.CommandList(commands)) 38 | entrypoint_commands = build_image._get_entrypoint_commands(project) 39 | self.assertEndsWith(entrypoint_commands, ' $@') 40 | 41 | def test_get_entrypoint_commands_no_dup_quoted_suffix(self): 42 | commands = ['echo "aaa" "$@"'] 43 | project = self.create_container(xm.CommandList(commands)) 44 | entrypoint_commands = build_image._get_entrypoint_commands(project) 45 | self.assertEndsWith(entrypoint_commands, ' "$@"') 46 | self.assertNotEndsWith(entrypoint_commands, ' "$@" "$@"') 47 | 48 | def test_get_entrypoint_commands_dup_single_quoted_suffix(self): 49 | commands = ['echo "aaa" \'$@\''] 50 | project = self.create_container(xm.CommandList(commands)) 51 | entrypoint_commands = build_image._get_entrypoint_commands(project) 52 | self.assertEndsWith(entrypoint_commands, ' \'$@\' "$@"') 53 | 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /xmanager/cloud/cloud_build_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for cloud_build.""" 15 | from unittest import mock 16 | from absl.testing import absltest 17 | 18 | from xmanager.cloud import cloud_build 19 | 20 | 21 | class CloudBuildTest(absltest.TestCase): 22 | 23 | def test_build_request_body(self): 24 | client = cloud_build.Client( 25 | 'my-project', 26 | 'my-bucket', 27 | mock.Mock(), 28 | use_kaniko=False, 29 | use_cloud_build_cache=False, 30 | ) 31 | image = client._build_request_body('path/to/project', 'my-image', 'live') 32 | self.assertEqual( 33 | image, 34 | { 35 | 'images': ['my-image'], 36 | 'options': {'machineType': 'E2_HIGHCPU_32'}, 37 | 'source': { 38 | 'storageSource': { 39 | 'bucket': 'my-bucket', 40 | 'object': 'path/to/project', 41 | }, 42 | }, 43 | 'steps': [{ 44 | 'args': [ 45 | 'build', 46 | '-t', 47 | 'my-image:live', 48 | '-t', 49 | 'my-image:latest', 50 | '.', 51 | ], 52 | 'name': 'gcr.io/cloud-builders/docker', 53 | }], 54 | 'timeout': '1200s', 55 | }, 56 | ) 57 | 58 | def test_build_request_body_use_kaniko(self): 59 | client = cloud_build.Client( 60 | 'my-project', 61 | 'my-bucket', 62 | mock.Mock(), 63 | use_kaniko=True, 64 | use_cloud_build_cache=False, 65 | ) 66 | image = client._build_request_body('path/to/project', 'my-image', 'live') 67 | self.assertEqual( 68 | image, 69 | { 70 | 'source': { 71 | 'storageSource': { 72 | 'bucket': 'my-bucket', 73 | 'object': 'path/to/project', 74 | }, 75 | }, 76 | 'steps': [{ 77 | 'args': [ 78 | '--destination=my-image:live', 79 | '--destination=my-image:latest', 80 | '--cache=true', 81 | '--cache-ttl=336h', 82 | ], 83 | 'name': 'gcr.io/kaniko-project/executor:latest', 84 | }], 85 | 'timeout': '1200s', 86 | }, 87 | ) 88 | 89 | def test_build_request_body_use_build_cache(self): 90 | client = cloud_build.Client( 91 | 'my-project', 92 | 'my-bucket', 93 | mock.Mock(), 94 | use_kaniko=False, 95 | use_cloud_build_cache=True, 96 | ) 97 | image = client._build_request_body('path/to/project', 'my-image', 'live') 98 | self.assertEqual( 99 | image, 100 | { 101 | 'images': ['my-image'], 102 | 'options': {'machineType': 'E2_HIGHCPU_32'}, 103 | 'source': { 104 | 'storageSource': { 105 | 'bucket': 'my-bucket', 106 | 'object': 'path/to/project', 107 | }, 108 | }, 109 | 'steps': [{ 110 | 'args': [ 111 | 'build', 112 | '-t', 113 | 'my-image:live', 114 | '-t', 115 | 'my-image:latest', 116 | '--cache-from', 117 | 'my-image:latest', 118 | '.', 119 | ], 120 | 'name': 'gcr.io/cloud-builders/docker', 121 | }], 122 | 'timeout': '1200s', 123 | }, 124 | ) 125 | 126 | 127 | if __name__ == '__main__': 128 | absltest.main() 129 | -------------------------------------------------------------------------------- /xmanager/cloud/data/wrapped_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | python3 -c "import vertex_utils; vertex_utils.create_workerpool_address_env_vars_script('./map_xm_env_vars')" 18 | source ./map_xm_env_vars 19 | ARGS=($(python3 -c "import vertex_utils; import sys; vertex_utils.print_workerpool_address_args(sys.argv)" $@ | tr -d '[],')) 20 | ./entrypoint.sh ${ARGS[@]} 21 | -------------------------------------------------------------------------------- /xmanager/cloud/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for xmanager.cloud.utils.""" 15 | 16 | import os 17 | import tempfile 18 | import unittest 19 | 20 | from xmanager.cloud import utils 21 | 22 | _CLUSTER_SPEC = """{ 23 | "cluster": { 24 | "workerpool0": ["cmle-training-workerpool0-ab-0:2222"], 25 | "workerpool1": ["cmle-training-workerpool1-ab-0:2222", "cmle-training-workerpool1-ab-1:2222"], 26 | "workerpool2": ["cmle-training-workerpool2-ab-0:2222", "cmle-training-workerpool2-ab-1:2222"] 27 | }, 28 | "environment": "cloud", 29 | "task": { 30 | "type": "workerpool1", 31 | "index": 1, 32 | "trial": "" 33 | } 34 | }""".replace( 35 | '\n', ' ' 36 | ) 37 | 38 | 39 | class UtilsTest(unittest.TestCase): 40 | 41 | def tearDown(self): 42 | super(UtilsTest, self).tearDown() 43 | os.environ['CLUSTER_SPEC'] = '' 44 | 45 | def test_get_master_address_port(self): 46 | os.environ['CLUSTER_SPEC'] = _CLUSTER_SPEC 47 | address, port = utils.get_master_address_port() 48 | self.assertEqual(address, 'cmle-training-workerpool0-ab-0') 49 | self.assertEqual(port, '2222') 50 | 51 | def test_get_master_address_port_default(self): 52 | address, port = utils.get_master_address_port() 53 | self.assertEqual(address, '127.0.0.1') 54 | self.assertEqual(port, '29500') 55 | 56 | def test_get_world_size_rank(self): 57 | os.environ['CLUSTER_SPEC'] = _CLUSTER_SPEC 58 | world_size, rank = utils.get_world_size_rank() 59 | self.assertEqual(world_size, 5) 60 | self.assertEqual(rank, 2) 61 | 62 | def test_get_world_size_rank_default(self): 63 | world_size, rank = utils.get_world_size_rank() 64 | self.assertEqual(world_size, 1) 65 | self.assertEqual(rank, 0) 66 | 67 | def test_wrap_and_unwrap_addresses(self): 68 | arg = '--master=' + utils.get_workerpool_address('workerpool0') 69 | self.assertEqual(arg, '--master=%objectname(workerpool0)%') 70 | os.environ['CLUSTER_SPEC'] = _CLUSTER_SPEC 71 | self.assertEqual( 72 | utils.map_workerpool_address_args([arg]), 73 | ['--master=cmle-training-workerpool0-ab-0:2222'], 74 | ) 75 | 76 | def test_create_workerpool_address_env_vars_script(self): 77 | os.environ['MY_WORKER'] = utils.get_workerpool_address('workerpool0') 78 | os.environ['CLUSTER_SPEC'] = _CLUSTER_SPEC 79 | t = tempfile.NamedTemporaryFile() 80 | utils.create_workerpool_address_env_vars_script(t.name) 81 | expected = """ 82 | #!/bin/bash 83 | 84 | export MY_WORKER=cmle-training-workerpool0-ab-0:2222 85 | """ 86 | with open(t.name) as f: 87 | self.assertEqual(f.read(), expected.strip()) 88 | 89 | 90 | if __name__ == '__main__': 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /xmanager/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /xmanager/contrib/addressing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module for getting the address of XManager jobs. 15 | 16 | Addresses in XManager can be statically evaluated because the experiment ID is 17 | known. Addressing should not involve tokens or late-bindings. 18 | """ 19 | 20 | 21 | def k8s_pod_domain( 22 | job_name: str, 23 | experiment_id: int, 24 | work_unit_id: int, 25 | service: str = 'experiments', 26 | namespace: str = 'default', 27 | ) -> str: 28 | """Returns the Kubernetes pod address of a job. 29 | 30 | Args: 31 | job_name: Job name. 32 | experiment_id: Experiment ID. 33 | work_unit_id: Work unit ID 34 | service: Name of the service for the job. Defaults to 'experiments' 35 | namespace: Namespace of the job. Defaults to 'default' 36 | """ 37 | return ( 38 | f'{experiment_id}-{work_unit_id}-{job_name}' 39 | f'.{service}.{namespace}.svc.cluster.local:2222' 40 | ) 41 | -------------------------------------------------------------------------------- /xmanager/contrib/addressing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for addressing.""" 15 | 16 | from absl.testing import absltest 17 | from xmanager.contrib import addressing 18 | 19 | 20 | class AddressingTest(absltest.TestCase): 21 | 22 | def test_k8s_pod_domain(self): 23 | address = addressing.k8s_pod_domain( 24 | job_name='cifar10', 25 | experiment_id=123, 26 | work_unit_id=4, 27 | service='best_service', 28 | namespace='best_namespace', 29 | ) 30 | 31 | self.assertEqual( 32 | address, 33 | '123-4-cifar10.best_service.best_namespace.svc.cluster.local:2222', 34 | ) 35 | 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /xmanager/contrib/copybara.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module for transforming source code using copybara. 15 | 16 | XManager primarily uses Copybara to run folder-to-folder workflows in the form: 17 | 18 | core.workflow( 19 | name = "folder_to_folder", 20 | origin = folder.origin(), 21 | destination = folder.destination(), 22 | ... 23 | ) 24 | 25 | Copybara allows for iterative local development on multiple platforms without 26 | needing to add/remove platform-specific code modifications. This allows you to 27 | preprocess source code so that it can be run on different platforms with 28 | different executors. e.g. 29 | 30 | local_version = run_workflow(config, 'local', path) 31 | vertex_version = run_workflow(config, 'vertex', path) 32 | 33 | local_spec = xm.PythonContainer(path=local_version, **kwargs) 34 | vertex_spec = xm.PythonContainer(path=vertex_version, **kwargs) 35 | 36 | [local_executable, vertex_executable] = experiment.package([ 37 | xm.Packageable( 38 | executable_spec=spec, 39 | executor_spec=xm_local.Local.Spec()), 40 | xm.Packageable( 41 | executable_spec=spec, 42 | executor_spec=xm_local.Vertex.Spec())]) 43 | 44 | Copybara has no release process, so you must compile copybara yourself: 45 | https://github.com/google/copybara 46 | """ 47 | import os 48 | import subprocess 49 | import tempfile 50 | from typing import Optional 51 | 52 | # Set with the compiled path to copybara e.g. 53 | # COPYBARA_BIN = 'bazel-bin/java/com/google/copybara/copybara_deploy.jar' 54 | COPYBARA_BIN = 'copybara' 55 | 56 | 57 | def run_workflow( 58 | config: str, 59 | workflow: str, 60 | origin_folder: str, 61 | destination_folder: Optional[str] = None, 62 | config_root: Optional[str] = None, 63 | ) -> str: 64 | """Run a workflow in a copybara config to transform origin to destination. 65 | 66 | Args: 67 | config: Path to the Copybara config. 68 | workflow: Name of a workflow in copybara config. 69 | origin_folder: The origin folder to use as input. This will passed to 70 | Copybara via the source_ref argument. 71 | destination_folder: The destination folder to output. 72 | config_root: Configuration root path to be used for resolving absolute 73 | config labels like '//foo/bar'. 74 | 75 | Returns: 76 | The output destination folder. 77 | """ 78 | origin_folder = os.path.abspath(origin_folder) 79 | if not destination_folder: 80 | destination_folder = tempfile.mkdtemp() 81 | command = [ 82 | COPYBARA_BIN, 83 | config, 84 | workflow, 85 | '--ignore-noop', 86 | origin_folder, 87 | '--folder-dir=' + destination_folder, 88 | ] 89 | if config_root: 90 | command += ['--config-root=' + config_root] 91 | print('Copybara command: ', command) 92 | subprocess.run(command, check=True) 93 | return destination_folder 94 | -------------------------------------------------------------------------------- /xmanager/contrib/framework_defaults_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | from xmanager import xm 18 | from xmanager.contrib import framework_defaults 19 | 20 | MLFramework = framework_defaults.MLFramework 21 | 22 | 23 | class FrameworkDefaultsTest(parameterized.TestCase): 24 | 25 | def test_known_frameworks(self): 26 | self.assertEqual( 27 | framework_defaults._get_framework('torch'), MLFramework.PYTORCH 28 | ) 29 | self.assertEqual( 30 | framework_defaults._get_framework('pytorch'), MLFramework.PYTORCH 31 | ) 32 | self.assertEqual(framework_defaults._get_framework('tf'), MLFramework.TF2) 33 | self.assertEqual(framework_defaults._get_framework('tf1'), MLFramework.TF1) 34 | self.assertEqual(framework_defaults._get_framework('tf2'), MLFramework.TF2) 35 | self.assertEqual( 36 | framework_defaults._get_framework('tensorflow 2.x'), MLFramework.TF2 37 | ) 38 | self.assertEqual(framework_defaults._get_framework('jax'), MLFramework.JAX) 39 | self.assertEqual(framework_defaults._get_framework('flax'), MLFramework.JAX) 40 | 41 | def test_unknown_frameworks(self): 42 | self.assertEqual( 43 | framework_defaults._get_framework('huggingface'), MLFramework.UNKNOWN 44 | ) 45 | self.assertEqual( 46 | framework_defaults._get_framework('objax'), MLFramework.UNKNOWN 47 | ) 48 | self.assertEqual( 49 | framework_defaults._get_framework('not a framework name'), 50 | MLFramework.UNKNOWN, 51 | ) 52 | 53 | @parameterized.named_parameters( 54 | ('cpu', None), 55 | ('gpu', xm.ResourceType.V100), 56 | ('tpu', xm.ResourceType.TPU_V3), 57 | ) 58 | def test_jax_base_image(self, accelerator): 59 | base_image = framework_defaults.base_image(MLFramework.JAX, accelerator) 60 | self.assertStartsWith(base_image, 'gcr.io/deeplearning-platform-release/') 61 | # Jax uses CUDA images. 62 | self.assertIn('cu', base_image) 63 | 64 | @parameterized.named_parameters( 65 | ('cpu', None), 66 | ('gpu', xm.ResourceType.V100), 67 | ('tpu', xm.ResourceType.TPU_V3), 68 | ) 69 | def test_tf2_base_image(self, accelerator): 70 | base_image = framework_defaults.base_image(MLFramework.TF2, accelerator) 71 | self.assertStartsWith(base_image, 'gcr.io/deeplearning-platform-release/') 72 | self.assertIn('tf2-', base_image) 73 | 74 | @parameterized.named_parameters( 75 | ('cpu', None), 76 | ('gpu', xm.ResourceType.V100), 77 | ('tpu', xm.ResourceType.TPU_V3), 78 | ) 79 | def test_torch_base_image(self, accelerator): 80 | base_image = framework_defaults.base_image(MLFramework.PYTORCH, accelerator) 81 | self.assertStartsWith(base_image, 'gcr.io/') 82 | self.assertIn('pytorch', base_image) 83 | if accelerator in xm.TpuType: 84 | self.assertIn('tpu', base_image) 85 | 86 | @parameterized.named_parameters( 87 | ('cpu', None), 88 | ('gpu', xm.ResourceType.V100), 89 | ('tpu', xm.ResourceType.TPU_V3), 90 | ) 91 | def test_unsupported_tf1_base_image(self, accelerator): 92 | base_image = framework_defaults.base_image(MLFramework.TF1, accelerator) 93 | self.assertStartsWith(base_image, 'gcr.io/deeplearning-platform-release/') 94 | self.assertIn('tf', base_image) 95 | 96 | @parameterized.named_parameters( 97 | ('cpu', None), 98 | ('gpu', xm.ResourceType.V100), 99 | ('tpu', xm.ResourceType.TPU_V3), 100 | ) 101 | def test_unknown_base_image(self, accelerator): 102 | base_image = framework_defaults.base_image(MLFramework.UNKNOWN, accelerator) 103 | self.assertStartsWith(base_image, 'gcr.io/deeplearning-platform-release/') 104 | self.assertIn('base', base_image) 105 | 106 | 107 | if __name__ == '__main__': 108 | absltest.main() 109 | -------------------------------------------------------------------------------- /xmanager/contrib/process_entry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Entry point of a user-defined Python function.""" 16 | 17 | 18 | import contextlib 19 | import functools 20 | import os 21 | import sys 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | import cloudpickle 27 | 28 | 29 | FLAGS = flags.FLAGS 30 | flags.DEFINE_string( 31 | 'data_file', '', 'Pickle file location with entry points for all nodes' 32 | ) 33 | flags.DEFINE_string( 34 | 'init_file', 35 | '', 36 | 'Pickle file location containing initialization module ' 37 | 'executed for each node prior to an entry point', 38 | ) 39 | flags.DEFINE_string('flags_to_populate', '{}', 'obsolete') 40 | 41 | 42 | def _parse_process_entry_flags(all_argv: list[str]) -> list[str]: 43 | """Parse and consume all flags for the entry script; return the rest.""" 44 | # unconsumed_argv will still include all_argv[0], which is expected to be 45 | # the program name and is ignored by flag parsing. 46 | unconsumed_argv = FLAGS(all_argv, known_only=True) 47 | 48 | # JAX doesn't use absl flags and so we need to forward absl flags to JAX 49 | # explicitly. Here's a heuristic to detect JAX flags and forward them. 50 | if any(arg.startswith('--jax_') for arg in sys.argv): 51 | try: 52 | # pytype:disable=import-error 53 | # pylint:disable=g-import-not-at-top 54 | import jax 55 | # pytype:enable=import-error 56 | # pylint:enable=g-import-not-at-top 57 | jax.config.parse_flags_with_absl() 58 | except ImportError: 59 | pass 60 | 61 | return unconsumed_argv 62 | 63 | 64 | def main(argv: list[str], process_argv: list[str]): 65 | # See `parse_flags_and_run()` for why arguments are passed in `process_argv` 66 | # instead. 67 | assert len(argv) == 1 68 | del argv 69 | 70 | # Allow for importing modules from the current directory. 71 | sys.path.append(os.getcwd()) 72 | data_file = FLAGS.data_file 73 | init_file = FLAGS.init_file 74 | 75 | if os.environ.get('TF_CONFIG', None): 76 | # For GCP runtime log to STDOUT so that logs are not reported as errors. 77 | logging.get_absl_handler().python_handler.stream = sys.stdout 78 | 79 | if init_file: 80 | init_function = cloudpickle.load(open(init_file, 'rb')) 81 | init_function() 82 | functions = cloudpickle.load(open(data_file, 'rb')) 83 | 84 | # Now that the code that we intend to run has been unpickled, that should 85 | # have caused the registration of any remaining flags that the program needs. 86 | [unused_program_name, *unconsumed_argv] = FLAGS(process_argv, known_only=True) 87 | if unconsumed_argv: 88 | logging.warning( 89 | 'The following command-line arguments were passed to the ' 90 | 'program but are not used by anything that it imports: %s', 91 | unconsumed_argv, 92 | ) 93 | 94 | with contextlib.suppress(): # no-op context manager 95 | # Currently only one function is supported. 96 | functions[0]() 97 | 98 | 99 | def parse_flags_and_run(): 100 | # Parse flags for this module and the things it has already imported. 101 | # Pass whatever flags are left over to main() through a side channel, so that 102 | # app.run() doesn't try to parse them before we have set the scene. 103 | [program_name, *process_argv] = _parse_process_entry_flags(sys.argv) 104 | app.run( 105 | functools.partial(main, process_argv=[program_name, *process_argv]), 106 | argv=[program_name], 107 | ) 108 | 109 | 110 | if __name__ == '__main__': 111 | parse_flags_and_run() 112 | -------------------------------------------------------------------------------- /xmanager/contrib/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Helper methods for running Tensorboard from the client.""" 15 | 16 | from typing import Any, Mapping, Optional 17 | 18 | from xmanager import xm 19 | 20 | 21 | class TensorboardProvider: 22 | """A class to generate package and job/args to Tensorboard jobs.""" 23 | 24 | DEFAULT_TENSORBOARD_PORT = 6006 25 | 26 | @staticmethod 27 | def get_tensorboard_packageable(timeout_secs: int) -> xm.PythonContainer: 28 | """Creates container spec running TensorBoard server. 29 | 30 | Args: 31 | timeout_secs: Seconds to run the server for. Note that a value of 0 32 | disables the associated timeout. 33 | 34 | Raises: 35 | RuntimeError: `timeout_secs` is negative. 36 | 37 | Returns: 38 | Spec of container running TensorBoard server for a specified 39 | period of time. 40 | """ 41 | if timeout_secs < 0: 42 | raise RuntimeError('`timeout_secs` must be a nonnegative number') 43 | 44 | return xm.PythonContainer( 45 | base_image='tensorflow/tensorflow', 46 | entrypoint=xm.CommandList([f'timeout {timeout_secs}s tensorboard']), 47 | ) 48 | 49 | @staticmethod 50 | def get_tensorboard_job_args( 51 | log_dir: str, 52 | port: Optional[int] = None, 53 | additional_args: Optional[Mapping[str, Any]] = None, 54 | ) -> Mapping[str, Any]: 55 | """Get arguments to start a Tensorboard job.""" 56 | args = { 57 | 'logdir': log_dir, 58 | 'port': port or TensorboardProvider.DEFAULT_TENSORBOARD_PORT, 59 | # Allows accessing visualisations from Docker container running locally. 60 | 'host': '0.0.0.0', 61 | # This is set to true by default when running Tensorboard. 62 | # Since it doesn't seem to be working well with GKE Workload Identity, 63 | # we set it to false for now. Can be overriden through the 64 | # `additional_args` param. 65 | # 66 | # https://github.com/tensorflow/tensorboard/issues/4784#issuecomment-868945650 67 | 'load_fast': 'false', 68 | } 69 | if additional_args: 70 | args.update(additional_args) 71 | 72 | return args 73 | 74 | 75 | def add_tensorboard( 76 | experiment: xm.Experiment, 77 | logdir: str, 78 | executor: xm.Executor, 79 | timeout_secs: int = 60 * 60 * 24, 80 | args: Optional[Mapping[str, Any]] = None, 81 | ) -> None: 82 | """Self-contained function which adds a Tensorboard auxiliary job to @experiment.""" 83 | provider = TensorboardProvider 84 | [executable] = experiment.package( 85 | [ 86 | xm.Packageable( 87 | provider.get_tensorboard_packageable(timeout_secs=timeout_secs), 88 | executor.Spec(), 89 | ) 90 | ] 91 | ) 92 | 93 | job = xm.Job( 94 | executable, 95 | executor, 96 | args=provider.get_tensorboard_job_args(logdir, additional_args=args), 97 | name='tensorboard', 98 | ) 99 | 100 | # TODO: Add support for `termination_delay_secs`. 101 | experiment.add(xm.AuxiliaryUnitJob(job, termination_delay_secs=0)) 102 | -------------------------------------------------------------------------------- /xmanager/contrib/tpu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Helper module for using TPUs.""" 15 | from typing import List 16 | 17 | 18 | # pylint: disable=line-too-long 19 | def tpuvm_docker_instructions() -> List[str]: 20 | return [ 21 | ( 22 | 'RUN wget' 23 | ' https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/20210525/libtpu.so' 24 | ' -O /lib/libtpu.so' 25 | ), 26 | 'RUN chmod 700 /lib/libtpu.so', 27 | ( 28 | 'RUN wget ' 29 | 'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/20210525/tf_nightly-2.6.0-cp38-cp38-linux_x86_64.whl' 30 | ), 31 | 'RUN pip3 install tf_nightly-2.6.0-cp38-cp38-linux_x86_64.whl', 32 | 'RUN rm tf_nightly-2.6.0-cp38-cp38-linux_x86_64.whl', 33 | ] 34 | -------------------------------------------------------------------------------- /xmanager/generated/README.md: -------------------------------------------------------------------------------- 1 | # `generated` 2 | 3 | This directory contains generated source files, e.g. compiled protobufs. 4 | -------------------------------------------------------------------------------- /xmanager/generated/command_line_pb2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Bazel Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyformat: disable 16 | # Generated by the protocol buffer compiler. DO NOT EDIT! 17 | # pylint: skip-file 18 | # source: src/main/protobuf/command_line.proto 19 | """Generated protocol buffer code.""" 20 | from google.protobuf.internal import builder as _builder 21 | from google.protobuf import descriptor as _descriptor 22 | from google.protobuf import descriptor_pool as _descriptor_pool 23 | from google.protobuf import symbol_database as _symbol_database 24 | # @@protoc_insertion_point(imports) 25 | 26 | _sym_db = _symbol_database.Default() 27 | 28 | 29 | from . import option_filters_pb2 as src_dot_main_dot_protobuf_dot_option__filters__pb2 30 | 31 | 32 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/main/protobuf/command_line.proto\x12\x0c\x63ommand_line\x1a&src/main/protobuf/option_filters.proto\"]\n\x0b\x43ommandLine\x12\x1a\n\x12\x63ommand_line_label\x18\x01 \x01(\t\x12\x32\n\x08sections\x18\x02 \x03(\x0b\x32 .command_line.CommandLineSection\"\x9b\x01\n\x12\x43ommandLineSection\x12\x15\n\rsection_label\x18\x01 \x01(\t\x12-\n\nchunk_list\x18\x02 \x01(\x0b\x32\x17.command_line.ChunkListH\x00\x12/\n\x0boption_list\x18\x03 \x01(\x0b\x32\x18.command_line.OptionListH\x00\x42\x0e\n\x0csection_type\"\x1a\n\tChunkList\x12\r\n\x05\x63hunk\x18\x01 \x03(\t\"2\n\nOptionList\x12$\n\x06option\x18\x01 \x03(\x0b\x32\x14.command_line.Option\"\xac\x01\n\x06Option\x12\x15\n\rcombined_form\x18\x01 \x01(\t\x12\x13\n\x0boption_name\x18\x02 \x01(\t\x12\x14\n\x0coption_value\x18\x03 \x01(\t\x12-\n\x0b\x65\x66\x66\x65\x63t_tags\x18\x04 \x03(\x0e\x32\x18.options.OptionEffectTag\x12\x31\n\rmetadata_tags\x18\x05 \x03(\x0e\x32\x1a.options.OptionMetadataTagB-\n+com.google.devtools.build.lib.runtime.protob\x06proto3') 33 | 34 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) 35 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.main.protobuf.command_line_pb2', globals()) 36 | if _descriptor._USE_C_DESCRIPTORS == False: 37 | 38 | DESCRIPTOR._options = None 39 | DESCRIPTOR._serialized_options = b'\n+com.google.devtools.build.lib.runtime.proto' 40 | _COMMANDLINE._serialized_start=94 41 | _COMMANDLINE._serialized_end=187 42 | _COMMANDLINESECTION._serialized_start=190 43 | _COMMANDLINESECTION._serialized_end=345 44 | _CHUNKLIST._serialized_start=347 45 | _CHUNKLIST._serialized_end=373 46 | _OPTIONLIST._serialized_start=375 47 | _OPTIONLIST._serialized_end=425 48 | _OPTION._serialized_start=428 49 | _OPTION._serialized_end=600 50 | # @@protoc_insertion_point(module_scope) 51 | -------------------------------------------------------------------------------- /xmanager/generated/data_pb2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyformat: disable 16 | # Generated by the protocol buffer compiler. DO NOT EDIT! 17 | # pylint: skip-file 18 | # source: xm_local/storage/data.proto 19 | """Generated protocol buffer code.""" 20 | from google.protobuf.internal import builder as _builder 21 | from google.protobuf import descriptor as _descriptor 22 | from google.protobuf import descriptor_pool as _descriptor_pool 23 | from google.protobuf import symbol_database as _symbol_database 24 | # @@protoc_insertion_point(imports) 25 | 26 | _sym_db = _symbol_database.Default() 27 | 28 | 29 | 30 | 31 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1bxm_local/storage/data.proto\x12\x08xmanager\"\x8a\x01\n\x03Job\x12#\n\x05local\x18\x01 \x01(\x0b\x32\x12.xmanager.LocalJobH\x00\x12\'\n\x04\x63\x61ip\x18\x02 \x01(\x0b\x32\x17.xmanager.AIPlatformJobH\x00\x12-\n\nkubernetes\x18\x03 \x01(\x0b\x32\x17.xmanager.KubernetesJobH\x00\x42\x06\n\x04kind\"7\n\x08LocalJob\x12\x0b\n\x03pid\x18\x01 \x01(\t\x12\x0b\n\x03\x63md\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\"&\n\rAIPlatformJob\x12\x15\n\rresource_name\x18\x01 \x01(\t\"4\n\rKubernetesJob\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x10\n\x08job_name\x18\x02 \x01(\tb\x06proto3') 32 | 33 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) 34 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xm_local.storage.data_pb2', globals()) 35 | if _descriptor._USE_C_DESCRIPTORS == False: 36 | 37 | DESCRIPTOR._options = None 38 | _JOB._serialized_start=42 39 | _JOB._serialized_end=180 40 | _LOCALJOB._serialized_start=182 41 | _LOCALJOB._serialized_end=237 42 | _AIPLATFORMJOB._serialized_start=239 43 | _AIPLATFORMJOB._serialized_end=277 44 | _KUBERNETESJOB._serialized_start=279 45 | _KUBERNETESJOB._serialized_end=331 46 | # @@protoc_insertion_point(module_scope) 47 | -------------------------------------------------------------------------------- /xmanager/generated/invocation_policy_pb2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The Bazel Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyformat: disable 16 | # Generated by the protocol buffer compiler. DO NOT EDIT! 17 | # pylint: skip-file 18 | # source: src/main/protobuf/invocation_policy.proto 19 | """Generated protocol buffer code.""" 20 | from google.protobuf.internal import builder as _builder 21 | from google.protobuf import descriptor as _descriptor 22 | from google.protobuf import descriptor_pool as _descriptor_pool 23 | from google.protobuf import symbol_database as _symbol_database 24 | # @@protoc_insertion_point(imports) 25 | 26 | _sym_db = _symbol_database.Default() 27 | 28 | 29 | 30 | 31 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)src/main/protobuf/invocation_policy.proto\x12\x17\x62laze.invocation_policy\"N\n\x10InvocationPolicy\x12:\n\rflag_policies\x18\x01 \x03(\x0b\x32#.blaze.invocation_policy.FlagPolicy\"\xb4\x02\n\nFlagPolicy\x12\x11\n\tflag_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ommands\x18\x02 \x03(\t\x12\x36\n\tset_value\x18\x03 \x01(\x0b\x32!.blaze.invocation_policy.SetValueH\x00\x12:\n\x0buse_default\x18\x04 \x01(\x0b\x32#.blaze.invocation_policy.UseDefaultH\x00\x12\x42\n\x0f\x64isallow_values\x18\x05 \x01(\x0b\x32\'.blaze.invocation_policy.DisallowValuesH\x00\x12<\n\x0c\x61llow_values\x18\x06 \x01(\x0b\x32$.blaze.invocation_policy.AllowValuesH\x00\x42\x0b\n\toperation\"\xc6\x01\n\x08SetValue\x12\x12\n\nflag_value\x18\x01 \x03(\t\x12<\n\x08\x62\x65havior\x18\x04 \x01(\x0e\x32*.blaze.invocation_policy.SetValue.Behavior\"\\\n\x08\x42\x65havior\x12\r\n\tUNDEFINED\x10\x00\x12\x13\n\x0f\x41LLOW_OVERRIDES\x10\x01\x12\n\n\x06\x41PPEND\x10\x02\x12 \n\x1c\x46INAL_VALUE_IGNORE_OVERRIDES\x10\x03J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\x0c\n\nUseDefault\"\x97\x01\n\x0e\x44isallowValues\x12\x19\n\x11\x64isallowed_values\x18\x01 \x03(\t\x12\x13\n\tnew_value\x18\x03 \x01(\tH\x00\x12:\n\x0buse_default\x18\x04 \x01(\x0b\x32#.blaze.invocation_policy.UseDefaultH\x00\x42\x13\n\x11replacement_valueJ\x04\x08\x02\x10\x03\"\x91\x01\n\x0b\x41llowValues\x12\x16\n\x0e\x61llowed_values\x18\x01 \x03(\t\x12\x13\n\tnew_value\x18\x03 \x01(\tH\x00\x12:\n\x0buse_default\x18\x04 \x01(\x0b\x32#.blaze.invocation_policy.UseDefaultH\x00\x42\x13\n\x11replacement_valueJ\x04\x08\x02\x10\x03\x42-\n+com.google.devtools.build.lib.runtime.proto') 32 | 33 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) 34 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.main.protobuf.invocation_policy_pb2', globals()) 35 | if _descriptor._USE_C_DESCRIPTORS == False: 36 | 37 | DESCRIPTOR._options = None 38 | DESCRIPTOR._serialized_options = b'\n+com.google.devtools.build.lib.runtime.proto' 39 | _INVOCATIONPOLICY._serialized_start=70 40 | _INVOCATIONPOLICY._serialized_end=148 41 | _FLAGPOLICY._serialized_start=151 42 | _FLAGPOLICY._serialized_end=459 43 | _SETVALUE._serialized_start=462 44 | _SETVALUE._serialized_end=660 45 | _SETVALUE_BEHAVIOR._serialized_start=556 46 | _SETVALUE_BEHAVIOR._serialized_end=648 47 | _USEDEFAULT._serialized_start=662 48 | _USEDEFAULT._serialized_end=674 49 | _DISALLOWVALUES._serialized_start=677 50 | _DISALLOWVALUES._serialized_end=828 51 | _ALLOWVALUES._serialized_start=831 52 | _ALLOWVALUES._serialized_end=976 53 | # @@protoc_insertion_point(module_scope) 54 | -------------------------------------------------------------------------------- /xmanager/generated/option_filters_pb2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Bazel Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyformat: disable 16 | # Generated by the protocol buffer compiler. DO NOT EDIT! 17 | # pylint: skip-file 18 | # source: src/main/protobuf/option_filters.proto 19 | """Generated protocol buffer code.""" 20 | from google.protobuf.internal import builder as _builder 21 | from google.protobuf import descriptor as _descriptor 22 | from google.protobuf import descriptor_pool as _descriptor_pool 23 | from google.protobuf import symbol_database as _symbol_database 24 | # @@protoc_insertion_point(imports) 25 | 26 | _sym_db = _symbol_database.Default() 27 | 28 | 29 | 30 | 31 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&src/main/protobuf/option_filters.proto\x12\x07options*\xea\x02\n\x0fOptionEffectTag\x12\x0b\n\x07UNKNOWN\x10\x00\x12\t\n\x05NO_OP\x10\x01\x12\x1b\n\x17LOSES_INCREMENTAL_STATE\x10\x02\x12\x12\n\x0e\x43HANGES_INPUTS\x10\x03\x12\x13\n\x0f\x41\x46\x46\x45\x43TS_OUTPUTS\x10\x04\x12\x18\n\x14\x42UILD_FILE_SEMANTICS\x10\x05\x12 \n\x1c\x42\x41ZEL_INTERNAL_CONFIGURATION\x10\x06\x12\x18\n\x14LOADING_AND_ANALYSIS\x10\x07\x12\r\n\tEXECUTION\x10\x08\x12\'\n#HOST_MACHINE_RESOURCE_OPTIMIZATIONS\x10\t\x12\x15\n\x11\x45\x41GERNESS_TO_EXIT\x10\n\x12\x14\n\x10\x42\x41ZEL_MONITORING\x10\x0b\x12\x13\n\x0fTERMINAL_OUTPUT\x10\x0c\x12\x18\n\x14\x41\x43TION_COMMAND_LINES\x10\r\x12\x0f\n\x0bTEST_RUNNER\x10\x0e*\xb2\x01\n\x11OptionMetadataTag\x12\x10\n\x0c\x45XPERIMENTAL\x10\x00\x12\x17\n\x13INCOMPATIBLE_CHANGE\x10\x01\x12\x0e\n\nDEPRECATED\x10\x02\x12\n\n\x06HIDDEN\x10\x03\x12\x0c\n\x08INTERNAL\x10\x04\x12\x1b\n\x17\x45XPLICIT_IN_OUTPUT_PATH\x10\x06\"\x04\x08\x05\x10\x05*%TRIGGERED_BY_ALL_INCOMPATIBLE_CHANGESB*\n(com.google.devtools.common.options.protob\x06proto3') 32 | 33 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) 34 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.main.protobuf.option_filters_pb2', globals()) 35 | if _descriptor._USE_C_DESCRIPTORS == False: 36 | 37 | DESCRIPTOR._options = None 38 | DESCRIPTOR._serialized_options = b'\n(com.google.devtools.common.options.proto' 39 | _OPTIONEFFECTTAG._serialized_start=52 40 | _OPTIONEFFECTTAG._serialized_end=414 41 | _OPTIONMETADATATAG._serialized_start=417 42 | _OPTIONMETADATATAG._serialized_end=595 43 | # @@protoc_insertion_point(module_scope) 44 | -------------------------------------------------------------------------------- /xmanager/module_lazy_loader/lazy_loader_module_attrs_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test for custom lazy loader definition for xmanager sub-package __init__.py files.""" 15 | 16 | import unittest 17 | from xmanager.module_lazy_loader import module_lazy_loader 18 | 19 | 20 | class LazyLoaderModuleAttrsTest(unittest.TestCase): 21 | test_lazy_loader = module_lazy_loader.XManagerLazyLoader( 22 | __name__, 23 | apis=[ 24 | module_lazy_loader.XManagerAPI( 25 | module="xmanager.module_lazy_loader.module_lazy_loader", 26 | symbol="XManagerAPI", 27 | alias="boo", 28 | ), 29 | module_lazy_loader.XManagerAPI( 30 | module="xmanager.module_lazy_loader.module_lazy_loader", 31 | symbol="XManagerAPI", 32 | ), 33 | module_lazy_loader.XManagerAPI( 34 | module="xmanager.module_lazy_loader.module_lazy_loader", 35 | alias="baz", 36 | ), 37 | module_lazy_loader.XManagerAPI( 38 | module="xmanager.module_lazy_loader.module_lazy_loader", 39 | ), 40 | ], 41 | ) 42 | 43 | def test_all(self): 44 | self.assertCountEqual( 45 | self.test_lazy_loader.get_module_all(), 46 | ["boo", "XManagerAPI", "baz", "module_lazy_loader"], 47 | ) 48 | 49 | def test_dir(self): 50 | self.assertCountEqual( 51 | self.test_lazy_loader.get_module_dir()(), 52 | ["boo", "XManagerAPI", "baz", "module_lazy_loader"], 53 | ) 54 | 55 | def test_getattr(self): 56 | local_getattr = self.test_lazy_loader.get_module_getattr() 57 | self.assertEqual(local_getattr("boo"), module_lazy_loader.XManagerAPI) 58 | self.assertEqual( 59 | local_getattr("XManagerAPI"), module_lazy_loader.XManagerAPI 60 | ) 61 | self.assertEqual(local_getattr("baz"), module_lazy_loader) 62 | self.assertEqual(local_getattr("module_lazy_loader"), module_lazy_loader) 63 | self.assertRaises(AttributeError, local_getattr, "this_attr_does_not_exist") 64 | 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /xmanager/vizier/vizier_cloud/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Vizier API for launching Vertex-Vizier explored Experiment for OSS.""" 15 | 16 | from xmanager.vizier.vizier_cloud.study_factory import NewStudy 17 | from xmanager.vizier.vizier_cloud.vizier_exploration import VizierExploration 18 | -------------------------------------------------------------------------------- /xmanager/vizier/vizier_cloud/study_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Factory classes for generating study of cloud Vertex Vizier.""" 15 | 16 | import abc 17 | from typing import Optional 18 | 19 | from google.cloud import aiplatform_v1beta1 as aip 20 | 21 | from xmanager.cloud import auth 22 | 23 | _DEFAULT_LOCATION = 'us-central1' 24 | 25 | 26 | class StudyFactory(abc.ABC): 27 | """Abstract class representing vizier study generator.""" 28 | 29 | vz_client: aip.VizierServiceClient 30 | study_config: aip.StudySpec 31 | num_trials_total: int 32 | display_name: str 33 | 34 | # TODO: Once vertex pyvizier is available, we should replace 35 | # aip.StudySpec with it. 36 | # display_name and num_trials_total are supposed to be set into the study 37 | # config, which is not supported by aip.StudySpec currently. But should be 38 | # settable when pyvizier.StudyConfig is available. 39 | def __init__( 40 | self, 41 | study_config: aip.StudySpec, 42 | num_trials_total: int, 43 | display_name: str, 44 | location: str, 45 | ) -> None: 46 | super().__init__() 47 | self.study_config = study_config 48 | self.num_trials_total = num_trials_total 49 | self.display_name = display_name 50 | self.vz_client = aip.VizierServiceClient( 51 | client_options=dict( 52 | api_endpoint=f'{location}-aiplatform.googleapis.com' 53 | ) 54 | ) 55 | 56 | @abc.abstractmethod 57 | def study(self) -> str: 58 | raise NotImplementedError 59 | 60 | 61 | class NewStudy(StudyFactory): 62 | """Vizier study generator that generates new study from given config.""" 63 | 64 | project: str 65 | location: str 66 | 67 | # `num_trials_total` is a required field. Default it to 0 to unbreak the 68 | # soon-to-deprecate VizierExploration users. 69 | # `display_name` is optional for user to customize, if not set, XM will 70 | # set it with experiment information 71 | def __init__( 72 | self, 73 | study_config: aip.StudySpec, 74 | num_trials_total: int = 0, 75 | display_name: Optional[str] = None, 76 | project: Optional[str] = None, 77 | location: Optional[str] = None, 78 | ) -> None: 79 | self.project = project or auth.get_project_name() 80 | self.location = location or _DEFAULT_LOCATION 81 | 82 | super().__init__( 83 | study_config, num_trials_total, display_name or '', self.location 84 | ) 85 | 86 | def study(self) -> str: 87 | return self.vz_client.create_study( 88 | parent=f'projects/{self.project}/locations/{self.location}', 89 | study=aip.Study( 90 | display_name=self.display_name, study_spec=self.study_config 91 | ), 92 | ).name 93 | -------------------------------------------------------------------------------- /xmanager/vizier/vizier_cloud/vizier_exploration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Interface for launching Vizier Explorations using Vertex Vizier.""" 15 | 16 | from typing import Any, Dict 17 | 18 | from xmanager import xm 19 | from xmanager.vizier.vizier_cloud import study_factory as sf 20 | from xmanager.vizier.vizier_cloud import vizier_controller 21 | 22 | _DEFAULT_LOCATION = 'us-central1' 23 | 24 | 25 | # TODO: Add vizier_controller as auxiliary Job generator. 26 | class VizierExploration: 27 | """An API for launching experiment as a Vizier-based Exploration.""" 28 | 29 | def __init__( 30 | self, 31 | experiment: xm.Experiment, 32 | job: xm.JobType, 33 | study_factory: sf.StudyFactory, 34 | num_trials_total: int, 35 | num_parallel_trial_runs: int, 36 | ) -> None: 37 | """Create a VizierExploration. 38 | 39 | Args: 40 | experiment: the experiment who does the exploration. 41 | job: a job to run. 42 | study_factory: the VizierStudyFactory used to create or load the study. 43 | num_trials_total: total number of trials the experiment want to explore. 44 | num_parallel_trial_runs: number of parallel runs evaluating the trials. 45 | """ 46 | 47 | async def work_unit_generator( 48 | work_unit: xm.WorkUnit, vizier_params: Dict[str, Any] 49 | ): 50 | work_unit.add(job, self._to_job_params(vizier_params)) 51 | 52 | if not study_factory.display_name: 53 | study_factory.display_name = f'X{experiment.experiment_id}' 54 | 55 | self._controller = vizier_controller.VizierController( 56 | experiment, 57 | work_unit_generator, 58 | study_factory.vz_client, 59 | study_factory.study(), 60 | num_trials_total, 61 | num_parallel_trial_runs, 62 | ) 63 | 64 | def _to_job_params(self, vizier_params: Dict[str, Any]) -> Dict[str, Any]: 65 | # TODO: unflatten parameters for JobGroup case (currently this 66 | # works for xm.Job). 67 | # For example: transform 68 | # {'learner.args.learning_rate': 0.1} 69 | # to 70 | # {'learner': {'args': {'learning_rate': 0.1}}} 71 | return {'args': vizier_params} 72 | 73 | def launch(self, **kwargs) -> None: 74 | self._controller.run(**kwargs) 75 | -------------------------------------------------------------------------------- /xmanager/vizier/vizier_cloud/vizier_worker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Run Job as a Vizier worker to manager WorkUnit Vizier interaction.""" 15 | import re 16 | from typing import Dict, Optional 17 | 18 | from absl import logging 19 | from google.cloud import aiplatform_v1beta1 as aip 20 | 21 | _TRIAL_NAME_REGEX = ( 22 | r'projects\/[^\/]+\/locations\/[^\/]+\/studies\/[^\/]+\/trials\/[^\/]+' 23 | ) 24 | 25 | 26 | class VizierWorker: 27 | """Worker that manage interaction between job and Vizier.""" 28 | 29 | def __init__(self, trial_name: str) -> None: 30 | if not re.match(_TRIAL_NAME_REGEX, trial_name): 31 | raise Exception( 32 | 'The trial_name must be in the form: ' 33 | 'projects/{project}/locations/{location}/' 34 | 'studies/{study}/trials/{trial}' 35 | ) 36 | 37 | self._trial_name = trial_name 38 | 39 | location = trial_name.split('/')[3] 40 | self._vz_client = aip.VizierServiceClient( 41 | client_options={ 42 | 'api_endpoint': f'{location}-aiplatform.googleapis.com', 43 | } 44 | ) 45 | 46 | def add_trial_measurement(self, step: int, metrics: Dict[str, float]) -> None: 47 | """Add trial measurements to Vizier.""" 48 | self._vz_client.add_trial_measurement( 49 | request=aip.AddTrialMeasurementRequest( 50 | trial_name=self._trial_name, 51 | measurement=aip.Measurement( 52 | step_count=step, 53 | metrics=[ 54 | aip.Measurement.Metric(metric_id=k, value=v) 55 | for k, v in metrics.items() 56 | ], 57 | ), 58 | ) 59 | ) 60 | logging.info('Step %d Metric %s is reported', step, metrics) 61 | 62 | def complete_trial(self, infeasible_reason: Optional[str] = None) -> None: 63 | """Complete a trial.""" 64 | self._vz_client.complete_trial( 65 | request=aip.CompleteTrialRequest( 66 | name=self._trial_name, 67 | trial_infeasible=infeasible_reason is not None, 68 | infeasible_reason=infeasible_reason, 69 | ) 70 | ) 71 | logging.info('Trial %s is completed', self._trial_name) 72 | -------------------------------------------------------------------------------- /xmanager/xm/README.md: -------------------------------------------------------------------------------- 1 | # `xm` 2 | 3 | This directory contains the Launch API base classes and common implementations 4 | that various experiment schedulers (local or provided as a service) and 5 | executors (Docker, Kubernetes, Vertex AI etc) implement. 6 | -------------------------------------------------------------------------------- /xmanager/xm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """XManager client API. 15 | 16 | Provides XManager public API for configuring and launching experiments. 17 | """ 18 | 19 | from xmanager.xm import job_operators 20 | from xmanager.xm.compute_units import * 21 | from xmanager.xm.core import AuxiliaryUnitJob 22 | from xmanager.xm.core import AuxiliaryUnitRole 23 | from xmanager.xm.core import DebugInterrupt 24 | from xmanager.xm.core import Experiment 25 | from xmanager.xm.core import ExperimentUnit 26 | from xmanager.xm.core import ExperimentUnitError 27 | from xmanager.xm.core import ExperimentUnitFailedError 28 | from xmanager.xm.core import ExperimentUnitNotCompletedError 29 | from xmanager.xm.core import ExperimentUnitRole 30 | from xmanager.xm.core import ExperimentUnitStatus 31 | from xmanager.xm.core import Importance 32 | from xmanager.xm.core import LaunchedJob 33 | from xmanager.xm.core import NotFoundError 34 | from xmanager.xm.core import ReloadError 35 | from xmanager.xm.core import WorkUnit 36 | from xmanager.xm.core import WorkUnitCompletedAwaitable 37 | from xmanager.xm.core import WorkUnitRole 38 | from xmanager.xm.executables import BazelBinary 39 | from xmanager.xm.executables import BazelContainer 40 | from xmanager.xm.executables import Binary 41 | from xmanager.xm.executables import BinaryDependency 42 | from xmanager.xm.executables import CommandList 43 | from xmanager.xm.executables import Container 44 | from xmanager.xm.executables import Dockerfile 45 | from xmanager.xm.executables import ModuleName 46 | from xmanager.xm.executables import PythonContainer 47 | from xmanager.xm.job_blocks import Constraint 48 | from xmanager.xm.job_blocks import Executable 49 | from xmanager.xm.job_blocks import ExecutableSpec 50 | from xmanager.xm.job_blocks import Executor 51 | from xmanager.xm.job_blocks import ExecutorSpec 52 | from xmanager.xm.job_blocks import get_args_for_all_jobs 53 | from xmanager.xm.job_blocks import Job 54 | from xmanager.xm.job_blocks import JobConfig 55 | from xmanager.xm.job_blocks import JobGeneratorType 56 | from xmanager.xm.job_blocks import JobGroup 57 | from xmanager.xm.job_blocks import JobType 58 | from xmanager.xm.job_blocks import merge_args 59 | from xmanager.xm.job_blocks import Packageable 60 | from xmanager.xm.job_blocks import SequentialArgs 61 | from xmanager.xm.job_blocks import UserArgs 62 | from xmanager.xm.metadata_context import ContextAnnotations 63 | from xmanager.xm.metadata_context import MetadataContext 64 | from xmanager.xm.packagables import bazel_binary 65 | from xmanager.xm.packagables import bazel_container 66 | from xmanager.xm.packagables import binary 67 | from xmanager.xm.packagables import container 68 | from xmanager.xm.packagables import dockerfile_container 69 | from xmanager.xm.packagables import python_container 70 | from xmanager.xm.resources import AcceleratorType 71 | from xmanager.xm.resources import Architecture 72 | from xmanager.xm.resources import GpuType 73 | from xmanager.xm.resources import InvalidTpuTopologyError 74 | from xmanager.xm.resources import JobRequirements 75 | from xmanager.xm.resources import ResourceDict 76 | from xmanager.xm.resources import ResourceQuantity 77 | from xmanager.xm.resources import ResourceType 78 | from xmanager.xm.resources import ServiceTier 79 | from xmanager.xm.resources import Topology 80 | from xmanager.xm.resources import TpuType 81 | from xmanager.xm.utils import run_in_asyncio_loop 82 | from xmanager.xm.utils import ShellSafeArg 83 | -------------------------------------------------------------------------------- /xmanager/xm/async_packager_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pickle 16 | from typing import Sequence 17 | import unittest 18 | 19 | from xmanager.xm import async_packager 20 | from xmanager.xm import job_blocks 21 | from xmanager.xm import utils 22 | 23 | 24 | def _package_batch( 25 | packageables: Sequence[job_blocks.Packageable], 26 | ) -> Sequence[job_blocks.Executable]: 27 | return [ 28 | job_blocks.Executable(name=packageable.executable_spec.name) 29 | for packageable in packageables 30 | ] 31 | 32 | 33 | class _TestExecutableSpec(job_blocks.ExecutableSpec): 34 | 35 | def __init__(self, name: str) -> None: 36 | self._name = name 37 | 38 | @property 39 | def name(self) -> str: 40 | return self._name 41 | 42 | 43 | def _make_packageable(name: str) -> job_blocks.Packageable: 44 | return job_blocks.Packageable( 45 | executable_spec=_TestExecutableSpec(name), 46 | executor_spec=job_blocks.ExecutorSpec(), 47 | ) 48 | 49 | 50 | class AsyncPackagerTest(unittest.TestCase): 51 | 52 | @utils.run_in_asyncio_loop 53 | async def test_async_packager_end_to_end(self): 54 | packager = async_packager.AsyncPackager(_package_batch) 55 | executable1 = packager.add(_make_packageable('1')) 56 | executable2 = packager.add(_make_packageable('2')) 57 | packager.package() 58 | self.assertEqual((await executable1).name, '1') 59 | self.assertEqual((await executable2).name, '2') 60 | 61 | @utils.run_in_asyncio_loop 62 | async def test_package_with_extra_packageables(self): 63 | packager = async_packager.AsyncPackager(_package_batch) 64 | async_executable = packager.add(_make_packageable('async')) 65 | [extra_executable] = packager.package((_make_packageable('extra'),)) 66 | self.assertEqual((await async_executable).name, 'async') 67 | self.assertEqual(extra_executable.name, 'extra') 68 | 69 | @utils.run_in_asyncio_loop 70 | async def test_package_is_required(self): 71 | packager = async_packager.AsyncPackager(_package_batch) 72 | executable = packager.add(_make_packageable('')) 73 | with self.assertRaises(async_packager.PackageHasNotBeenCalledError): 74 | await executable 75 | 76 | def test_awaitable_is_picklable(self): 77 | packager = async_packager.AsyncPackager(_package_batch) 78 | executable = packager.add(_make_packageable('')) 79 | packager.package() 80 | executable_str = pickle.dumps(executable) 81 | 82 | # Wait for the executable in a separate event loop, which did not even exist 83 | # when we requested packaging. 84 | @utils.run_in_asyncio_loop 85 | async def wait_for_it(): 86 | await pickle.loads(executable_str) 87 | 88 | wait_for_it() 89 | 90 | def test_awaitable_is_repeatedly_picklable(self): 91 | packager = async_packager.AsyncPackager(_package_batch) 92 | executable = packager.add(_make_packageable('')) 93 | packager.package() 94 | executable_str = pickle.dumps(executable) 95 | executable_reconstructed = pickle.loads(executable_str) 96 | executable_str2 = pickle.dumps(executable_reconstructed) 97 | self.assertEqual(executable_str, executable_str2) 98 | 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /xmanager/xm/compute_units.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Defines convenience constants/functions for converting various units.""" 15 | 16 | # pylint: disable=invalid-name 17 | vCPU = 1.0 # Virtual CPU 18 | 19 | # TODO: Explicit types are to work around a type inference issue. 20 | KiB: int = 2**10 # kibibyte 21 | MiB: int = 2**20 # mibibyte 22 | GiB: int = 2**30 # gibibyte 23 | TiB: int = 2**40 # tebibyte 24 | PiB: int = 2**50 # pebibyte 25 | 26 | KB = 10**3 # kilobyte 27 | MB = 10**6 # megabyte 28 | GB = 10**9 # gigabyte 29 | TB = 10**12 # terabyte 30 | PB = 10**15 # petabyte 31 | # pylint: enable=invalid-name 32 | -------------------------------------------------------------------------------- /xmanager/xm/executables_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for executables.""" 15 | 16 | import os 17 | import unittest 18 | 19 | from xmanager.xm import executables 20 | from xmanager.xm import utils 21 | 22 | 23 | class ExecutablesTest(unittest.TestCase): 24 | 25 | def test_python_container_name(self): 26 | executable = executables.PythonContainer( 27 | entrypoint=executables.ModuleName('module'), 28 | path='/home/user/project/', 29 | ) 30 | 31 | self.assertEqual(executable.name, 'project') 32 | 33 | def test_container_name(self): 34 | executable = executables.Container( 35 | image_path='/home/user/project/image.tar' 36 | ) 37 | 38 | self.assertEqual(executable.name, 'image_tar') 39 | 40 | def test_binary_name(self): 41 | executable = executables.Binary(path='./binary') 42 | 43 | self.assertEqual(executable.name, 'binary') 44 | 45 | def test_bazel_container_name(self): 46 | executable = executables.BazelContainer(label='//container') 47 | 48 | self.assertEqual(executable.name, 'container') 49 | 50 | def test_bazel_binary_name(self): 51 | executable = executables.BazelBinary(label=':binary') 52 | 53 | self.assertEqual(executable.name, '_binary') 54 | 55 | def test_dockerfile_defaults(self): 56 | root = utils.resolve_path_relative_to_launcher('.') 57 | 58 | spec = executables.Dockerfile() 59 | self.assertEqual(spec.path, root) 60 | self.assertEqual(spec.dockerfile, os.path.join(root, 'Dockerfile')) 61 | 62 | 63 | if __name__ == '__main__': 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /xmanager/xm/id_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A utility to predict IDs that would be assigned upon object creation. 15 | 16 | This class helps to untie this chicken and egg problem. Sometimes to create an 17 | object (for example WorkUnit) one may need to know its ID beforehand (for 18 | example to generate a checkpoint path). But the ID would only be assigned by a 19 | backend upon object creation. In ideal world we would rewrite the backend to 20 | allow ID reservation. This module provides a temporary solution which does the 21 | reservation on client-side. Following assumptions are made: 22 | * Ids are assigned sequentially by the backend, starting from some number. 23 | * Only one process at a time creates the objects. Any races are resolved only 24 | within that process. 25 | 26 | Usage: 27 | predictor = Predictor() 28 | 29 | # Obtain the ID and asynchronously construct the object. 30 | next_id = predictor.reserve_id() 31 | job = Job(args={'checkpoint_path': f'/tmp/{next_id}'}) 32 | 33 | # Wait untill all objects with smaller IDs are submitted. 34 | async with predictor.submit_id(next_id): 35 | # And submit it to the backend. 36 | submit(job) 37 | 38 | If the submission fails, the sequence is considered broken and following calls 39 | to the predictor would raise an error. 40 | """ 41 | 42 | import asyncio 43 | import threading 44 | from typing import AsyncIterable 45 | 46 | import async_generator 47 | 48 | 49 | class BrokenSequenceError(RuntimeError): 50 | """The ID would never be ready to submit.""" 51 | 52 | 53 | class Predictor: 54 | """Predicts IDs that would be assigned on object creation. 55 | 56 | This class is thread safe and async Python friendly. It must be constructed 57 | from inside asyncio event loop. 58 | """ 59 | 60 | def __init__(self, next_id: int) -> None: 61 | """Initializes the predictor. 62 | 63 | Args: 64 | next_id: The first available ID that would be assigned to the next object. 65 | """ 66 | self._next_id = next_id 67 | # We use threading.Lock to allow calling reserve_id from non async context. 68 | # Note that no long operations are done under this lock. 69 | self._next_id_lock = threading.Lock() 70 | 71 | self._is_broken = False 72 | self._last_created_id = next_id - 1 73 | self._last_created_id_condition = asyncio.Condition() 74 | 75 | def reserve_id(self) -> int: 76 | """Returns the next ID.""" 77 | with self._next_id_lock: 78 | next_id = self._next_id 79 | self._next_id += 1 80 | return next_id 81 | 82 | @async_generator.asynccontextmanager 83 | async def submit_id(self, id_to_submit: int) -> AsyncIterable[None]: 84 | """Waits until the ID can be submitted and marks it as such. 85 | 86 | A context manager which would wait for all smaller IDs to be submitted on 87 | entry and marks it as submitted on exit. As a result all submissions are 88 | serialized in the correct order and receive the right ID from the backend. 89 | 90 | Args: 91 | id_to_submit: The id to be submitted. 92 | 93 | Yields: 94 | Yields when it is time to send the request to the backend. 95 | """ 96 | async with self._last_created_id_condition: 97 | await self._last_created_id_condition.wait_for( 98 | lambda: self._is_broken or self._last_created_id == id_to_submit - 1 99 | ) 100 | 101 | try: 102 | if self._is_broken: 103 | raise BrokenSequenceError( 104 | f'Id {id} would never be ready to submit as' 105 | ' submission of the previous one has failed' 106 | ) 107 | yield 108 | self._last_created_id = id_to_submit 109 | except: 110 | self._is_broken = True 111 | raise 112 | finally: 113 | self._last_created_id_condition.notify_all() 114 | -------------------------------------------------------------------------------- /xmanager/xm/id_predictor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import asyncio 16 | import unittest 17 | 18 | from xmanager.xm import id_predictor 19 | from xmanager.xm import utils 20 | 21 | 22 | class IdPredictorTest(unittest.TestCase): 23 | 24 | @utils.run_in_asyncio_loop 25 | async def test_first_id_is_correct(self): 26 | """Simple Predictor usage example.""" 27 | predictor = id_predictor.Predictor(next_id=1) 28 | 29 | first_id = predictor.reserve_id() 30 | async with predictor.submit_id(first_id): 31 | self.assertEqual(first_id, 1) 32 | 33 | @utils.run_in_asyncio_loop 34 | async def test_ids_are_submitted_in_order(self): 35 | predictor = id_predictor.Predictor(next_id=1) 36 | 37 | self.assertEqual(predictor.reserve_id(), 1) 38 | self.assertEqual(predictor.reserve_id(), 2) 39 | self.assertEqual(predictor.reserve_id(), 3) 40 | 41 | submitted_ids = [] 42 | 43 | async def submit(id_to_submit): 44 | async with predictor.submit_id(id_to_submit): 45 | submitted_ids.append(id_to_submit) 46 | 47 | await asyncio.gather(submit(3), submit(2), submit(1)) 48 | 49 | self.assertEqual(submitted_ids, [1, 2, 3]) 50 | 51 | @utils.run_in_asyncio_loop 52 | async def test_broken_sequence(self): 53 | predictor = id_predictor.Predictor(next_id=1) 54 | 55 | self.assertEqual(predictor.reserve_id(), 1) 56 | self.assertEqual(predictor.reserve_id(), 2) 57 | 58 | with self.assertRaises(RuntimeError): 59 | async with predictor.submit_id(1): 60 | raise RuntimeError('Id was eaten by a giant space ant.') 61 | 62 | with self.assertRaises(id_predictor.BrokenSequenceError): 63 | async with predictor.submit_id(2): 64 | pass 65 | 66 | 67 | if __name__ == '__main__': 68 | unittest.main() 69 | -------------------------------------------------------------------------------- /xmanager/xm/job_operators_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | from unittest import mock 17 | import uuid 18 | 19 | from xmanager import xm_mock 20 | from xmanager.xm import job_blocks 21 | from xmanager.xm import job_operators 22 | 23 | 24 | TEST_UUID = uuid.UUID('40c2fc75-9424-4e33-a929-2e0cc631dccf') 25 | 26 | 27 | def construct_job(name=None): 28 | return job_blocks.Job( 29 | name=name, 30 | executable=xm_mock.MockExecutable(), 31 | executor=xm_mock.MockExecutor(), 32 | ) 33 | 34 | 35 | class JobOperatorsTest(unittest.TestCase): 36 | 37 | def test_collect_jobs_by_filter_gathers_matches(self): 38 | job_group = job_blocks.JobGroup( 39 | foo=construct_job('foo'), 40 | bar=construct_job('bar'), 41 | baz=construct_job('baz'), 42 | ) 43 | 44 | self.assertEqual( 45 | job_operators.collect_jobs_by_filter( 46 | job_group, 47 | predicate=lambda job: job.name in ['foo', 'baz'], 48 | ), 49 | [job_group.jobs['foo'], job_group.jobs['baz']], 50 | ) 51 | 52 | def test_flatten_jobs_traverses_nested_groups(self): 53 | baz = construct_job('baz') 54 | foo = construct_job('foo') 55 | job_group = job_blocks.JobGroup( 56 | foo=foo, 57 | bar=job_blocks.JobGroup(baz=baz), 58 | ) 59 | 60 | self.assertEqual( 61 | job_operators.flatten_jobs(job_group), 62 | [foo, baz], 63 | ) 64 | 65 | @mock.patch.object(uuid, 'uuid4', return_value=TEST_UUID) 66 | def test_aggregate_constraint_cliques(self, _): 67 | outer_1 = construct_job('outer_1') 68 | inner_1 = construct_job('inner_1') 69 | inner_2 = construct_job('inner_2') 70 | constraint_a = xm_mock.MockConstraint('A') 71 | constraint_b = xm_mock.MockConstraint('B') 72 | constraint_c = xm_mock.MockConstraint('C') 73 | job_group = job_blocks.JobGroup( 74 | outer_1=outer_1, 75 | outer_2=job_blocks.JobGroup( 76 | inner_1=inner_1, 77 | inner_2=inner_2, 78 | constraints=[constraint_b, constraint_c], 79 | ), 80 | constraints=[constraint_a], 81 | ) 82 | 83 | group0_name = f'jobgroup_0_{TEST_UUID.hex}' 84 | group1_name = f'jobgroup_1_{TEST_UUID.hex}' 85 | self.assertEqual( 86 | job_operators.aggregate_constraint_cliques(job_group), 87 | [ 88 | job_operators.ConstraintClique( 89 | constraint=constraint_a, 90 | jobs=[outer_1, inner_1, inner_2], 91 | group_name=group0_name, 92 | size=2, 93 | parent_group_name=None, 94 | ), 95 | job_operators.ConstraintClique( 96 | constraint=constraint_b, 97 | jobs=[inner_1, inner_2], 98 | group_name=group1_name, 99 | size=2, 100 | parent_group_name=group0_name, 101 | ), 102 | job_operators.ConstraintClique( 103 | constraint=constraint_c, 104 | jobs=[inner_1, inner_2], 105 | group_name=group1_name, 106 | size=2, 107 | parent_group_name=group0_name, 108 | ), 109 | ], 110 | ) 111 | 112 | 113 | if __name__ == '__main__': 114 | unittest.main() 115 | -------------------------------------------------------------------------------- /xmanager/xm/metadata_context.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """An interface to manipulate and access experiment metadata. 15 | 16 | Metadata is attached to a context and the context may belong to an experiment 17 | or a work unit. 18 | """ 19 | 20 | from typing import Any, Collection, List, Mapping, Set 21 | import attr 22 | 23 | 24 | class ContextAnnotations: 25 | """Interface for managing annotations. 26 | 27 | Annotations are user-supplied attributes of a context, such as title or tags. 28 | Default method implementations are intentionally left blank so that backends 29 | only have to implement the subset they support. 30 | """ 31 | 32 | @property 33 | def title(self) -> str: 34 | """An experiment title. 35 | 36 | To differentiate experiments from each other they can be given a human 37 | readable title. Same title can be reused for multiple experiments. 38 | """ 39 | return '' 40 | 41 | def set_title(self, title: str) -> None: 42 | """Sets the context title.""" 43 | 44 | 45 | @attr.s(auto_attribs=True) 46 | class MetadataContext: 47 | """Interface for managing metadata. 48 | 49 | The metadata context could be attached to an experiment or a work unit. 50 | 51 | 52 | Attributes: 53 | creator: The username of the creator of this context. 54 | annotations: User-modifiable annotations. 55 | """ 56 | 57 | creator: str 58 | annotations: ContextAnnotations 59 | -------------------------------------------------------------------------------- /xmanager/xm/packagables_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for packagables.""" 15 | 16 | import unittest 17 | 18 | from xmanager.xm import executables 19 | from xmanager.xm import job_blocks 20 | from xmanager.xm import packagables 21 | from xmanager.xm_local import executors 22 | 23 | 24 | class PackagablesTest(unittest.TestCase): 25 | 26 | def test_minimal_executable_spec(self): 27 | expected = job_blocks.Packageable( 28 | executable_spec=executables.BazelBinary(label='label'), 29 | executor_spec=executors.Local.Spec(), 30 | args=[], 31 | env_vars={}, 32 | ) 33 | 34 | actual = packagables.bazel_binary(executors.Local.Spec(), label='label') 35 | 36 | self.assertEqual(actual, expected) 37 | 38 | def test_pkg_args_env_vars(self): 39 | expected = job_blocks.Packageable( 40 | executable_spec=executables.BazelBinary(label='label'), 41 | executor_spec=executors.Local.Spec(), 42 | args=['-f'], 43 | env_vars={'KEY': 'value'}, 44 | ) 45 | 46 | actual = packagables.bazel_binary( 47 | executors.Local.Spec(), 48 | label='label', 49 | args=['-f'], 50 | env_vars={'KEY': 'value'}, 51 | ) 52 | 53 | self.assertEqual(actual, expected) 54 | 55 | 56 | if __name__ == '__main__': 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /xmanager/xm/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import enum 16 | import unittest 17 | 18 | from xmanager.xm import utils 19 | 20 | 21 | async def make_me_a_sandwich() -> str: 22 | return 'sandwich' 23 | 24 | 25 | class ResourceType(enum.Enum): 26 | MINERALS = 1 27 | VESPEN = 2 28 | 29 | 30 | class UtilsTest(unittest.TestCase): 31 | 32 | @utils.run_in_asyncio_loop 33 | async def test_run_in_asyncio_loop(self): 34 | self.assertEqual(await make_me_a_sandwich(), 'sandwich') 35 | 36 | def test_run_in_asyncio_loop_returns_value(self): 37 | self.assertEqual( 38 | utils.run_in_asyncio_loop(make_me_a_sandwich)(), 'sandwich' 39 | ) 40 | 41 | def test_arg_escaper(self): 42 | self.assertEqual(utils.ARG_ESCAPER(1.0), '1.0') 43 | self.assertEqual(utils.ARG_ESCAPER('Jonny Droptable'), "'Jonny Droptable'") 44 | self.assertEqual(utils.ARG_ESCAPER(ResourceType.VESPEN), 'VESPEN') 45 | 46 | def test_shell_safe_arg_in_f_string(self): 47 | # ShellSafeArg shouldn't be used in f-strings. 48 | with self.assertRaises(RuntimeError): 49 | f'{utils.ShellSafeArg("42")}' # pylint: disable=expression-not-assigned 50 | 51 | 52 | if __name__ == '__main__': 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /xmanager/xm_local/README.md: -------------------------------------------------------------------------------- 1 | # `xm_local` 2 | 3 | This directory contains the Launch API scheduler components for running 4 | experiments using the Local Scheduler. 5 | -------------------------------------------------------------------------------- /xmanager/xm_local/executables.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Local backend executables.""" 15 | 16 | from typing import Dict 17 | 18 | import attr 19 | from xmanager import xm 20 | 21 | 22 | @attr.s(auto_attribs=True) 23 | class LoadedContainerImage(xm.Executable): 24 | """A locally loaded container image.""" 25 | 26 | image_id: str 27 | args: xm.SequentialArgs = attr.Factory(xm.SequentialArgs) 28 | env_vars: Dict[str, str] = attr.Factory(dict) 29 | 30 | 31 | @attr.s(auto_attribs=True) 32 | class LocalBinary(xm.Executable): 33 | """A locally located binary.""" 34 | 35 | command: str 36 | args: xm.SequentialArgs = attr.Factory(xm.SequentialArgs) 37 | env_vars: Dict[str, str] = attr.Factory(dict) 38 | 39 | 40 | @attr.s(auto_attribs=True) 41 | class GoogleContainerRegistryImage(xm.Executable): 42 | """An image inside Google Container Registry.""" 43 | 44 | image_path: str 45 | args: xm.SequentialArgs = attr.Factory(xm.SequentialArgs) 46 | env_vars: Dict[str, str] = attr.Factory(dict) 47 | -------------------------------------------------------------------------------- /xmanager/xm_local/executors_test.py: -------------------------------------------------------------------------------- 1 | """Tests for xmanager.xm_local.executors.""" 2 | 3 | import unittest 4 | from unittest import mock 5 | 6 | from xmanager.xm_local import executors 7 | 8 | 9 | class ExecutorsTest(unittest.TestCase): 10 | 11 | @mock.patch.object(executors, 'importlib') 12 | def test_executor_missing_required_module(self, mock_importlib): 13 | mock_importlib.import_module.side_effect = ModuleNotFoundError( 14 | "No module named 'xmanager.cloud.kubernetes'" 15 | ) 16 | with self.assertRaises(ModuleNotFoundError): 17 | executors.Kubernetes() 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /xmanager/xm_local/handles.py: -------------------------------------------------------------------------------- 1 | """Local execution handles.""" 2 | 3 | import abc 4 | import asyncio 5 | from concurrent import futures 6 | import logging 7 | 8 | import attr 9 | import docker 10 | from docker.models import containers 11 | from xmanager.xm_local import status 12 | 13 | 14 | _DEFAULT_ENCODING = 'utf-8' 15 | 16 | 17 | def _print_chunk(name: str, line: str) -> None: 18 | print('[{}] {}'.format(name, line.strip())) 19 | 20 | 21 | class ExecutionHandle(abc.ABC): 22 | """An interface for operating on executions.""" 23 | 24 | @abc.abstractmethod 25 | async def wait(self) -> None: 26 | raise NotImplementedError 27 | 28 | @abc.abstractmethod 29 | def get_status(self) -> status.LocalWorkUnitStatus: 30 | """Aggregates the statuses of all jobs in the work unit into one status.""" 31 | raise NotImplementedError 32 | 33 | def save_to_storage(self, experiment_id: int, work_unit_id: int) -> None: 34 | """Saves the handle to the local database.""" 35 | raise NotImplementedError 36 | 37 | def stop(self) -> None: 38 | """Stops execution.""" 39 | raise NotImplementedError 40 | 41 | 42 | class LocalExecutionHandle(ExecutionHandle, abc.ABC): 43 | """An interface for operating on local executions.""" 44 | 45 | @abc.abstractmethod 46 | async def monitor(self) -> None: 47 | raise NotImplementedError 48 | 49 | @abc.abstractmethod 50 | def terminate(self) -> None: 51 | raise NotImplementedError 52 | 53 | 54 | @attr.s(auto_attribs=True) 55 | class ContainerHandle(LocalExecutionHandle): 56 | """A handle for referring to the launched container.""" 57 | 58 | name: str 59 | model: containers.Container | None 60 | stream_output: bool 61 | futures_executor: futures.Executor = attr.Factory(futures.ThreadPoolExecutor) 62 | 63 | async def wait(self) -> None: 64 | if self.model is None: 65 | return 66 | 67 | def _wait() -> None: 68 | try: 69 | self.model.wait() 70 | except docker.errors.NotFound: 71 | logging.info( 72 | 'Container %s not found (it may have already been removed).', 73 | self.model.name, 74 | ) 75 | 76 | await asyncio.wrap_future(self.futures_executor.submit(_wait)) 77 | 78 | def get_status(self) -> status.LocalWorkUnitStatus: 79 | raise NotImplementedError 80 | 81 | def terminate(self) -> None: 82 | if self.model is None: 83 | return 84 | 85 | self.model.stop() 86 | self.futures_executor.shutdown(wait=True) 87 | 88 | async def monitor(self) -> None: 89 | if self.model is None: 90 | return 91 | 92 | def _stream_chunks() -> None: 93 | try: 94 | for chunk in self.model.logs(stream=True, follow=True): 95 | _print_chunk(self.name, chunk.decode(_DEFAULT_ENCODING)) 96 | except docker.errors.NotFound: 97 | logging.info( 98 | 'Container %s not found (it may have already been removed).', 99 | self.model.name, 100 | ) 101 | 102 | if self.stream_output: 103 | await asyncio.wrap_future(self.futures_executor.submit(_stream_chunks)) 104 | 105 | 106 | @attr.s(auto_attribs=True) 107 | class BinaryHandle(LocalExecutionHandle): 108 | """A handle referring to the launched binary.""" 109 | 110 | name: str 111 | process: asyncio.subprocess.Process # pytype: disable=module-attr 112 | stream_output: bool 113 | 114 | async def wait(self) -> None: 115 | return_code = await self.process.wait() 116 | if return_code != 0: 117 | raise RuntimeError( 118 | f'Process {self.process!r} returned non-zero code: {return_code}' 119 | ) 120 | 121 | def get_status(self) -> status.LocalWorkUnitStatus: 122 | raise NotImplementedError 123 | 124 | def terminate(self) -> None: 125 | self.process.terminate() 126 | 127 | async def monitor(self) -> None: 128 | if self.stream_output: 129 | if not self.process.stdout: 130 | raise ValueError( 131 | 'No stdout available from process. Cannot stream output.' 132 | ) 133 | while True: 134 | line = await self.process.stdout.readline() 135 | if not line: 136 | break 137 | _print_chunk(self.name, line.decode(_DEFAULT_ENCODING)) 138 | -------------------------------------------------------------------------------- /xmanager/xm_local/packaging/bazel_tools_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from xmanager.xm_local.packaging import bazel_tools 18 | 19 | 20 | class BazelToolsTest(unittest.TestCase): 21 | 22 | def test_lex_full_label(self): 23 | self.assertEqual( 24 | bazel_tools._lex_label('//project/directory:target'), 25 | (['project', 'directory'], 'target'), 26 | ) 27 | 28 | def test_lex_short_label(self): 29 | self.assertEqual( 30 | bazel_tools._lex_label('//project/package'), 31 | (['project', 'package'], 'package'), 32 | ) 33 | 34 | def test_lex_root_target(self): 35 | self.assertEqual(bazel_tools._lex_label('//:label'), ([], 'label')) 36 | 37 | def test_lex_empty_label(self): 38 | with self.assertRaises(ValueError): 39 | bazel_tools._lex_label('//') 40 | 41 | def test_lex_relative_label(self): 42 | with self.assertRaises(ValueError): 43 | bazel_tools._lex_label('a/b:c') 44 | 45 | def test_assemble_label(self): 46 | self.assertEqual(bazel_tools._assemble_label((['a', 'b'], 'c')), '//a/b:c') 47 | 48 | def test_label_kind_lines_to_dict(self): 49 | self.assertEqual( 50 | bazel_tools._label_kind_lines_to_dict([ 51 | 'py_binary rule //:py_target', 52 | 'cc_binary rule //:cc_target', 53 | ]), 54 | {'//:py_target': 'py_binary rule', '//:cc_target': 'cc_binary rule'}, 55 | ) 56 | 57 | def test_absolute_label_with_extension_dot(self): 58 | self.assertEqual( 59 | bazel_tools._lex_label('//project/directory:image.tar'), 60 | (['project', 'directory'], 'image.tar'), 61 | ) 62 | 63 | def test_label_with_three_dots(self): 64 | with self.assertRaisesRegex(ValueError, 'is not an absolute Bazel label'): 65 | bazel_tools._lex_label('//project/directory/...') 66 | 67 | def test_label_with_star_target(self): 68 | with self.assertRaisesRegex(ValueError, 'is not an absolute Bazel label'): 69 | bazel_tools._lex_label('//project/directory:*') 70 | 71 | def test_label_with_all_target(self): 72 | with self.assertRaisesRegex(ValueError, '`:all` is not a valid target'): 73 | bazel_tools._lex_label('//project/directory:all') 74 | 75 | 76 | if __name__ == '__main__': 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /xmanager/xm_local/packaging/router.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Methods for routing packageables to appropriate packagers.""" 15 | 16 | import collections 17 | from typing import Dict, List, Sequence, Tuple 18 | 19 | from xmanager import xm 20 | from xmanager.bazel import client as bazel_client 21 | from xmanager.xm_local import executors 22 | from xmanager.xm_local.packaging import bazel_tools 23 | from xmanager.xm_local.packaging import cloud as cloud_packaging 24 | from xmanager.xm_local.packaging import local as local_packaging 25 | 26 | 27 | def _packaging_router( 28 | built_targets: bazel_tools.TargetOutputs, packageable: xm.Packageable 29 | ) -> xm.Executable: 30 | match packageable.executor_spec: 31 | case executors.VertexSpec(): 32 | return cloud_packaging.package_cloud_executable( 33 | built_targets, 34 | packageable, 35 | packageable.executable_spec, 36 | ) 37 | case executors.LocalSpec(): 38 | return local_packaging.package_for_local_executor( 39 | built_targets, 40 | packageable, 41 | packageable.executable_spec, 42 | ) 43 | case executors.KubernetesSpec(): 44 | return cloud_packaging.package_cloud_executable( 45 | built_targets, 46 | packageable, 47 | packageable.executable_spec, 48 | ) 49 | case _: 50 | raise TypeError( 51 | f'Unsupported executor specification: {packageable.executor_spec!r}. ' 52 | f'Packageable: {packageable!r}' 53 | ) 54 | 55 | 56 | _ArgsToTargets = Dict[Tuple[str, ...], List[bazel_client.BazelTarget]] 57 | 58 | 59 | def package(packageables: Sequence[xm.Packageable]) -> List[xm.Executable]: 60 | """Routes a packageable to an appropriate packaging mechanism.""" 61 | built_targets: bazel_tools.TargetOutputs = {} 62 | bazel_targets = bazel_tools.collect_bazel_targets(packageables) 63 | 64 | if bazel_targets: 65 | bazel_service = bazel_tools.local_bazel_service() 66 | 67 | args_to_targets: _ArgsToTargets = collections.defaultdict(list) 68 | for target in bazel_targets: 69 | args_to_targets[target.bazel_args].append(target) 70 | for args, targets in args_to_targets.items(): 71 | outputs = bazel_service.build_targets( 72 | labels=tuple(target.label for target in targets), 73 | bazel_args=args, 74 | ) 75 | for target, output in zip(targets, outputs): 76 | built_targets[target] = output 77 | 78 | return [ 79 | _packaging_router(built_targets, packageable) 80 | for packageable in packageables 81 | ] 82 | -------------------------------------------------------------------------------- /xmanager/xm_local/registry.py: -------------------------------------------------------------------------------- 1 | """Registry of execution logic for each executor type. 2 | 3 | Execution logic is automatically registered when the corresponding Executor 4 | class is instantiated. 5 | 6 | The registry allows users to avoid heavy dependencies for executors that are not 7 | used in the experiment (i.e. heavy Kubernetes deps for the `xm_local.Kubernetes` 8 | executor). 9 | """ 10 | 11 | from typing import Awaitable, Callable, Generic, Sequence, Type, TypeVar 12 | 13 | import attr 14 | from xmanager import xm 15 | from xmanager.xm_local import handles 16 | 17 | 18 | _Handle = TypeVar('_Handle', bound=handles.ExecutionHandle) 19 | 20 | 21 | @attr.s(auto_attribs=True) 22 | class _ExecutorInfo(Generic[_Handle]): 23 | # Method to launch a job group and get the execution handles. 24 | launch: Callable[..., Awaitable[Sequence[_Handle]]] 25 | # Method to create an execution handle using data from the local database. 26 | create_handle: Callable[..., _Handle] | None = None 27 | 28 | 29 | _REGISTRY: dict[Type[xm.Executor], _ExecutorInfo] = {} 30 | 31 | 32 | def register( 33 | executor_type: Type[xm.Executor], 34 | launch: Callable[..., Awaitable[Sequence[_Handle]]], 35 | create_handle: Callable[..., _Handle] | None = None, 36 | ): 37 | _REGISTRY[executor_type] = _ExecutorInfo( 38 | launch=launch, create_handle=create_handle 39 | ) 40 | 41 | 42 | def is_registered(executor_type: Type[xm.Executor]) -> bool: 43 | return executor_type in _REGISTRY 44 | 45 | 46 | def get_launch_method( 47 | executor_type: Type[xm.Executor], 48 | ) -> Callable[..., Awaitable[Sequence[_Handle]]]: 49 | return _REGISTRY[executor_type].launch 50 | 51 | 52 | def get_create_handle_method( 53 | executor_type: Type[xm.Executor], 54 | ) -> Callable[..., _Handle] | None: 55 | return _REGISTRY[executor_type].create_handle 56 | -------------------------------------------------------------------------------- /xmanager/xm_local/status.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implementation of local work unit statuses.""" 15 | import enum 16 | 17 | from xmanager import xm 18 | 19 | 20 | class LocalWorkUnitStatusEnum(enum.Enum): 21 | """Status of a local experiment job.""" 22 | 23 | # Work unit was created, but has not terminated yet. 24 | RUNNING = 1 25 | # Work unit terminated and was successful. 26 | COMPLETED = 2 27 | # Work unit terminated and was not succesful. 28 | FAILED = 3 29 | # Work unit terminated because it was cancelled by the user. 30 | CANCELLED = 4 31 | 32 | 33 | class LocalWorkUnitStatus(xm.ExperimentUnitStatus): 34 | """Status of a local experiment job.""" 35 | 36 | def __init__( 37 | self, status: LocalWorkUnitStatusEnum, message: str = '' 38 | ) -> None: 39 | super().__init__() 40 | self._status = status 41 | self._message = message 42 | 43 | @property 44 | def is_active(self) -> bool: 45 | return self._status == LocalWorkUnitStatusEnum.RUNNING 46 | 47 | @property 48 | def is_completed(self) -> bool: 49 | return self._status == LocalWorkUnitStatusEnum.COMPLETED 50 | 51 | @property 52 | def is_failed(self) -> bool: 53 | return self._status == LocalWorkUnitStatusEnum.FAILED 54 | 55 | @property 56 | def message(self) -> str: 57 | return self._message 58 | -------------------------------------------------------------------------------- /xmanager/xm_local/storage/alembic.ini: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # A generic, single database configuration. 16 | 17 | [alembic] 18 | # path to migration scripts 19 | script_location = SET_AT_RUNTIME_BY_XMANAGER 20 | 21 | # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s 22 | # Uncomment the line below if you want the files to be prepended with date and time 23 | # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file 24 | # for all available tokens 25 | # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s 26 | 27 | # sys.path path, will be prepended to sys.path if present. 28 | # defaults to the current working directory. 29 | prepend_sys_path = . 30 | 31 | # timezone to use when rendering the date within the migration file 32 | # as well as the filename. 33 | # If specified, requires the python-dateutil library that can be 34 | # installed by adding `alembic[tz]` to the pip requirements 35 | # string value is passed to dateutil.tz.gettz() 36 | # leave blank for localtime 37 | # timezone = 38 | 39 | # max length of characters to apply to the 40 | # "slug" field 41 | # truncate_slug_length = 40 42 | 43 | # set to 'true' to run the environment during 44 | # the 'revision' command, regardless of autogenerate 45 | # revision_environment = false 46 | 47 | # set to 'true' to allow .pyc and .pyo files without 48 | # a source .py file to be detected as revisions in the 49 | # versions/ directory 50 | # sourceless = false 51 | 52 | # version location specification; This defaults 53 | # to alembic/versions. When using multiple version 54 | # directories, initial revisions must be specified with --version-path. 55 | # The path separator used here should be the separator specified by "version_path_separator" below. 56 | # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions 57 | 58 | # version path separator; As mentioned above, this is the character used to split 59 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. 60 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. 61 | # Valid values for version_path_separator are: 62 | # 63 | # version_path_separator = : 64 | # version_path_separator = ; 65 | # version_path_separator = space 66 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 67 | 68 | # the output encoding used when revision files 69 | # are written from script.py.mako 70 | # output_encoding = utf-8 71 | 72 | sqlalchemy.url = SET_AT_RUNTIME_BY_XMANAGER 73 | 74 | 75 | [post_write_hooks] 76 | # post_write_hooks defines scripts or Python functions that are run 77 | # on newly generated revision scripts. See the documentation for further 78 | # detail and examples 79 | 80 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 81 | # hooks = black 82 | # black.type = console_scripts 83 | # black.entrypoint = black 84 | # black.options = -l 79 REVISION_SCRIPT_FILENAME 85 | -------------------------------------------------------------------------------- /xmanager/xm_local/storage/alembic/README.md: -------------------------------------------------------------------------------- 1 | # Database Migration 2 | 3 | ## Alembic 4 | 5 | XManager updates may introduce updates to the schema of the database 6 | used to store experiment metadata. Therefore, the existing database must adapt 7 | to the new changes. This is achieved by using 8 | [`alembic`](https://alembic.sqlalchemy.org/en/latest/) to perform migrations. 9 | 10 | XManager comes with an already configured `alembic` environment in the 11 | `xm_local.storage` package. Migration scripts are provided in the 12 | `alembic/versions` directory present there. 13 | 14 | A new revision can be created using the `alembic revision` command: 15 | 16 | ``` 17 | $ cd xmanager/xm_local/storage 18 | $ alembic -m revision "Create Example table." 19 | Generating ${PATH_TO_XMANAGER}/xm_local/storage/alembic/versions/106c21f078d5_create_example_table.py ... done 20 | ``` 21 | 22 | A new file `106c21f078d5_create_example_table.py` is generated. This file 23 | contains details about the migration, like the revision ID `106c21f078d5` and 24 | the `upgrade` and `downgrade` functions. One has to populate these functions 25 | with directives that will apply changes to the database. 26 | 27 | ```py 28 | """Create Example table. 29 | 30 | Revision ID: 106c21f078d5 31 | Revises: 5cbe03fe7ed1 32 | Create Date: 2022-09-20 09:23:48.029756 33 | 34 | """ 35 | from alembic import op 36 | import sqlalchemy as sa 37 | 38 | 39 | # revision identifiers, used by Alembic. 40 | revision = '106c21f078d5' 41 | down_revision = '5cbe03fe7ed1' 42 | branch_labels = None 43 | depends_on = None 44 | 45 | 46 | def upgrade() -> None: 47 | pass 48 | 49 | 50 | def downgrade() -> None: 51 | pass 52 | ``` 53 | 54 | We can add some directives to our script, e.g. adding a new table `Example`: 55 | 56 | ```py 57 | def upgrade() -> None: 58 | op.create_table( 59 | 'Example', 60 | sa.Column('Id', sa.Integer, primary_key=True) 61 | ) 62 | 63 | def downgrade() -> None: 64 | op.drop_table('Example') 65 | ``` 66 | 67 | Launching XManager now using the `--xm_upgrade_db` flag will attempt to update 68 | the database. For example, running 69 | 70 | ```sh 71 | $ xmanager launch examples/cifar10_tensorflow/launcher.py -- --xm_upgrade_db 72 | ``` 73 | will perform the update and some information about this will be printed: 74 | 75 | ``` 76 | INFO [alembic.runtime.migration] Context impl SQLiteImpl. 77 | INFO [alembic.runtime.migration] Will assume non-transactional DDL. 78 | INFO [alembic.runtime.migration] Running upgrade 5cbe03fe7ed1-> 106c21f078d5', Create Example table. 79 | ``` 80 | 81 | For more details on how `alembic` works and how it can be used, check out the 82 | [tutorial](https://alembic.sqlalchemy.org/en/latest/tutorial.html). 83 | 84 | If more developers work on migrations, make sure to sync with the latest 85 | version of the code, check if there are multiple heads / any issues with the 86 | `alembic` history and create a merge revision accordingly. Check 87 | [this tutorial](https://alembic.sqlalchemy.org/en/latest/branches.html#working-with-branches) 88 | for more details on branching. 89 | 90 | NOTE: Migrations created by the user are not explicitly supported by XManager. 91 | On update, it's the responsibility of the user to adapt to new revisions. 92 | 93 | ## Automatic Migration on Update 94 | 95 | If XManager is updated and a new version of the database is available (a new 96 | migration script is provided), the currently used database 97 | (local SQLite or specified through the YAML config file) must be upgraded 98 | to run XManager. 99 | 100 | When using `--xm_upgrade_db`, XManager attempts to update the database 101 | to the latest version. Taking a backup of the database 102 | before using this flag is recommended, since migrations may fail. 103 | 104 | This doesn't apply for the case when a new database is used. 105 | The initial tables will be created automatically for a new database without 106 | having to use the flag. 107 | 108 | NOTE: Making changes to the database from outside of the XManager client 109 | is not supported and discouraged (SQL queries, `alembic` upgrades/downgrades). 110 | All database operations should be performed 111 | automatically by the client. -------------------------------------------------------------------------------- /xmanager/xm_local/storage/alembic/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Alembic env.py.""" 15 | 16 | from alembic import context 17 | from sqlalchemy import engine_from_config 18 | from sqlalchemy import pool 19 | 20 | config = context.config 21 | 22 | target_metadata = None 23 | 24 | 25 | def run_migrations_offline() -> None: 26 | """Run migrations in 'offline' mode. 27 | 28 | This configures the context with just a URL 29 | and not an Engine, though an Engine is acceptable 30 | here as well. By skipping the Engine creation 31 | we don't even need a DBAPI to be available. 32 | """ 33 | url = config.get_main_option('sqlalchemy.url') 34 | context.configure( 35 | url=url, 36 | target_metadata=target_metadata, 37 | literal_binds=True, 38 | dialect_opts={'paramstyle': 'named'}, 39 | ) 40 | 41 | with context.begin_transaction(): 42 | context.run_migrations() 43 | 44 | 45 | def run_migrations_online() -> None: 46 | """Run migrations in 'online' mode. 47 | 48 | In this scenario we need to create an Engine 49 | and associate a connection with the context. 50 | """ 51 | connectable = config.attributes.get('connection', None) 52 | 53 | if connectable is None: 54 | connectable = engine_from_config( 55 | config.get_section(config.config_ini_section), 56 | prefix='sqlalchemy.', 57 | poolclass=pool.NullPool, 58 | ) 59 | 60 | with connectable.connect() as connection: 61 | context.configure(connection=connection, target_metadata=target_metadata) 62 | 63 | with context.begin_transaction(): 64 | context.run_migrations() 65 | else: 66 | context.configure(connection=connectable, target_metadata=target_metadata) 67 | 68 | with context.begin_transaction(): 69 | context.run_migrations() 70 | 71 | 72 | if context.is_offline_mode(): 73 | run_migrations_offline() 74 | else: 75 | run_migrations_online() 76 | -------------------------------------------------------------------------------- /xmanager/xm_local/storage/alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Generated script names don't follow the module naming convention 16 | # pylint: disable=invalid-name 17 | """${message} 18 | 19 | Revision ID: ${up_revision} 20 | Revises: ${down_revision | comma,n} 21 | Create Date: ${create_date} 22 | 23 | """ 24 | from alembic import op 25 | import sqlalchemy as sa 26 | ${imports if imports else ""} 27 | 28 | # revision identifiers, used by Alembic. 29 | revision = ${repr(up_revision)} 30 | down_revision = ${repr(down_revision)} 31 | branch_labels = ${repr(branch_labels)} 32 | depends_on = ${repr(depends_on)} 33 | 34 | 35 | def upgrade() -> None: 36 | """Upgrades DB.""" 37 | ${upgrades if upgrades else "pass"} 38 | 39 | 40 | def downgrade() -> None: 41 | """Downgrades DB.""" 42 | ${downgrades if downgrades else "pass"} 43 | -------------------------------------------------------------------------------- /xmanager/xm_local/storage/data.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2021 DeepMind Technologies Limited 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto3"; 16 | 17 | package xmanager; 18 | 19 | message Job { 20 | oneof kind { 21 | LocalJob local = 1; 22 | AIPlatformJob caip = 2; 23 | KubernetesJob kubernetes = 3; 24 | } 25 | } 26 | 27 | message LocalJob { 28 | // The identity of the job as identified by `ps -o pid,cmd,etime`. 29 | // etime is then converted to an unix timestamp 30 | string pid = 1; 31 | string cmd = 2; 32 | int64 timestamp = 3; 33 | } 34 | 35 | message AIPlatformJob { 36 | // The resource name of the job. This will be in the form: 37 | // projects/{project}/locations/{location}/customJobs/{customJobId} 38 | string resource_name = 1; 39 | } 40 | 41 | message KubernetesJob { 42 | // Namespace of the kubernetes job. 43 | string namespace = 1; 44 | // Name of the kubernetes job. 45 | string job_name = 2; 46 | } 47 | --------------------------------------------------------------------------------