├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── colab_codelab.ipynb
├── compile_pb.sh
├── docs
├── executable_specs.md
├── executors.md
├── metadata_storage.md
├── oss_xm_scope.md
├── parameter_controller.md
├── tensorboard.md
└── xm_launch_api_principles.md
├── examples
├── cifar10_tensorflow
│ ├── cifar10.py
│ ├── launcher.py
│ └── requirements.txt
├── cifar10_tensorflow_k8s_multiworker
│ ├── cifar10.py
│ ├── launcher.py
│ └── requirements.txt
├── cifar10_tensorflow_k8s_ps
│ ├── cifar10.py
│ ├── launcher.py
│ └── requirements.txt
├── cifar10_tensorflow_k8s_tensorboard
│ ├── cifar10.py
│ ├── launcher.py
│ └── requirements.txt
├── cifar10_tensorflow_tpu
│ ├── cifar10.py
│ ├── launcher.py
│ └── requirements.txt
├── cifar10_torch
│ ├── cifar10.py
│ ├── launcher.py
│ └── requirements.txt
├── cifar10_torch_xla
│ ├── cifar10.py
│ ├── launcher.py
│ └── requirements.txt
├── dockerfile
│ ├── Dockerfile
│ ├── arg_printer.cc
│ └── launcher.py
├── dopamine
│ ├── launcher.py
│ └── train.py
├── local_arg_printer
│ ├── BUILD
│ ├── WORKSPACE
│ ├── arg_printer.cc
│ └── launcher.py
├── local_container_gpu
│ ├── cifar10.py
│ ├── launcher.py
│ └── requirements.txt
├── local_container_links
│ ├── BUILD.bazel
│ ├── README.md
│ ├── WORKSPACE
│ ├── launcher.py
│ ├── requirements.txt
│ └── server.py
├── parameter_controller
│ ├── db_config.yaml
│ ├── inner_job
│ │ ├── requirements.txt
│ │ └── wait_job.py
│ ├── launcher.py
│ └── requirements.txt
└── vizier
│ ├── README.md
│ ├── launcher.py
│ ├── polynomial.py
│ └── requirements.txt
├── jupyter_codelab.ipynb
├── setup.py
├── setup_scripts
├── install_bazel.sh
├── install_docker.sh
├── install_gcloud.sh
├── install_python.sh
├── install_xmanager.sh
├── setup_all.sh
└── setup_gcp.sh
└── xmanager
├── bazel
├── client.py
└── file_utils.py
├── cli
├── README.md
└── cli.py
├── cloud
├── README.md
├── __init__.py
├── auth.py
├── auth_test.py
├── build_image.py
├── build_image_test.py
├── cloud_build.py
├── cloud_build_test.py
├── data
│ └── wrapped_entrypoint.sh
├── docker_lib.py
├── kubernetes.py
├── kubernetes_test.py
├── utils.py
├── utils_test.py
├── vertex.py
└── vertex_test.py
├── contrib
├── __init__.py
├── addressing.py
├── addressing_test.py
├── copybara.py
├── executor_selector.py
├── flow.py
├── framework_defaults.py
├── framework_defaults_test.py
├── gcs.py
├── gcs_test.py
├── parameter_controller.py
├── process_entry.py
├── tensorboard.py
├── tensorboard_test.py
├── tpu.py
├── xm_tensorflow.py
└── xm_tensorflow_test.py
├── docker
└── docker_adapter.py
├── generated
├── README.md
├── build_event_stream_pb2.py
├── command_line_pb2.py
├── data_pb2.py
├── failure_details_pb2.py
├── invocation_policy_pb2.py
└── option_filters_pb2.py
├── module_lazy_loader
├── lazy_loader_module_attrs_test.py
└── module_lazy_loader.py
├── vizier
└── vizier_cloud
│ ├── __init__.py
│ ├── study_factory.py
│ ├── vizier_controller.py
│ ├── vizier_exploration.py
│ └── vizier_worker.py
├── xm
├── README.md
├── __init__.py
├── async_packager.py
├── async_packager_test.py
├── compute_units.py
├── core.py
├── core_test.py
├── executables.py
├── executables_test.py
├── id_predictor.py
├── id_predictor_test.py
├── job_blocks.py
├── job_blocks_test.py
├── job_operators.py
├── job_operators_test.py
├── metadata_context.py
├── packagables.py
├── packagables_generator.py
├── packagables_test.py
├── resources.py
├── resources_test.py
├── utils.py
└── utils_test.py
├── xm_flags.py
├── xm_local
├── README.md
├── __init__.py
├── executables.py
├── execution.py
├── execution_test.py
├── executors.py
├── executors_test.py
├── experiment.py
├── handles.py
├── multiplexer.py
├── packaging
│ ├── bazel_tools.py
│ ├── bazel_tools_test.py
│ ├── cloud.py
│ ├── local.py
│ └── router.py
├── registry.py
├── status.py
└── storage
│ ├── alembic.ini
│ ├── alembic
│ ├── README.md
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ │ └── f45829405692_migrate_or_create.py
│ ├── data.proto
│ └── database.py
└── xm_mock
└── __init__.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | The project currently does not accept pull requests, but we'd be more than happy
4 | to get your feedback via [issues](https://github.com/deepmind/xmanager/issues)!
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Community Guidelines
20 |
21 | This project follows
22 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
23 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """XManager."""
15 |
--------------------------------------------------------------------------------
/compile_pb.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | readonly BAZEL_DIR="/tmp/bazel"
18 | readonly SOURCE_ROOT_DIR="$(realpath $(dirname $0))"
19 | readonly GENERATED_DIR="$SOURCE_ROOT_DIR/generated"
20 |
21 | git clone https://github.com/bazelbuild/bazel.git "${BAZEL_DIR}"
22 |
23 | protoc --proto_path="${BAZEL_DIR}" --python_out="${BAZEL_DIR}" \
24 | "${BAZEL_DIR}/src/main/java/com/google/devtools/build/lib/buildeventstream/proto/build_event_stream.proto" \
25 | "${BAZEL_DIR}/src/main/protobuf/command_line.proto" \
26 | "${BAZEL_DIR}/src/main/protobuf/failure_details.proto" \
27 | "${BAZEL_DIR}/src/main/protobuf/invocation_policy.proto" \
28 | "${BAZEL_DIR}/src/main/protobuf/option_filters.proto"
29 |
30 | cp "${BAZEL_DIR}/src/main/java/com/google/devtools/build/lib/buildeventstream/proto/build_event_stream_pb2.py" "${GENERATED_DIR}/"
31 | cp "${BAZEL_DIR}/src/main/protobuf/command_line_pb2.py" "${GENERATED_DIR}/"
32 | cp "${BAZEL_DIR}/src/main/protobuf/failure_details_pb2.py" "${GENERATED_DIR}/"
33 | cp "${BAZEL_DIR}/src/main/protobuf/invocation_policy_pb2.py" "${GENERATED_DIR}/"
34 | cp "${BAZEL_DIR}/src/main/protobuf/option_filters_pb2.py" "${GENERATED_DIR}/"
35 |
36 | protoc --proto_path="$SOURCE_ROOT_DIR" --python_out="$SOURCE_ROOT_DIR" \
37 | "$SOURCE_ROOT_DIR/xm_local/storage/data.proto"
38 | mv "$SOURCE_ROOT_DIR/xm_local/storage/data_pb2.py" "${GENERATED_DIR}/"
39 |
40 | # Make generated imports local to xmanager.generated.
41 | find "${GENERATED_DIR}/" -name '*pb2*' -exec \
42 | sed -i 's/^from src\.main.* import/from . import/' {} \;
43 | # Add NOLINT line.
44 | find "${GENERATED_DIR}/" -name '*pb2*' -exec \
45 | sed -i 's/DO NOT EDIT!/DO NOT EDIT!\n# pylint: skip-file/' {} \;
46 |
--------------------------------------------------------------------------------
/docs/executors.md:
--------------------------------------------------------------------------------
1 | # Executors
2 |
3 | ## Local
4 |
5 | The local executor declares that an executable will be run on the same machine
6 | from which the launch script is invoked.
7 |
8 | ```python
9 | xm_local.Local(
10 | docker_options=xm_local.DockerOptions(...),
11 | )
12 | ```
13 |
14 | Making GPUs available to local containers is possible by requesting
15 | the `local_gpu` resource through a requirements object.
16 |
17 | ```python
18 | xm_local.Local(
19 | xm.JobRequirements(local_gpu=1)
20 | )
21 | ```
22 |
23 | Note: Only local NVIDIA GPUs can be requested for the `Local` executor through
24 | the resource requirements object. Any other requirements will be ignored.
25 |
26 | ## Vertex AI (Cloud AI Platform)
27 |
28 | The `Vertex` executor declares that an executable will be run on the Vertex AI
29 | platform.
30 |
31 | The Vertex executor takes in a resource requirements object.
32 |
33 | ```python
34 | xm_local.Vertex(
35 | xm.JobRequirements(
36 | cpu=1, # Measured in vCPUs.
37 | ram=4 * xm.GiB,
38 | T4=1, # NVIDIA Tesla T4.
39 | ),
40 | )
41 | ```
42 |
43 | ```python
44 | xm_local.Vertex(
45 | xm.JobRequirements(
46 | cpu=1, # Measured in vCPUs.
47 | ram=4 * xm.GiB,
48 | TPU_V2=8, # TPU v2.
49 | ),
50 | )
51 | ```
52 |
53 | As of June 2021, the currently supported accelerator types are:
54 |
55 | * `P100`
56 | * `V100`
57 | * `P4`
58 | * `T4`
59 | * `A100`
60 | * `TPU_V2`
61 | * `TPU_V3`
62 |
63 | ### Vertex AI Specification
64 |
65 | The Vertex AI executor allows you specify a remote image repository to push to.
66 |
67 | ```python
68 | xm_local.Vertex.Spec(
69 | push_image_tag='gcr.io//:',
70 | )
71 | ```
72 |
73 | ## Kubernetes (experimental)
74 |
75 | The Kubernetes executor declares that an executable will be run on a Kubernetes
76 | cluster. As of October 2021, Kubernetes is not fully supported.
77 |
78 | The Kubernetes executor pulls from your local `kubeconfig`. The XManager
79 | command-line has helpers to set up a Google Kubernetes Engine (GKE) cluster.
80 |
81 | ```bash
82 | pip install caliban==0.4.1
83 | xmanager cluster create
84 |
85 | # cleanup
86 | xmanager cluster delete
87 | ```
88 |
89 | You can store the GKE credentials in your `kubeconfig`:
90 |
91 | ```bash
92 | gcloud container clusters get-credentials
93 | ```
94 |
95 | ### Kubernetes Specification
96 |
97 | The Kubernetes executor allows you specify a remote image repository to push to.
98 |
99 | ```python
100 | xm_local.Kubernetes.Spec(
101 | push_image_tag='gcr.io//:',
102 | )
103 | ```
104 |
--------------------------------------------------------------------------------
/docs/metadata_storage.md:
--------------------------------------------------------------------------------
1 | # Metadata Storage
2 |
3 | Experiment metadata is stored in a SQL database. By default, the database used
4 | is the SQLite one at `~/.xmanager/experiments.sqlite3`. If that does not
5 | suffice, the XManager client can also connect to a generic database based on a
6 | YAML configuration. An example of such a configuration is given below:
7 |
8 | ```yaml
9 | # sample_db_config.yaml
10 |
11 | # Connector used - one of ['cloudsql', 'sqlite', 'generic']
12 | sql_connector: 'generic'
13 |
14 | sql_connection_settings:
15 | # Backend used, e.g. 'mysql', 'postgresql'
16 | backend: 'mysql'
17 | # Driver used, e.g. 'pymysql', 'pg8000'
18 | driver: 'pymysql'
19 | # Username
20 | username: 'root'
21 | # Password
22 | password: 'metadata'
23 | # Host (or instance connection name for CloudSql)
24 | host: '127.0.01'
25 | # Port
26 | port: 3309
27 | # Database name
28 | db_name: 'metadata'
29 | ```
30 |
31 | The `sql_connector` field specifies the connector to be used. It can be one of
32 | `sqlite`, `cloudsql`, `generic`. Generally, it's recommended to use the `sqlite`
33 | connector for a local DB at a different location than the default one,
34 | `cloudsql` for connecting to a CloudSQL database, and `generic` for any other
35 | database.
36 |
37 | The `cloudsql` connector runs the
38 | [CloudSQL Auth Proxy](https://cloud.google.com/sql/docs/mysql/sql-proxy)
39 | automatically, so it requires the host that runs the XManager client to have the
40 | required permissions (Cloud SQL Editor IAM role) and project APIs enabled (Cloud
41 | SQL Admin API). When using the `cloudsql` connector, one should use an instance
42 | connection name of the format
43 | `${PROJECT_ID}:${INSTANCE_LOCATION}:${INSTANCE_NAME}` in the `host` field of the
44 | YAML configuration.
45 |
46 | When using the `generic` connector, the fields in the `sql_connection_settings`
47 | follow the format of the `sqlalchemy` connection URL. Therefore, each
48 | combination of the fields above that can form a valid `sqlalchemy` connection
49 | URL can be used.
50 |
51 | This YAML config can be passed to XManager by using the
52 | `--xm_db_yaml_config_path` flag. Note that the path is relative to the launch
53 | script. If the flag is not enabled, the default SQLite database is used.
54 |
55 | The XManager client attempts to always keep the database to the latest version.
56 | For details, check
57 | [database migration docs](https://github.com/deepmind/xmanager/tree/main/xmanager/xm_local/storage/alembic/README.md).
58 |
59 | ## Slides
60 |
61 | Remote metadata execution and metadata storage are present in
62 | [these slides](https://storage.googleapis.com/gresearch/xmanager/remote_execution_slides.pdf).
63 |
--------------------------------------------------------------------------------
/docs/oss_xm_scope.md:
--------------------------------------------------------------------------------
1 | # Scope of OSS XManager
2 |
3 | ## What's included
4 |
5 | Internally within Alphabet, XManager is a fully-featured ecosystem aimed to
6 | facilitate the experimentation process that is deeply integrated with our
7 | internal infrastructure. The XManager repository that you are seeing only
8 | contains a small subset of our ecosystem. Due to the amount of work needed and
9 | coupling with internal tools, we have decided to only open source the launch
10 | API: the means to describe what an experiment consists of and how to launch it.
11 | As a result, users can launch these experiments locally, on Google Cloud
12 | Platform or on Kubernetes. The API can be extended to support other Cloud
13 | platforms.
14 |
15 | ## The scope
16 |
17 | The project’s primary goal is to facilitate collaboration between internal
18 | researchers and the scientific community. The Open Sourced XManager API provides
19 | a way to share our internally-developed code, enabling everybody to reproduce
20 | results, build new knowledge on top of them, or put the code to work for the
21 | benefit of society. It also enables closer collaborations between Alphabet and
22 | external researchers.
23 |
24 | As we are focused on delivering our internal roadmap, we may not have bandwidth
25 | for external feature requests, even if they come with an implementation.
26 | XManager was designed to be highly flexible and adaptable and has a
27 | lot of potential for further use cases. Therefore we don’t exclude the
28 | possibility of increasing its scope going forward.
29 |
30 | ## Can I use XManager for research unrelated to Google or DeepMind?
31 |
32 | Yes. XManager API is distributed under Apache 2.0 licence which gives you a lot
33 | of flexibility. We understand that our
34 | [contribution policy](https://github.com/deepmind/xmanager/blob/main/CONTRIBUTING.md)
35 | doesn't sound very reassuring, but we can suggest the following to counter the
36 | risks:
37 |
38 | * The licence allows you to have and maintain a fork with all the changes you
39 | need.
40 | * The API has been designed to be modular. Similar to how `xmanager.xm_local`
41 | extends `xmanager.xm` and provides support for orchestrating experiments
42 | from a local machine, you may have your own module tailoring the API to your
43 | needs. In fact this is how our internal version adds support for
44 | Alphabet-specific infrastructure.
45 | * The core `xmanager.xm` composes the basic building blocks of a research
46 | experiment that we think every implementation should be based on. And one of
47 | the things we've learned over years is that research is a fast moving field.
48 | It was in our best interest to design it generic, flexible, and extensible.
49 |
50 | ## What is not included
51 |
52 | The following was intentionally left outside of the XManager scope to allow it
53 | to be shared in a reasonable time frame. Note that while these features are
54 | important parts of the research infrastructure we believe that many of them are
55 | better to be provided as a collection of well-integrated libraries / services
56 | rather than an all-in-one tool.
57 |
58 | * Web user interface. Vertex AI and TensorBoard provide good alternatives if
59 | you need a UI.
60 | * Metrics storage. XManager API doesn't cover storing, retrieving, or
61 | conducting analysis of experiment metrics.
62 | * Computational resource sharing. `xm_local` supports running many experiments
63 | in the Cloud in parallel. But it doesn't enforce any policy for sharing
64 | resources between many researchers.
65 | * Continuous status tracking. The base XManager implementation does not track
66 | the running status of trials or jobs. It does not contain any mechanism for
67 | push-notification or notifying a user of success or failure.
68 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow/cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Tensorflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn."""
15 | import os
16 |
17 | from absl import app
18 | from absl import flags
19 | import tensorflow as tf
20 | from tensorflow.keras import datasets
21 | from tensorflow.keras import layers
22 | from tensorflow.keras import models
23 |
24 | # When using Vertex Tensorboard, the tensorboard will be present as a
25 | # environment variable.
26 | LOG_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '')
27 |
28 | FLAGS = flags.FLAGS
29 | flags.DEFINE_integer('epochs', 5, 'epochs')
30 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate')
31 |
32 |
33 | def main(_):
34 | (train_images, train_labels), (test_images, test_labels) = (
35 | datasets.cifar10.load_data()
36 | )
37 |
38 | # Normalize pixel values to be between 0 and 1
39 | train_images, test_images = train_images / 255.0, test_images / 255.0
40 |
41 | strategy = tf.distribute.MirroredStrategy()
42 | with strategy.scope():
43 | model = models.Sequential()
44 | model.add(
45 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))
46 | )
47 | model.add(layers.MaxPooling2D((2, 2)))
48 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
49 | model.add(layers.MaxPooling2D((2, 2)))
50 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
51 |
52 | model.add(layers.Flatten())
53 | model.add(layers.Dense(64, activation='relu'))
54 | model.add(layers.Dense(10))
55 |
56 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
57 | model.compile(
58 | optimizer=optimizer,
59 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
60 | metrics=['accuracy'],
61 | )
62 |
63 | callbacks = []
64 | if LOG_DIR:
65 | callbacks = [
66 | tf.keras.callbacks.TensorBoard(
67 | log_dir=LOG_DIR,
68 | histogram_freq=1,
69 | ),
70 | ]
71 |
72 | model.fit(
73 | train_images,
74 | train_labels,
75 | epochs=FLAGS.epochs,
76 | validation_data=(test_images, test_labels),
77 | callbacks=callbacks,
78 | verbose=2,
79 | )
80 |
81 |
82 | if __name__ == '__main__':
83 | app.run(main)
84 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for CIFAR10.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/cifar10_tensorflow/launcher.py -- \
19 | --xm_wrap_late_bindings [--image_path=gcr.io/path/to/image/tag]
20 | """
21 | import asyncio
22 | import itertools
23 | import os
24 |
25 | from absl import app
26 | from absl import flags
27 | from xmanager import xm
28 | from xmanager import xm_local
29 | from xmanager.cloud import vertex
30 |
31 | FLAGS = flags.FLAGS
32 | flags.DEFINE_string('tensorboard', None, 'Tensorboard instance.')
33 | flags.DEFINE_integer('gpus_per_node', 2, 'Number of GPUs per node.')
34 |
35 |
36 | def main(_):
37 | with xm_local.create_experiment(experiment_title='cifar10') as experiment:
38 | spec = xm.PythonContainer(
39 | # Package the current directory that this script is in.
40 | path='.',
41 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6',
42 | entrypoint=xm.ModuleName('cifar10'),
43 | )
44 |
45 | [executable] = experiment.package(
46 | [
47 | xm.Packageable(
48 | executable_spec=spec,
49 | executor_spec=xm_local.Vertex.Spec(),
50 | args={},
51 | ),
52 | ]
53 | )
54 |
55 | learning_rates = [0.1, 0.001]
56 | trials = list(
57 | dict([('learning_rate', lr)])
58 | for (lr,) in itertools.product(learning_rates)
59 | )
60 |
61 | tensorboard = FLAGS.tensorboard
62 | if not tensorboard:
63 | tensorboard = vertex.get_default_client().get_or_create_tensorboard(
64 | 'cifar10'
65 | )
66 | tensorboard = asyncio.get_event_loop().run_until_complete(tensorboard)
67 |
68 | for i, hyperparameters in enumerate(trials):
69 | output_dir = os.environ.get('GOOGLE_CLOUD_BUCKET_NAME', None)
70 | if output_dir:
71 | output_dir = os.path.join(
72 | output_dir, str(experiment.experiment_id), str(i)
73 | )
74 | tensorboard_capability = xm_local.TensorboardCapability(
75 | name=tensorboard, base_output_directory=output_dir
76 | )
77 | experiment.add(
78 | xm.Job(
79 | executable=executable,
80 | executor=xm_local.Vertex(
81 | tensorboard=tensorboard_capability,
82 | requirements=xm.JobRequirements(t4=FLAGS.gpus_per_node),
83 | ),
84 | args=hyperparameters,
85 | )
86 | )
87 |
88 |
89 | if __name__ == '__main__':
90 | app.run(main)
91 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow
3 | tensorflow-datasets
4 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_multiworker/cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Tensorflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn."""
15 | import os
16 |
17 | from absl import app
18 | from absl import flags
19 | import tensorflow as tf
20 | from tensorflow.keras import datasets
21 | from tensorflow.keras import layers
22 | from tensorflow.keras import models
23 |
24 | # When using Vertex Tensorboard, the tensorboard will be present as a
25 | # environment variable.
26 | LOG_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '')
27 |
28 | FLAGS = flags.FLAGS
29 | flags.DEFINE_integer('epochs', 5, 'epochs')
30 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate')
31 |
32 |
33 | def main(_):
34 | strategy = tf.distribute.MultiWorkerMirroredStrategy()
35 | with strategy.scope():
36 | model = models.Sequential()
37 | model.add(
38 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))
39 | )
40 | model.add(layers.MaxPooling2D((2, 2)))
41 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
42 | model.add(layers.MaxPooling2D((2, 2)))
43 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
44 |
45 | model.add(layers.Flatten())
46 | model.add(layers.Dense(64, activation='relu'))
47 | model.add(layers.Dense(10))
48 |
49 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
50 | model.compile(
51 | optimizer=optimizer,
52 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
53 | metrics=['accuracy'],
54 | )
55 |
56 | callbacks = []
57 | if LOG_DIR:
58 | callbacks = [
59 | tf.keras.callbacks.TensorBoard(
60 | log_dir=LOG_DIR,
61 | histogram_freq=1,
62 | ),
63 | ]
64 |
65 | (train_images, train_labels), (test_images, test_labels) = (
66 | datasets.cifar10.load_data()
67 | )
68 | # Normalize pixel values to be between 0 and 1
69 | train_images, test_images = train_images / 255.0, test_images / 255.0
70 |
71 | train_dataset = tf.data.Dataset.from_tensor_slices(
72 | (train_images, train_labels)
73 | )
74 | validation_dataset = tf.data.Dataset.from_tensor_slices(
75 | (test_images, test_labels)
76 | )
77 |
78 | train_dataset = train_dataset.batch(32)
79 | validation_dataset = validation_dataset.batch(32)
80 |
81 | options = tf.data.Options()
82 | options.experimental_distribute.auto_shard_policy = (
83 | tf.data.experimental.AutoShardPolicy.DATA
84 | )
85 |
86 | train_dataset.with_options(options)
87 | validation_dataset.with_options(options)
88 |
89 | model.fit(
90 | train_dataset,
91 | epochs=FLAGS.epochs,
92 | validation_data=validation_dataset,
93 | callbacks=callbacks,
94 | verbose=2,
95 | )
96 |
97 |
98 | if __name__ == '__main__':
99 | app.run(main)
100 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_multiworker/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for CIFAR10 using TF's MultiWorkerMirroredStrategy using the Kubernetes back-end.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/cifar10_tensorflow_k8s_multiworker/launcher.py -- \
19 | --xm_wrap_late_bindings [--image_path=gcr.io/path/to/image/tag]
20 | """
21 | import itertools
22 |
23 | from absl import app
24 | from xmanager import xm
25 | from xmanager import xm_local
26 | from xmanager.contrib import xm_tensorflow
27 |
28 |
29 | def main(_):
30 | with xm_local.create_experiment(
31 | experiment_title='kubernetes_multiworker'
32 | ) as experiment:
33 | spec = xm.PythonContainer(
34 | # Package the current directory that this script is in.
35 | path='.',
36 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6',
37 | entrypoint=xm.ModuleName('cifar10'),
38 | )
39 |
40 | [executable] = experiment.package(
41 | [
42 | xm.Packageable(
43 | executable_spec=spec,
44 | executor_spec=xm_local.Kubernetes.Spec(),
45 | args={},
46 | ),
47 | ]
48 | )
49 |
50 | learning_rates = [0.001]
51 | trials = list(
52 | dict([('learning_rate', lr)])
53 | for (lr,) in itertools.product(learning_rates)
54 | )
55 |
56 | builder = xm_tensorflow.MultiWorkerMirroredStrategyBuilder(
57 | experiment=experiment,
58 | worker_executable=executable,
59 | worker_executor=xm_local.Kubernetes(
60 | requirements=xm.JobRequirements(t4=1)
61 | ),
62 | worker_name='worker',
63 | num_workers=3,
64 | )
65 |
66 | for hyperparameters in trials:
67 | experiment.add(builder.gen_job_group(), args=hyperparameters)
68 |
69 |
70 | if __name__ == '__main__':
71 | app.run(main)
72 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_multiworker/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow
3 | tensorflow-datasets
4 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_ps/cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Tensorflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn."""
15 | import os
16 |
17 | from absl import app
18 | from absl import flags
19 | import tensorflow as tf
20 | from tensorflow.keras import datasets
21 | from tensorflow.keras import layers
22 | from tensorflow.keras import models
23 |
24 | FLAGS = flags.FLAGS
25 | flags.DEFINE_integer('epochs', 5, 'epochs')
26 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate')
27 |
28 |
29 | def main(_):
30 | cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
31 |
32 | if cluster_resolver.task_type in ('worker', 'ps'):
33 | os.environ['GRPC_FAIL_FAST'] = 'use_caller'
34 |
35 | server = tf.distribute.Server(
36 | cluster_resolver.cluster_spec(),
37 | job_name=cluster_resolver.task_type,
38 | task_index=cluster_resolver.task_id,
39 | protocol=cluster_resolver.rpc_layer or 'grpc',
40 | start=True,
41 | )
42 | server.join()
43 |
44 | (train_images, train_labels), _ = datasets.cifar10.load_data()
45 |
46 | def dataset_fn(input_context):
47 | dataset = tf.data.Dataset.from_tensor_slices(
48 | (train_images, train_labels)
49 | ).repeat()
50 | dataset = dataset.shard(
51 | input_context.num_input_pipelines, input_context.input_pipeline_id
52 | )
53 | dataset = dataset.batch(64)
54 | dataset = dataset.prefetch(2)
55 |
56 | return dataset
57 |
58 | strategy = tf.distribute.experimental.ParameterServerStrategy(
59 | cluster_resolver
60 | )
61 | with strategy.scope():
62 | model = models.Sequential()
63 | model.add(
64 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))
65 | )
66 | model.add(layers.MaxPooling2D((2, 2)))
67 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
68 | model.add(layers.MaxPooling2D((2, 2)))
69 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
70 |
71 | model.add(layers.Flatten())
72 | model.add(layers.Dense(64, activation='relu'))
73 | model.add(layers.Dense(10))
74 |
75 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
76 | model.compile(
77 | optimizer=optimizer,
78 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
79 | metrics=['accuracy'],
80 | )
81 |
82 | model.fit(
83 | tf.keras.utils.experimental.DatasetCreator(dataset_fn),
84 | steps_per_epoch=1500,
85 | epochs=FLAGS.epochs,
86 | )
87 |
88 |
89 | if __name__ == '__main__':
90 | app.run(main)
91 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_ps/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for CIFAR10 using TF's ParameterServerStrategy.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/cifar10_tensorflow_k8s__ps/launcher.py -- \
19 | --xm_wrap_late_bindings [--image_path=gcr.io/path/to/image/tag]
20 | """
21 | import itertools
22 |
23 | from absl import app
24 | from xmanager import xm
25 | from xmanager import xm_local
26 | from xmanager.contrib import xm_tensorflow
27 |
28 |
29 | def main(_):
30 | with xm_local.create_experiment(
31 | experiment_title='kubernetes_multiworker'
32 | ) as experiment:
33 | spec = xm.PythonContainer(
34 | # Package the current directory that this script is in.
35 | path='.',
36 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6',
37 | entrypoint=xm.ModuleName('cifar10'),
38 | )
39 |
40 | [executable] = experiment.package(
41 | [
42 | xm.Packageable(
43 | executable_spec=spec,
44 | executor_spec=xm_local.Kubernetes.Spec(),
45 | args={},
46 | )
47 | ]
48 | )
49 |
50 | learning_rates = [0.001]
51 | trials = list(
52 | dict([('learning_rate', lr)])
53 | for (lr,) in itertools.product(learning_rates)
54 | )
55 |
56 | builder = xm_tensorflow.ParameterServerStrategyBuilder(
57 | experiment=experiment,
58 | chief_executable=executable,
59 | chief_executor=xm_local.Kubernetes(
60 | requirements=xm.JobRequirements(t4=1)
61 | ),
62 | worker_executable=executable,
63 | worker_executor=xm_local.Kubernetes(
64 | requirements=xm.JobRequirements(t4=1)
65 | ),
66 | worker_name='worker',
67 | ps_executable=executable,
68 | ps_executor=xm_local.Kubernetes(),
69 | ps_name='ps',
70 | num_workers=2,
71 | num_ps=1,
72 | )
73 |
74 | for hyperparameters in trials:
75 | experiment.add(builder.gen_job_group(), args=hyperparameters)
76 |
77 |
78 | if __name__ == '__main__':
79 | app.run(main)
80 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_ps/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow
3 | tensorflow-datasets
4 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_tensorboard/cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Tensorflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn."""
15 | from absl import app
16 | from absl import flags
17 | import tensorflow as tf
18 | from tensorflow.keras import datasets
19 | from tensorflow.keras import layers
20 | from tensorflow.keras import models
21 |
22 | _TENSORBOARD_LOG_DIR = flags.DEFINE_string('tensorboard_log_dir', None, '')
23 |
24 | FLAGS = flags.FLAGS
25 | flags.DEFINE_integer('epochs', 5, 'epochs')
26 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate')
27 |
28 |
29 | def main(_):
30 | (train_images, train_labels), (test_images, test_labels) = (
31 | datasets.cifar10.load_data()
32 | )
33 |
34 | # Normalize pixel values to be between 0 and 1
35 | train_images, test_images = train_images / 255.0, test_images / 255.0
36 |
37 | strategy = tf.distribute.MirroredStrategy()
38 | with strategy.scope():
39 | model = models.Sequential()
40 | model.add(
41 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))
42 | )
43 | model.add(layers.MaxPooling2D((2, 2)))
44 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
45 | model.add(layers.MaxPooling2D((2, 2)))
46 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
47 |
48 | model.add(layers.Flatten())
49 | model.add(layers.Dense(64, activation='relu'))
50 | model.add(layers.Dense(10))
51 |
52 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
53 | model.compile(
54 | optimizer=optimizer,
55 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
56 | metrics=['accuracy'],
57 | )
58 |
59 | callbacks = []
60 | if _TENSORBOARD_LOG_DIR.value:
61 | callbacks = [
62 | tf.keras.callbacks.TensorBoard(
63 | log_dir=_TENSORBOARD_LOG_DIR.value,
64 | histogram_freq=1,
65 | ),
66 | ]
67 |
68 | model.fit(
69 | train_images,
70 | train_labels,
71 | epochs=FLAGS.epochs,
72 | validation_data=(test_images, test_labels),
73 | callbacks=callbacks,
74 | verbose=2,
75 | )
76 |
77 |
78 | if __name__ == '__main__':
79 | app.run(main)
80 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_tensorboard/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for running CIFAR10 on Kubernetes with Tensorboard auxiliary job.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/cifar10_tensorflow_k8s_tensorboard/launcher.py -- \
19 | --tensorboard_log_dir=TENSORBOARD_LOG_DIR \
20 | [--tensorboard_timeout_secs=TIMEOUT_SECS]
21 | """
22 | import itertools
23 | import os
24 |
25 | from absl import app
26 | from absl import flags
27 | from xmanager import xm
28 | from xmanager import xm_local
29 | from xmanager.contrib import tensorboard
30 |
31 | _TENSORBOARD_LOG_DIR = flags.DEFINE_string(
32 | 'tensorboard_log_dir',
33 | None,
34 | 'Log directory to be used by workers and Tensorboard.',
35 | )
36 |
37 | _TENSORBOARD_TIMEOUT_SECS = flags.DEFINE_integer(
38 | 'tensorboard_timeout_secs',
39 | 60 * 60,
40 | 'The amount of time the Tensorboard job should run for.',
41 | )
42 |
43 |
44 | def main(_):
45 | with xm_local.create_experiment(experiment_title='cifar10') as experiment:
46 | spec = xm.PythonContainer(
47 | # Package the current directory that this script is in.
48 | path='.',
49 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6',
50 | entrypoint=xm.ModuleName('cifar10'),
51 | )
52 |
53 | [executable] = experiment.package(
54 | [
55 | xm.Packageable(
56 | executable_spec=spec,
57 | executor_spec=xm_local.Kubernetes.Spec(),
58 | ),
59 | ]
60 | )
61 |
62 | learning_rates = [0.1, 0.001]
63 | trials = list(
64 | dict([('learning_rate', lr)])
65 | for (lr,) in itertools.product(learning_rates)
66 | )
67 |
68 | log_dir = None
69 | if _TENSORBOARD_LOG_DIR.value:
70 | log_dir = (
71 | f'{_TENSORBOARD_LOG_DIR.value}/{str(experiment.experiment_id)}/logs'
72 | )
73 |
74 | if log_dir:
75 | tensorboard.add_tensorboard(
76 | experiment,
77 | log_dir,
78 | executor=xm_local.Kubernetes(),
79 | timeout_secs=_TENSORBOARD_TIMEOUT_SECS.value,
80 | )
81 |
82 | for i, hyperparameters in enumerate(trials):
83 | output_dir = os.path.join(log_dir, str(i))
84 | experiment.add(
85 | xm.Job(
86 | executable=executable,
87 | executor=xm_local.Kubernetes(),
88 | args=dict({'tensorboard_log_dir': output_dir, **hyperparameters}),
89 | )
90 | )
91 |
92 |
93 | if __name__ == '__main__':
94 | app.run(main)
95 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_k8s_tensorboard/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow
3 | tensorflow-datasets
4 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_tpu/cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Tensorflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn."""
15 | import os
16 |
17 | from absl import app
18 | from absl import flags
19 | import tensorflow as tf
20 | from tensorflow.keras import datasets
21 | from tensorflow.keras import layers
22 | from tensorflow.keras import models
23 |
24 | # When using Vertex Tensorboard, the tensorboard will be present as a
25 | # environment variable.
26 | LOG_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '')
27 |
28 | FLAGS = flags.FLAGS
29 | flags.DEFINE_integer('epochs', 5, 'epochs')
30 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate')
31 |
32 |
33 | def main(_):
34 | (train_images, train_labels), (test_images, test_labels) = (
35 | datasets.cifar10.load_data()
36 | )
37 |
38 | # Normalize pixel values to be between 0 and 1
39 | train_images, test_images = train_images / 255.0, test_images / 255.0
40 |
41 | resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
42 | tf.config.experimental_connect_to_cluster(resolver)
43 | tf.tpu.experimental.initialize_tpu_system(resolver)
44 | strategy = tf.distribute.TPUStrategy(resolver)
45 | with strategy.scope():
46 | model = models.Sequential()
47 | model.add(
48 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))
49 | )
50 | model.add(layers.MaxPooling2D((2, 2)))
51 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
52 | model.add(layers.MaxPooling2D((2, 2)))
53 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
54 |
55 | model.add(layers.Flatten())
56 | model.add(layers.Dense(64, activation='relu'))
57 | model.add(layers.Dense(10))
58 |
59 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
60 | model.compile(
61 | optimizer=optimizer,
62 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
63 | metrics=['accuracy'],
64 | )
65 |
66 | callbacks = []
67 | if LOG_DIR:
68 | callbacks = [
69 | tf.keras.callbacks.TensorBoard(
70 | log_dir=LOG_DIR,
71 | histogram_freq=1,
72 | ),
73 | ]
74 |
75 | model.fit(
76 | train_images,
77 | train_labels,
78 | epochs=FLAGS.epochs,
79 | validation_data=(test_images, test_labels),
80 | callbacks=callbacks,
81 | verbose=2,
82 | )
83 |
84 |
85 | if __name__ == '__main__':
86 | app.run(main)
87 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_tpu/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for CIFAR10.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/cifar10_tensorflow_tpu/launcher.py
19 | """
20 | import asyncio
21 | import itertools
22 | import os
23 |
24 | from absl import app
25 | from absl import flags
26 | from xmanager import xm
27 | from xmanager import xm_local
28 | from xmanager.cloud import build_image
29 | from xmanager.cloud import vertex
30 | from xmanager.contrib import tpu
31 |
32 | FLAGS = flags.FLAGS
33 | flags.DEFINE_string('tensorboard', None, 'Tensorboard instance.')
34 |
35 |
36 | def main(_):
37 | with xm_local.create_experiment(experiment_title='cifar10') as experiment:
38 | directory = os.path.basename(os.path.dirname(os.path.realpath(__file__)))
39 | # pyformat: disable
40 | spec = xm.PythonContainer(
41 | # Package the current directory that this script is in.
42 | path='.',
43 | # tpuvm requires Python3.8 and GLIBC_2.29, which requires at least
44 | # debian:11 or ubuntu:20.04
45 | base_image='ubuntu:20.04',
46 | docker_instructions=(
47 | ['RUN apt-get update --allow-releaseinfo-change && apt-get install -y python-is-python3 python3-pip wget'] + # pylint: disable=line-too-long
48 | build_image.default_steps(directory, use_deep_module=False) +
49 | tpu.tpuvm_docker_instructions()),
50 | entrypoint=xm.ModuleName('cifar10'),
51 | )
52 | # pyformat: enable
53 |
54 | [executable] = experiment.package(
55 | [
56 | xm.Packageable(
57 | executable_spec=spec,
58 | executor_spec=xm_local.Vertex.Spec(),
59 | args={},
60 | ),
61 | ]
62 | )
63 |
64 | learning_rates = [0.1, 0.001]
65 | trials = list(
66 | dict([('learning_rate', lr)])
67 | for (lr,) in itertools.product(learning_rates)
68 | )
69 |
70 | tensorboard = FLAGS.tensorboard
71 | if not tensorboard:
72 | tensorboard = vertex.get_default_client().get_or_create_tensorboard(
73 | 'cifar10'
74 | )
75 | tensorboard = asyncio.get_event_loop().run_until_complete(tensorboard)
76 |
77 | for i, hyperparameters in enumerate(trials):
78 | output_dir = os.environ.get('GOOGLE_CLOUD_BUCKET_NAME', None)
79 | if output_dir:
80 | output_dir = os.path.join(
81 | output_dir, str(experiment.experiment_id), str(i)
82 | )
83 | tensorboard_capability = xm_local.TensorboardCapability(
84 | name=tensorboard, base_output_directory=output_dir
85 | )
86 | experiment.add(
87 | xm.Job(
88 | executable=executable,
89 | executor=xm_local.Vertex(
90 | requirements=xm.JobRequirements(tpu_v2=8),
91 | tensorboard=tensorboard_capability,
92 | ),
93 | args=hyperparameters,
94 | )
95 | )
96 |
97 |
98 | if __name__ == '__main__':
99 | app.run(main)
100 |
--------------------------------------------------------------------------------
/examples/cifar10_tensorflow_tpu/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | # tensorflow # installed as part of tpuvm
3 | tensorflow-datasets
4 |
--------------------------------------------------------------------------------
/examples/cifar10_torch/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for CIFAR10.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/cifar10_torch/launcher.py -- \
19 | --xm_wrap_late_bindings [--image_path=gcr.io/path/to/image/tag]
20 | """
21 |
22 | import itertools
23 |
24 | from absl import app
25 | from absl import flags
26 | from xmanager import xm
27 | from xmanager import xm_local
28 | from xmanager.cloud import utils
29 |
30 | FLAGS = flags.FLAGS
31 | flags.DEFINE_string('image_path', None, 'Image path.')
32 |
33 | flags.DEFINE_integer('nodes', 1, 'Number of nodes.')
34 | flags.DEFINE_integer('gpus_per_node', 2, 'Number of GPUs per node.')
35 |
36 |
37 | @xm.run_in_asyncio_loop
38 | async def main(_):
39 | async with xm_local.create_experiment(
40 | experiment_title='cifar10'
41 | ) as experiment:
42 | if FLAGS.image_path:
43 | spec = xm.Container(image_path=FLAGS.image_path)
44 | else:
45 | spec = xm.PythonContainer(
46 | # Package the current directory that this script is in.
47 | path='.',
48 | base_image='gcr.io/deeplearning-platform-release/pytorch-gpu.1-12',
49 | entrypoint=xm.ModuleName('cifar10'),
50 | )
51 |
52 | [executable] = experiment.package(
53 | [
54 | xm.Packageable(
55 | executable_spec=spec,
56 | executor_spec=xm_local.Vertex.Spec(),
57 | args={
58 | # TODO: replace workerpool0 with the actual
59 | # name of the job when Vertex AI supports custom name worker
60 | # pools.
61 | 'master_addr_port': xm.ShellSafeArg(
62 | utils.get_workerpool_address('workerpool0')
63 | ),
64 | },
65 | ),
66 | ]
67 | )
68 |
69 | batch_sizes = [64, 1024]
70 | learning_rates = [0.1, 0.001]
71 | trials = list(
72 | dict([('batch_size', bs), ('learning_rate', lr)])
73 | for (bs, lr) in itertools.product(batch_sizes, learning_rates)
74 | )
75 |
76 | work_units = []
77 | for hyperparameters in trials:
78 | job_group = xm.JobGroup()
79 | for i in range(FLAGS.nodes):
80 | hyperparameters = dict(hyperparameters)
81 | hyperparameters['world_size'] = FLAGS.nodes
82 | hyperparameters['rank'] = i
83 | job_group.jobs[f'node_{i}'] = xm.Job(
84 | executable=executable,
85 | executor=xm_local.Vertex(
86 | xm.JobRequirements(t4=FLAGS.gpus_per_node)
87 | ),
88 | args=hyperparameters,
89 | )
90 | work_units.append(await experiment.add(job_group))
91 | print('Waiting for async launches to return values...')
92 | for work_unit in work_units:
93 | await work_unit.wait_until_complete()
94 | print('Experiment completed.')
95 |
96 |
97 | if __name__ == '__main__':
98 | app.run(main)
99 |
--------------------------------------------------------------------------------
/examples/cifar10_torch/requirements.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | absl-py
16 | numpy
17 | torch
18 | torchvision
19 |
--------------------------------------------------------------------------------
/examples/cifar10_torch_xla/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for CIFAR10.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/cifar10_torch/launcher.py -- \
19 | --xm_wrap_late_bindings \
20 | [--image_path=gcr.io/path/to/image/tag] \
21 | [--platform=gpu]
22 | """
23 | import itertools
24 |
25 | from absl import app
26 | from absl import flags
27 | from xmanager import xm
28 | from xmanager import xm_local
29 |
30 | FLAGS = flags.FLAGS
31 | flags.DEFINE_string('image_path', None, 'Image path.')
32 | flags.DEFINE_string('platform', 'cpu', 'cpu/gpu/tpu.')
33 | flags.DEFINE_integer('cores', 1, 'Number of cores. Use 8 if platform==tpu.')
34 |
35 |
36 | def main(_):
37 | with xm_local.create_experiment(experiment_title='cifar10') as experiment:
38 | if FLAGS.image_path:
39 | spec = xm.Container(image_path=FLAGS.image_path)
40 | else:
41 | # Package the current directory that this script is in.
42 | spec = xm.PythonContainer(
43 | path='.',
44 | # This base_image is experimental and works with cpu/gpu/tpu.
45 | # https://cloud.google.com/ai-platform/deep-learning-containers/docs/choosing-container
46 | base_image='gcr.io/deeplearning-platform-release/pytorch-xla.1-8',
47 | entrypoint=xm.ModuleName('cifar10'),
48 | )
49 |
50 | [executable] = experiment.package(
51 | [
52 | xm.Packageable(
53 | executable_spec=spec,
54 | executor_spec=xm_local.Vertex.Spec(),
55 | args={'platform': FLAGS.platform},
56 | ),
57 | ]
58 | )
59 |
60 | batch_sizes = [64, 1024]
61 | learning_rates = [0.1, 0.001]
62 | trials = list(
63 | dict([('batch_size', bs), ('learning_rate', lr)])
64 | for (bs, lr) in itertools.product(batch_sizes, learning_rates)
65 | )
66 |
67 | requirements = xm.JobRequirements()
68 | if FLAGS.platform == 'gpu':
69 | requirements = xm.JobRequirements(t4=FLAGS.cores)
70 | elif FLAGS.platform == 'tpu':
71 | requirements = xm.JobRequirements(tpu_v3=8)
72 | for hyperparameters in trials:
73 | jobs = {}
74 | jobs['coordinator'] = xm.Job(
75 | executable=executable,
76 | executor=xm_local.Vertex(requirements),
77 | args=hyperparameters,
78 | )
79 | experiment.add(xm.JobGroup(**jobs))
80 | break
81 | print('Waiting for async launches to return values...')
82 | print('Launch completed and successful.')
83 |
84 |
85 | if __name__ == '__main__':
86 | app.run(main)
87 |
--------------------------------------------------------------------------------
/examples/cifar10_torch_xla/requirements.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | absl-py
16 | numpy
17 | tensorflow
18 | torch
19 | torchvision
20 |
--------------------------------------------------------------------------------
/examples/dockerfile/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM gcc:latest
2 |
3 | COPY arg_printer.cc arg_printer.cc
4 | RUN g++ -o arg_printer arg_printer.cc
5 |
6 | ENTRYPOINT ["./arg_printer"]
7 |
--------------------------------------------------------------------------------
/examples/dockerfile/arg_printer.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2021 DeepMind Technologies Limited
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 |
18 | int main(int argc, char* argv[]) {
19 | std::cout << "Hello World!" << std::endl;
20 |
21 | std::cout << std::getenv("FOO") << std::endl;
22 | for (int i = 0; i < argc; ++i) {
23 | std::cout << argv[i] << std::endl;
24 | }
25 | return 0;
26 | }
27 |
--------------------------------------------------------------------------------
/examples/dockerfile/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """XManager launcher that runs an image built from a Dockerfile."""
15 |
16 | from typing import Sequence
17 |
18 | from absl import app
19 | from xmanager import xm
20 | from xmanager import xm_local
21 |
22 |
23 | def main(argv: Sequence[str]) -> None:
24 | del argv
25 |
26 | with xm_local.create_experiment(
27 | experiment_title='Example using Dockerfile()'
28 | ) as experiment:
29 | executable_spec = xm.Dockerfile()
30 | [executable] = experiment.package(
31 | [
32 | xm.Packageable(
33 | executable_spec=executable_spec,
34 | executor_spec=xm_local.Vertex.Spec(),
35 | ),
36 | ]
37 | )
38 | experiment.add(
39 | xm.Job(
40 | executable=executable,
41 | executor=xm_local.Vertex(),
42 | env_vars={'FOO': 'bar'},
43 | args=['--a=1', '--b=2', '--c=3', '--d=4'],
44 | )
45 | )
46 |
47 |
48 | if __name__ == '__main__':
49 | app.run(main)
50 |
--------------------------------------------------------------------------------
/examples/dopamine/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""Launcher for Dopamine.
15 |
16 | Usage:
17 | xmanager launch examples/dopaminelauncher.py -- \
18 | --gin_file=https://raw.githubusercontent.com/google/dopamine/master/dopamine/agents/dqn/configs/dqn_mountaincar.gin
19 | """
20 | import asyncio
21 | import os
22 |
23 | from absl import app
24 | from absl import flags
25 | from xmanager import xm
26 | from xmanager import xm_local
27 | from xmanager.cloud import vertex
28 |
29 | FLAGS = flags.FLAGS
30 | flags.DEFINE_string(
31 | 'gin_file',
32 | 'https://raw.githubusercontent.com/google/dopamine/master/dopamine/agents/dqn/configs/dqn_mountaincar.gin',
33 | 'Gin file pulled from https://github.com/google/dopamine.',
34 | )
35 | flags.DEFINE_string('tensorboard', None, 'Tensorboard instance.')
36 |
37 |
38 | def main(_):
39 | with xm_local.create_experiment(experiment_title='dopamine') as experiment:
40 | gin_file = os.path.basename(FLAGS.gin_file)
41 | add_instruction = f'ADD {FLAGS.gin_file} {gin_file}'
42 | if FLAGS.gin_file.startswith('http'):
43 | add_instruction = f'RUN wget -O ./{gin_file} {FLAGS.gin_file}'
44 | spec = xm.PythonContainer(
45 | docker_instructions=[
46 | 'RUN apt update && apt install -y python3-opencv',
47 | 'RUN pip install dopamine-rl',
48 | 'COPY dopamine/ workdir',
49 | 'WORKDIR workdir',
50 | add_instruction,
51 | ],
52 | entrypoint=xm.ModuleName('train'),
53 | )
54 |
55 | [executable] = experiment.package(
56 | [
57 | xm.Packageable(
58 | executable_spec=spec,
59 | executor_spec=xm_local.Vertex.Spec(),
60 | args={
61 | 'gin_files': gin_file,
62 | },
63 | ),
64 | ]
65 | )
66 |
67 | tensorboard = FLAGS.tensorboard
68 | if not tensorboard:
69 | tensorboard = vertex.get_default_client().get_or_create_tensorboard(
70 | 'cifar10'
71 | )
72 | tensorboard = asyncio.get_event_loop().run_until_complete(tensorboard)
73 | output_dir = os.environ['GOOGLE_CLOUD_BUCKET_NAME']
74 | output_dir = os.path.join(output_dir, str(experiment.experiment_id))
75 | tensorboard_capability = xm_local.TensorboardCapability(
76 | name=tensorboard, base_output_directory=output_dir
77 | )
78 | experiment.add(
79 | xm.Job(
80 | executable=executable,
81 | executor=xm_local.Vertex(
82 | xm.JobRequirements(t4=1), tensorboard=tensorboard_capability
83 | ),
84 | )
85 | )
86 |
87 |
88 | if __name__ == '__main__':
89 | app.run(main)
90 |
--------------------------------------------------------------------------------
/examples/dopamine/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """The entry point for running a Dopamine agent."""
15 | import os
16 |
17 | from absl import app
18 | from absl import flags
19 | from absl import logging
20 |
21 | from dopamine.discrete_domains import run_experiment
22 | import tensorflow as tf
23 |
24 | FLAGS = flags.FLAGS
25 | flags.DEFINE_multi_string(
26 | 'gin_files',
27 | [],
28 | (
29 | 'List of paths to gin configuration files (e.g.'
30 | '"dopamine/agents/dqn/dqn.gin").'
31 | ),
32 | )
33 |
34 | # When using Vertex Tensorboard, the tensorboard will be present as a
35 | # environment variable.
36 | BASE_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '/tmp/dopamine_runs')
37 |
38 |
39 | def main(unused_argv):
40 | logging.set_verbosity(logging.INFO)
41 | tf.compat.v1.disable_v2_behavior()
42 | run_experiment.load_gin_configs(FLAGS.gin_files, [])
43 | runner = run_experiment.create_runner(BASE_DIR)
44 | runner.run_experiment()
45 |
46 |
47 | if __name__ == '__main__':
48 | app.run(main)
49 |
--------------------------------------------------------------------------------
/examples/local_arg_printer/BUILD:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | licenses(["notice"])
16 |
17 | load("@rules_cc//cc:defs.bzl", "cc_binary")
18 |
19 | cc_binary(
20 | name = "arg_printer",
21 | srcs = ["arg_printer.cc"],
22 | )
23 |
--------------------------------------------------------------------------------
/examples/local_arg_printer/WORKSPACE:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/local_arg_printer/arg_printer.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2021 DeepMind Technologies Limited
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 |
20 | int main(int argc, char* argv[]) {
21 | std::ofstream ouf(std::getenv("OUTPUT_PATH"));
22 | ouf << argc << std::endl;
23 | for (int i = 0; i < argc; ++i) {
24 | ouf << argv[i] << std::endl;
25 | }
26 | // Sleep to demonstrate local jobs waiting.
27 | std::this_thread::sleep_for(std::chrono::seconds(5));
28 | return 0;
29 | }
30 |
--------------------------------------------------------------------------------
/examples/local_arg_printer/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """XManager launcher that runs locally a binary built with Bazel.
15 |
16 | One must `cd` into xmanager/examples/local_arg_printer/ in order to run this
17 | example because Bazel needs to locate the WORKSPACE file.
18 | """
19 |
20 | from typing import Sequence
21 |
22 | from absl import app
23 | from xmanager import xm
24 | from xmanager import xm_local
25 |
26 |
27 | def main(argv: Sequence[str]) -> None:
28 | del argv
29 |
30 | with xm_local.create_experiment(
31 | experiment_title='local_arg_printer'
32 | ) as experiment:
33 | [executable] = experiment.package(
34 | [
35 | xm.Packageable(
36 | executable_spec=xm.BazelBinary(
37 | label='//:arg_printer'
38 | ),
39 | executor_spec=xm_local.Local.Spec(),
40 | ),
41 | ]
42 | )
43 | experiment.add(
44 | xm.Job(
45 | executable=executable,
46 | executor=xm_local.Local(),
47 | env_vars={'OUTPUT_PATH': '/tmp/local_arg_printer.txt'},
48 | )
49 | )
50 |
51 |
52 | if __name__ == '__main__':
53 | app.run(main)
54 |
--------------------------------------------------------------------------------
/examples/local_container_gpu/cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Tensorflow Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Code based on https://www.tensorflow.org/tutorials/images/cnn."""
15 | import os
16 |
17 | from absl import app
18 | from absl import flags
19 | import tensorflow as tf
20 | from tensorflow.keras import datasets
21 | from tensorflow.keras import layers
22 | from tensorflow.keras import models
23 |
24 | # When using Vertex Tensorboard, the tensorboard will be present as a
25 | # environment variable.
26 | LOG_DIR = os.environ.get('AIP_TENSORBOARD_LOG_DIR', '')
27 |
28 | FLAGS = flags.FLAGS
29 | flags.DEFINE_integer('epochs', 5, 'epochs')
30 | flags.DEFINE_float('learning_rate', 0.001, 'learning rate')
31 |
32 |
33 | def main(_):
34 | (train_images, train_labels), (test_images, test_labels) = (
35 | datasets.cifar10.load_data()
36 | )
37 |
38 | # Normalize pixel values to be between 0 and 1
39 | train_images, test_images = train_images / 255.0, test_images / 255.0
40 |
41 | strategy = tf.distribute.MirroredStrategy()
42 | with strategy.scope():
43 | model = models.Sequential()
44 | model.add(
45 | layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))
46 | )
47 | model.add(layers.MaxPooling2D((2, 2)))
48 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
49 | model.add(layers.MaxPooling2D((2, 2)))
50 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
51 |
52 | model.add(layers.Flatten())
53 | model.add(layers.Dense(64, activation='relu'))
54 | model.add(layers.Dense(10))
55 |
56 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
57 | model.compile(
58 | optimizer=optimizer,
59 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
60 | metrics=['accuracy'],
61 | )
62 |
63 | callbacks = []
64 | if LOG_DIR:
65 | callbacks = [
66 | tf.keras.callbacks.TensorBoard(
67 | log_dir=LOG_DIR,
68 | histogram_freq=1,
69 | ),
70 | ]
71 |
72 | model.fit(
73 | train_images,
74 | train_labels,
75 | epochs=FLAGS.epochs,
76 | validation_data=(test_images, test_labels),
77 | callbacks=callbacks,
78 | verbose=2,
79 | )
80 |
81 |
82 | if __name__ == '__main__':
83 | app.run(main)
84 |
--------------------------------------------------------------------------------
/examples/local_container_gpu/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager local launcher for CIFAR10 using GPUs.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/local_container_gpu/launcher.py -- \
19 | --xm_wrap_late_bindings
20 | """
21 |
22 | from absl import app
23 | from absl import flags
24 | from xmanager import xm
25 | from xmanager import xm_local
26 |
27 | _EXP_NAME = flags.DEFINE_string(
28 | 'exp_name', 'local-cifar10-gpu', 'Name of the experiment.', short_name='n'
29 | )
30 | _INTERACTIVE = flags.DEFINE_bool(
31 | 'interactive',
32 | False,
33 | 'Launch the container and allow interactive access to it.',
34 | )
35 |
36 |
37 | def main(argv) -> None:
38 | if len(argv) > 1:
39 | raise app.UsageError('Too many command-line arguments.')
40 |
41 | create_experiment = xm_local.create_experiment
42 | with create_experiment(experiment_title=_EXP_NAME.value) as experiment:
43 | docker_options = xm_local.DockerOptions(interactive=_INTERACTIVE.value)
44 | # Creating local executor with extra flag to track job's progress.
45 | executor = xm_local.Local(
46 | xm.JobRequirements(local_gpu=2),
47 | experimental_stream_output=True,
48 | docker_options=docker_options,
49 | )
50 |
51 | # Empty args means nothing is passed into the job.
52 | executable_args = {}
53 | (executable,) = experiment.package(
54 | [
55 | xm.python_container(
56 | executor_spec=executor.Spec(),
57 | args=executable_args,
58 | # Package the current directory that this script is in.
59 | path='.',
60 | base_image='gcr.io/deeplearning-platform-release/tf2-gpu.2-6',
61 | entrypoint=xm.ModuleName('local_container_gpu.cifar10'),
62 | use_deep_module=True,
63 | )
64 | ]
65 | )
66 | job = xm.Job(executable, executor)
67 |
68 | experiment.add(job)
69 |
70 |
71 | if __name__ == '__main__':
72 | app.run(main)
73 |
--------------------------------------------------------------------------------
/examples/local_container_gpu/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow
3 | tensorflow-datasets
4 |
--------------------------------------------------------------------------------
/examples/local_container_links/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | load("@io_bazel_rules_docker//container:container.bzl", "container_image")
16 | load("@python_deps//:requirements.bzl", "requirement")
17 | load("@subpar//:subpar.bzl", "par_binary")
18 |
19 | licenses(["notice"])
20 |
21 | par_binary(
22 | name = "server",
23 | srcs = ["server.py"],
24 | deps = [
25 | requirement("absl-py"),
26 | requirement("bottle"),
27 | requirement("bottle-redis"),
28 | ],
29 | )
30 |
31 | container_image(
32 | name = "server_image",
33 | base = "@io_docker_index_library_python//image",
34 | entrypoint = ["/server.par"],
35 | files = [":server.par"],
36 | )
37 |
--------------------------------------------------------------------------------
/examples/local_container_links/README.md:
--------------------------------------------------------------------------------
1 | ## Description
2 |
3 | launcher.py packages and executes (locally) two Docker containers:
4 | [Redis](https://hub.docker.com/_/redis), a key-value database, and a
5 | [Bottle](https://bottlepy.org/)-based HTTP server that listens on port 8080 and
6 | responds to requests at `/increment` by running
7 | [`INCR counter`](https://redis.io/commands/INCR) on Redis and sending back the
8 | result.
9 |
10 | ## Instructions
11 |
12 | 1. Clone the repository and install XManager via `pip install ./xmanager`.
13 | 2. Install Bazelisk via `npm install -g @bazel/bazelisk`.
14 | 3. Get into this directory and run `xmanager launch launcher.py -- --xm_bazel_command=bazelisk`.
15 | 4. Once it is idling, open `http://localhost:8080/increment` in the browser.
16 |
--------------------------------------------------------------------------------
/examples/local_container_links/WORKSPACE:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
16 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
17 |
18 | http_archive(
19 | name = "io_bazel_rules_docker",
20 | sha256 = "59d5b42ac315e7eadffa944e86e90c2990110a1c8075f1cd145f487e999d22b3",
21 | strip_prefix = "rules_docker-0.17.0",
22 | urls = ["https://github.com/bazelbuild/rules_docker/releases/download/v0.17.0/rules_docker-v0.17.0.tar.gz"],
23 | )
24 |
25 | http_archive(
26 | name = "rules_python",
27 | sha256 = "934c9ceb552e84577b0faf1e5a2f0450314985b4d8712b2b70717dc679fdc01b",
28 | url = "https://github.com/bazelbuild/rules_python/releases/download/0.3.0/rules_python-0.3.0.tar.gz",
29 | )
30 |
31 | git_repository(
32 | name = "subpar",
33 | remote = "https://github.com/google/subpar",
34 | tag = "2.0.0",
35 | )
36 |
37 | load("@rules_python//python:pip.bzl", "pip_install")
38 |
39 | pip_install(
40 | name = "python_deps",
41 | requirements = "//:requirements.txt",
42 | )
43 |
44 | load("@io_bazel_rules_docker//repositories:repositories.bzl", container_repositories = "repositories")
45 |
46 | container_repositories()
47 |
48 | load("@io_bazel_rules_docker//repositories:deps.bzl", container_deps = "deps")
49 |
50 | container_deps()
51 |
52 | load("@io_bazel_rules_docker//container:container.bzl", "container_pull")
53 |
54 | container_pull(
55 | name = "io_docker_index_library_python",
56 | digest = "sha256:11f3ccfbb8e809246f5993be99693966ffc99e1c7b632251fde27c0ce45b35f2",
57 | registry = "index.docker.io",
58 | repository = "library/python",
59 | tag = "3.9.6",
60 | )
61 |
--------------------------------------------------------------------------------
/examples/local_container_links/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A launcher for server.py and Redis.
15 |
16 | See README.md for details.
17 | """
18 |
19 | from typing import Sequence
20 |
21 | from absl import app
22 | from xmanager import xm
23 | from xmanager import xm_local
24 |
25 |
26 | def main(argv: Sequence[str]) -> None:
27 | del argv # Unused.
28 |
29 | with xm_local.create_experiment(
30 | experiment_title='local_container_links'
31 | ) as experiment:
32 | [redis, server] = experiment.package([
33 | xm.Packageable(
34 | executable_spec=xm.Container(image_path='redis'),
35 | executor_spec=xm_local.Local.Spec(),
36 | ),
37 | xm.Packageable(
38 | executable_spec=xm.BazelContainer(label='//:server_image.tar'),
39 | executor_spec=xm_local.Local.Spec(),
40 | ),
41 | ])
42 |
43 | async def generator(work_unit):
44 | work_unit.add(
45 | xm.JobGroup(
46 | server=xm.Job(
47 | executable=server,
48 | executor=xm_local.Local(
49 | docker_options=xm_local.DockerOptions(ports={8080: 8080})
50 | ),
51 | args={'redis_host': work_unit.get_full_job_name('redis')},
52 | ),
53 | redis=xm.Job(
54 | name='redis',
55 | executable=redis,
56 | executor=xm_local.Local(
57 | docker_options=xm_local.DockerOptions(
58 | volumes={'/tmp/redis': '/data'}
59 | )
60 | ),
61 | ),
62 | )
63 | )
64 |
65 | experiment.add(generator)
66 |
67 |
68 | if __name__ == '__main__':
69 | app.run(main)
70 |
--------------------------------------------------------------------------------
/examples/local_container_links/requirements.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | absl-py
16 | bottle
17 | bottle-redis
18 |
--------------------------------------------------------------------------------
/examples/local_container_links/server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """An HTTP server incrementing a value in Redis."""
15 |
16 | from typing import Sequence
17 |
18 | from absl import app
19 | from absl import flags
20 | import bottle
21 | import bottle.ext.redis
22 |
23 | redis_host = flags.DEFINE_string('redis_host', None, "Redis' host.")
24 |
25 | server = bottle.Bottle()
26 |
27 |
28 | @server.route('/increment')
29 | def increment(rdb):
30 | return str(rdb.incr('counter'))
31 |
32 |
33 | def main(argv: Sequence[str]) -> None:
34 | del argv # Unused.
35 |
36 | server.install(bottle.ext.redis.RedisPlugin(host=redis_host.value))
37 | bottle.run(server, host='0.0.0.0', port=8080, debug=True)
38 |
39 |
40 | if __name__ == '__main__':
41 | app.run(main)
42 |
--------------------------------------------------------------------------------
/examples/parameter_controller/db_config.yaml:
--------------------------------------------------------------------------------
1 | # Connector used - one of ['cloudsql', 'sqlite', 'generic']
2 | sql_connector: ''
3 |
4 | sql_connection_settings:
5 | # Backend used, e.g. 'mysql', 'postgresql'
6 | backend: ''
7 | # Username
8 | username: ''
9 | # Password
10 | password: ''
11 | # Host (or instance connection name for CloudSql)
12 | host: ''
13 | # Port
14 | port: 0
15 | # Database name
16 | db_name: ''
--------------------------------------------------------------------------------
/examples/parameter_controller/inner_job/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
--------------------------------------------------------------------------------
/examples/parameter_controller/inner_job/wait_job.py:
--------------------------------------------------------------------------------
1 | """Job waiting."""
2 | import time
3 |
4 | from absl import app
5 | from absl import flags
6 |
7 | _TIME_TO_SLEEP = flags.DEFINE_integer('time_to_sleep', 10, 'Time to sleep.')
8 |
9 |
10 | def main(_):
11 | print(f'Hello, waiting for {_TIME_TO_SLEEP.value}...')
12 | time.sleep(_TIME_TO_SLEEP.value)
13 | print('Done!')
14 |
15 |
16 | if __name__ == '__main__':
17 | app.run(main)
18 |
--------------------------------------------------------------------------------
/examples/parameter_controller/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for a parameter controller example.
15 |
16 | The given program launches a dummy job on VertexAI, waits for its completion
17 | and then launches another job on Kubernetes. This kind of workflow can be used
18 | to define pipelines.
19 |
20 | Usage:
21 |
22 |
23 | xmanager launch examples/parameter_controller/launcher.py -- \
24 | --xm_db_yaml_config=db_config.yaml
25 | [--xm_k8s_service_account_name=...]
26 | [--xm_gcp_service_account_name=...]
27 |
28 | The content of `db_config.yaml` must be updated to match the connection details
29 | to the DB used.
30 | """
31 | import os
32 |
33 | from absl import app
34 | from absl import flags
35 | from xmanager import xm
36 | from xmanager import xm_local
37 | from xmanager.contrib import parameter_controller
38 |
39 |
40 | def main(_):
41 | with xm_local.create_experiment(experiment_title='cifar10') as experiment:
42 |
43 | @parameter_controller.controller(
44 | executor=xm_local.Kubernetes(),
45 | controller_args={
46 | 'xm_k8s_service_account_name': (
47 | flags.FLAGS.xm_k8s_service_account_name
48 | ),
49 | 'xm_gcp_service_account_name': (
50 | flags.FLAGS.xm_gcp_service_account_name
51 | ),
52 | },
53 | controller_env_vars={
54 | 'GOOGLE_CLOUD_BUCKET_NAME': os.environ['GOOGLE_CLOUD_BUCKET_NAME'],
55 | },
56 | # Package contents of this directory inside parameter controller job
57 | package_path='.',
58 | )
59 | async def parameter_controller_example(experiment: xm.Experiment):
60 | spec = xm.PythonContainer(
61 | # Package contents of job to be launched
62 | path='inner_job',
63 | base_image='python:3.9',
64 | entrypoint=xm.ModuleName('wait_job'),
65 | )
66 | [vertex_executable, k8s_executable] = experiment.package([
67 | xm.Packageable(
68 | executable_spec=spec,
69 | executor_spec=xm_local.Vertex.Spec(),
70 | args={'time_to_sleep': 10},
71 | ),
72 | xm.Packageable(
73 | executable_spec=spec,
74 | executor_spec=xm_local.Kubernetes.Spec(),
75 | args={'time_to_sleep': 20},
76 | ),
77 | ])
78 |
79 | wu1 = await experiment.add(xm.Job(k8s_executable, xm_local.Kubernetes()))
80 | await wu1.wait_until_complete()
81 |
82 | wu2 = await experiment.add(xm.Job(vertex_executable, xm_local.Vertex()))
83 | await wu2.wait_until_complete()
84 |
85 | experiment.add(parameter_controller_example()) # pylint: disable=no-value-for-parameter
86 |
87 |
88 | if __name__ == '__main__':
89 | app.run(main)
90 |
--------------------------------------------------------------------------------
/examples/parameter_controller/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | six
3 | sqlparse
4 | dm-launchpad[tensorflow]
5 | numpy
6 | cloudpickle
7 | cloud-sql-python-connector[pymysql]
--------------------------------------------------------------------------------
/examples/vizier/README.md:
--------------------------------------------------------------------------------
1 | This is a launcher to express the use case for Vertex Vizier:
2 |
3 | https://cloud.google.com/vertex-ai/docs/vizier
4 |
5 |
6 | This is a dummy project for using hyperparameter optimization.
7 |
8 | In this project, we define a bivariate polynomial:
9 |
10 | ax^2 + by^2 + cxy + dx + ey + f
11 |
12 | with fixed args: a,b,c,d,e,f
13 |
14 | and hyper parameters: x,y
15 |
--------------------------------------------------------------------------------
/examples/vizier/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""XManager launcher for Polynomial.
15 |
16 | Usage:
17 |
18 | xmanager launch examples/vizier/launcher.py -- \
19 | --xm_wrap_late_bindings
20 | """
21 |
22 | from absl import app
23 |
24 | from google.cloud import aiplatform_v1beta1 as aip
25 |
26 | from xmanager import xm
27 | from xmanager import xm_local
28 | from xmanager.vizier import vizier_cloud
29 |
30 |
31 | def get_study_spec() -> aip.StudySpec:
32 | return aip.StudySpec(
33 | algorithm=aip.StudySpec.Algorithm.RANDOM_SEARCH,
34 | parameters=[
35 | aip.StudySpec.ParameterSpec(
36 | parameter_id='x',
37 | double_value_spec=aip.StudySpec.ParameterSpec.DoubleValueSpec(
38 | min_value=-2.0, max_value=2.0
39 | ),
40 | ),
41 | aip.StudySpec.ParameterSpec(
42 | parameter_id='y',
43 | double_value_spec=aip.StudySpec.ParameterSpec.DoubleValueSpec(
44 | min_value=-2.0, max_value=2.0
45 | ),
46 | ),
47 | ],
48 | metrics=[
49 | aip.StudySpec.MetricSpec(
50 | metric_id='loss', goal=aip.StudySpec.MetricSpec.GoalType.MINIMIZE
51 | )
52 | ],
53 | )
54 |
55 |
56 | def main(_):
57 | with xm_local.create_experiment(experiment_title='polynomial') as experiment:
58 | spec = xm.PythonContainer(
59 | # Package the current directory that this script is in.
60 | path='.',
61 | base_image='gcr.io/deeplearning-platform-release/base-cpu',
62 | entrypoint=xm.ModuleName('polynomial'),
63 | )
64 |
65 | [executable] = experiment.package(
66 | [
67 | xm.Packageable(
68 | executable_spec=spec,
69 | executor_spec=xm_local.Vertex.Spec(),
70 | ),
71 | ]
72 | )
73 |
74 | vizier_cloud.VizierExploration(
75 | experiment=experiment,
76 | job=xm.Job(
77 | executable=executable,
78 | executor=xm_local.Vertex(),
79 | ),
80 | study_factory=vizier_cloud.NewStudy(study_config=get_study_spec()),
81 | num_trials_total=3,
82 | num_parallel_trial_runs=2,
83 | ).launch()
84 |
85 |
86 | if __name__ == '__main__':
87 | app.run(main)
88 |
--------------------------------------------------------------------------------
/examples/vizier/polynomial.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Compute the value of a bivariate quadratic polynomial.
15 |
16 | An example of finding the min values of x and y with auto hyperparam tuning.
17 | Set a, b, c, d, e, and f to constant args. Set x and y to be hyperparameters.
18 | """
19 |
20 | from absl import app
21 | from absl import flags
22 |
23 | from vizier_worker import VizierWorker
24 |
25 | FLAGS = flags.FLAGS
26 | flags.DEFINE_integer('a', 1, 'a in ax^2 + by^2 + cxy + dx + ey + f')
27 | flags.DEFINE_integer('b', 1, 'b in ax^2 + by^2 + cxy + dx + ey + f')
28 | flags.DEFINE_integer('c', 0, 'c in ax^2 + by^2 + cxy + dx + ey + f')
29 | flags.DEFINE_integer('d', 1, 'd in ax^2 + by^2 + cxy + dx + ey + f')
30 | flags.DEFINE_integer('e', 1, 'e in ax^2 + by^2 + cxy + dx + ey + f')
31 | flags.DEFINE_integer('f', 1, 'f in ax^2 + by^2 + cxy + dx + ey + f')
32 |
33 | flags.DEFINE_float('x', 0, 'The hyperparameter variable X.')
34 | flags.DEFINE_float('y', 0, 'The hyperparameter variable Y.')
35 |
36 | flags.DEFINE_string(
37 | 'trial_name',
38 | None,
39 | (
40 | 'Identifying the current job trial that measurements '
41 | 'will be submitted to with `add_trial_measurement`. Format: '
42 | 'projects/{project}/locations/{location}/studies/{study}/trials/{trial}'
43 | ),
44 | )
45 |
46 |
47 | def main(_):
48 | worker = VizierWorker(FLAGS.trial_name)
49 |
50 | # dummy training loop: "train" for one epoch
51 | metric_value = float(
52 | FLAGS.a * FLAGS.x * FLAGS.x
53 | + FLAGS.b * FLAGS.y * FLAGS.y
54 | + FLAGS.c * FLAGS.x * FLAGS.y
55 | + FLAGS.d * FLAGS.x
56 | + FLAGS.e * FLAGS.y
57 | + FLAGS.f
58 | )
59 |
60 | worker.add_trial_measurement(1, {'loss': metric_value})
61 |
62 |
63 | if __name__ == '__main__':
64 | app.run(main)
65 |
--------------------------------------------------------------------------------
/examples/vizier/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | google-cloud-aiplatform
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Setup configuration specifying XManager dependencies."""
15 |
16 | from setuptools import find_namespace_packages
17 | from setuptools import setup
18 |
19 | with open('README.md', 'r', encoding='utf-8') as fh:
20 | long_description = fh.read()
21 |
22 | setup(
23 | name='xmanager',
24 | version='0.7.1',
25 | description='A framework for managing machine learning experiments',
26 | long_description=long_description,
27 | long_description_content_type='text/markdown',
28 | author='DeepMind Technologies Limited',
29 | packages=find_namespace_packages(exclude=['examples.*']),
30 | package_data={'': ['*.sh', '*.sql', '*.ini', '*.mako']},
31 | python_requires='>=3.10',
32 | install_requires=[
33 | 'absl-py',
34 | 'alembic==1.4.3',
35 | 'async_generator',
36 | 'attrs',
37 | 'cloud-sql-python-connector',
38 | 'docker',
39 | 'etils[epath]',
40 | 'google-api-core',
41 | 'google-api-python-client',
42 | 'google-auth',
43 | 'google-cloud-aiplatform',
44 | 'google-cloud-storage',
45 | 'humanize',
46 | 'immutabledict',
47 | 'kubernetes',
48 | 'pyyaml',
49 | 'sqlalchemy==1.2.19',
50 | 'sqlparse',
51 | 'termcolor',
52 | ],
53 | entry_points={
54 | 'console_scripts': [
55 | 'xmanager = xmanager.cli.cli:entrypoint',
56 | ],
57 | },
58 | # https://github.com/pypa/warehouse/blob/de4a2e5e2ec26d01bf7813da427ebc4725dccde9/warehouse/templates/packaging/detail.html#L20-L60
59 | project_urls={
60 | 'Homepage': 'https://github.com/deepmind/xmanager',
61 | 'Issue tracker': 'https://github.com/deepmind/xmanager/issues',
62 | },
63 | )
64 |
--------------------------------------------------------------------------------
/setup_scripts/install_bazel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #
3 | # Installs the latest version of Bazel.
4 |
5 | # Copyright 2021 DeepMind Technologies Limited
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | sudo apt-get install -y npm
20 | sudo ln -s /usr/bin/nodejs /usr/bin/node
21 | sudo npm install -g @bazel/bazelisk
22 | export PATH=$PATH:$(npm bin)
23 | echo $(bazel --version)
--------------------------------------------------------------------------------
/setup_scripts/install_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #
3 | # Installs the latest version of Docker.
4 |
5 | # Copyright 2021 DeepMind Technologies Limited
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | sudo apt update
20 | sudo apt-get install -y curl
21 | curl -fsSL https://get.docker.com -o get-docker.sh
22 | sudo sh get-docker.sh
--------------------------------------------------------------------------------
/setup_scripts/install_gcloud.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #
3 | # Installs the Google Cloud SDK.
4 |
5 | # Copyright 2021 DeepMind Technologies Limited
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | sudo apt-get update
20 | echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \
21 | | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
22 | sudo apt-get install -y apt-transport-https ca-certificates gnupg curl
23 | curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \
24 | | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
25 | sudo apt-get update
26 | sudo apt-get -y install google-cloud-sdk
--------------------------------------------------------------------------------
/setup_scripts/install_python.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #
3 | # Installs Python 3.9.
4 |
5 | # Copyright 2021 DeepMind Technologies Limited
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | sudo apt-get update
20 | sudo apt-get install -y software-properties-common
21 | sudo add-apt-repository -y ppa:deadsnakes/ppa
22 | sudo apt-get update
23 | sudo apt-get install -y python3.9
24 | sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1
25 | sudo apt-get install -y python3-pip python3-distutils
26 | python3 -m pip install --user --upgrade pip
--------------------------------------------------------------------------------
/setup_scripts/install_xmanager.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #
3 | # Installs XManager.
4 |
5 | # Copyright 2021 DeepMind Technologies Limited
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | sudo apt-get update
20 | sudo apt-get install -y git
21 |
22 | python3 -m pip install --user --upgrade setuptools
23 | python3 -m pip install --user git+https://github.com/deepmind/xmanager
24 |
25 | if [ -n "${BASH_VERSION}" ]; then
26 | echo "Adding ~/.local/bin to PATH in ~/.bashrc..."
27 | # TODO: Use sed to search and replace instead of simply appending
28 | echo "export PATH=${PATH}:~/.local/bin" >> ~/.bashrc
29 | export PATH="${PATH}:~/.local/bin"
30 | else
31 | echo "Add ~/.local/bin/xmanager to your PATH to directly launch xmanager from the shell."
32 | fi
--------------------------------------------------------------------------------
/setup_scripts/setup_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #
3 | # Installs XManager and all required dependencies.
4 | # Creates basic GCP configuration based on user preferences.
5 |
6 | # Copyright 2021 DeepMind Technologies Limited
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License
19 |
20 | chmod +x *.sh
21 |
22 | ./install_bazel.sh &&
23 | ./install_docker.sh &&
24 | ./install_gcloud.sh &&
25 | ./install_python.sh &&
26 | . ./install_xmanager.sh &&
27 | . ./setup_gcp.sh
--------------------------------------------------------------------------------
/xmanager/bazel/client.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A module for communicating with the Bazel server."""
15 |
16 | import abc
17 | from typing import List, Sequence, Tuple
18 |
19 | import attr
20 |
21 |
22 | class BazelService(abc.ABC):
23 | """An interface for Bazel operations."""
24 |
25 | @abc.abstractmethod
26 | def fetch_kinds(self, labels: Sequence[str]) -> List[str]:
27 | """Fetches kinds of given targets.
28 |
29 | See https://docs.bazel.build/versions/main/query.html#output-label_kind.
30 |
31 | Args:
32 | labels: Labels of the targets to query.
33 |
34 | Returns:
35 | A list of kinds, for example, `['py_binary rule']`.
36 | """
37 | raise NotImplementedError
38 |
39 | @abc.abstractmethod
40 | def build_targets(
41 | self, labels: Sequence[str], tail_args: Sequence[str]
42 | ) -> List[List[str]]:
43 | """Builds given targets and returns paths to their important outputs.
44 |
45 | Args:
46 | labels: Labels of the targets to build.
47 | tail_args: Arguments to append to the Bazel command.
48 |
49 | Returns:
50 | For each label returns a list of its important outputs.
51 | """
52 | raise NotImplementedError
53 |
54 |
55 | def _to_tuple(sequence: Sequence[str]) -> Tuple[str, ...]:
56 | """A standalone function to satisfy PyType."""
57 | return tuple(sequence)
58 |
59 |
60 | @attr.s(auto_attribs=True, frozen=True)
61 | class BazelTarget:
62 | """A Bazel target to be built."""
63 |
64 | label: str
65 | bazel_args: Tuple[str, ...] = attr.ib(converter=_to_tuple)
66 |
--------------------------------------------------------------------------------
/xmanager/bazel/file_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A collection of tools for files operations."""
15 |
16 | import os
17 | import tempfile
18 |
19 |
20 | class TemporaryFilePath:
21 | """A context manager providing a temporary file path.
22 |
23 | Unlike NamedTemporaryFile, TemporaryFilePath closes the file when one enters
24 | the context.
25 | """
26 |
27 | _path: str
28 |
29 | def __enter__(self):
30 | fd, path = tempfile.mkstemp()
31 | os.close(fd)
32 | self._path = path
33 | return path
34 |
35 | def __exit__(self, error_type, error_value, traceback):
36 | os.remove(self._path)
37 |
--------------------------------------------------------------------------------
/xmanager/cli/README.md:
--------------------------------------------------------------------------------
1 | # XManager CLI
2 |
3 | This directory contains the command-line interface for XManager.
4 |
5 | ## `launch`
6 |
7 | Runs a given launch script.
8 |
9 | ```
10 | xmanager launch path/to/launch/script.py
11 | ```
12 |
13 | ## GKE
14 |
15 | In order to use the Kubernetes executor, you must host a Kubernetes cluster or
16 | use a Kubernetes cluster hosted by a cloud provider. An easy way to quickly
17 | create a cluster is by using [caliban](https://caliban.readthedocs.io/).
18 |
19 | The `xmanager` CLI can create clusters by calling caliban. To create a GKE
20 | auto-scaling cluster to be used with the Kubernetes executor, run:
21 |
22 | ```
23 | xmanager cluster create
24 | ```
25 |
26 | This command is equivalent to running `caliban cluster create`.
27 |
28 | To delete this cluster:
29 |
30 | ```
31 | xmanager cluster delete
32 | ```
33 |
34 | This command is equivalent to running `caliban cluster delete`.
35 |
36 |
37 |
--------------------------------------------------------------------------------
/xmanager/cli/cli.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Xmanager command-line interface."""
15 |
16 | import errno
17 | import importlib
18 | import os
19 | import sys
20 |
21 | from absl import app
22 |
23 | _DEFAULT_ZONE = 'us-west1-b'
24 | _DEFAULT_CLUSTER_NAME = 'xmanager-via-caliban'
25 |
26 |
27 | def main(argv):
28 | if len(argv) < 3:
29 | raise app.UsageError('There must be at least 2 command-line arguments')
30 | cmd = argv[1]
31 | if cmd == 'launch':
32 | launch_script = argv[2]
33 | if not os.path.exists(launch_script):
34 | raise OSError(errno.ENOENT, f'File not found: {launch_script}')
35 | sys.path.insert(0, os.path.abspath(os.path.dirname(launch_script)))
36 | launch_module, _ = os.path.splitext(os.path.basename(launch_script))
37 | m = importlib.import_module(launch_module)
38 | sys.path.pop(0)
39 | argv = [
40 | launch_script,
41 | '--xm_launch_script={}'.format(launch_script),
42 | ] + argv[3:]
43 | app.run(m.main, argv=argv)
44 | elif cmd == 'cluster':
45 | caliban_gke = importlib.import_module('caliban.platform.gke.cli')
46 | caliban_gke_types = importlib.import_module('caliban.platform.gke.types')
47 | subcmd = argv[2]
48 | args = {
49 | 'dry_run': False,
50 | 'cluster_name': _DEFAULT_CLUSTER_NAME,
51 | 'zone': _DEFAULT_ZONE,
52 | 'release_channel': caliban_gke_types.ReleaseChannel.REGULAR,
53 | 'single_zone': True,
54 | }
55 | if subcmd == 'create':
56 | caliban_gke._cluster_create(args) # pylint: disable=protected-access
57 | elif subcmd == 'delete':
58 | caliban_gke._cluster_delete(args) # pylint: disable=protected-access
59 | else:
60 | raise app.UsageError(
61 | f'Subcommand `{cmd} {subcmd}` is not a supported subcommand'
62 | )
63 | else:
64 | raise app.UsageError(f'Command `{cmd}` is not a supported command')
65 |
66 |
67 | def entrypoint():
68 | app.run(main)
69 |
70 |
71 | if __name__ == '__main__':
72 | app.run(main)
73 |
--------------------------------------------------------------------------------
/xmanager/cloud/README.md:
--------------------------------------------------------------------------------
1 | # `cloud`
2 |
3 | This directory contains tools and clients for various Google Cloud products.
4 |
--------------------------------------------------------------------------------
/xmanager/cloud/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/xmanager/cloud/build_image_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from absl.testing import absltest
15 | from xmanager import xm
16 | from xmanager.cloud import build_image
17 |
18 |
19 | class BuildImageTest(absltest.TestCase):
20 |
21 | def create_container(self, entrypoint) -> xm.PythonContainer:
22 | return xm.PythonContainer(entrypoint=entrypoint)
23 |
24 | def test_get_entrypoint_commands_module_adds_suffix(self):
25 | project = self.create_container(xm.ModuleName('some.python.module'))
26 | entrypoint_commands = build_image._get_entrypoint_commands(project)
27 | self.assertEndsWith(entrypoint_commands, ' "$@"')
28 |
29 | def test_get_entrypoint_commands_adds_suffix(self):
30 | commands = ['echo "aaa"']
31 | project = self.create_container(xm.CommandList(commands))
32 | entrypoint_commands = build_image._get_entrypoint_commands(project)
33 | self.assertEndsWith(entrypoint_commands, ' "$@"')
34 |
35 | def test_get_entrypoint_commands_no_dup_plain_suffix(self):
36 | commands = ['echo "aaa" $@']
37 | project = self.create_container(xm.CommandList(commands))
38 | entrypoint_commands = build_image._get_entrypoint_commands(project)
39 | self.assertEndsWith(entrypoint_commands, ' $@')
40 |
41 | def test_get_entrypoint_commands_no_dup_quoted_suffix(self):
42 | commands = ['echo "aaa" "$@"']
43 | project = self.create_container(xm.CommandList(commands))
44 | entrypoint_commands = build_image._get_entrypoint_commands(project)
45 | self.assertEndsWith(entrypoint_commands, ' "$@"')
46 | self.assertNotEndsWith(entrypoint_commands, ' "$@" "$@"')
47 |
48 | def test_get_entrypoint_commands_dup_single_quoted_suffix(self):
49 | commands = ['echo "aaa" \'$@\'']
50 | project = self.create_container(xm.CommandList(commands))
51 | entrypoint_commands = build_image._get_entrypoint_commands(project)
52 | self.assertEndsWith(entrypoint_commands, ' \'$@\' "$@"')
53 |
54 |
55 | if __name__ == '__main__':
56 | absltest.main()
57 |
--------------------------------------------------------------------------------
/xmanager/cloud/cloud_build_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for cloud_build."""
15 | from unittest import mock
16 | from absl.testing import absltest
17 |
18 | from xmanager.cloud import cloud_build
19 |
20 |
21 | class CloudBuildTest(absltest.TestCase):
22 |
23 | def test_build_request_body(self):
24 | client = cloud_build.Client(
25 | 'my-project',
26 | 'my-bucket',
27 | mock.Mock(),
28 | use_kaniko=False,
29 | use_cloud_build_cache=False,
30 | )
31 | image = client._build_request_body('path/to/project', 'my-image', 'live')
32 | self.assertEqual(
33 | image,
34 | {
35 | 'images': ['my-image'],
36 | 'options': {'machineType': 'E2_HIGHCPU_32'},
37 | 'source': {
38 | 'storageSource': {
39 | 'bucket': 'my-bucket',
40 | 'object': 'path/to/project',
41 | },
42 | },
43 | 'steps': [{
44 | 'args': [
45 | 'build',
46 | '-t',
47 | 'my-image:live',
48 | '-t',
49 | 'my-image:latest',
50 | '.',
51 | ],
52 | 'name': 'gcr.io/cloud-builders/docker',
53 | }],
54 | 'timeout': '1200s',
55 | },
56 | )
57 |
58 | def test_build_request_body_use_kaniko(self):
59 | client = cloud_build.Client(
60 | 'my-project',
61 | 'my-bucket',
62 | mock.Mock(),
63 | use_kaniko=True,
64 | use_cloud_build_cache=False,
65 | )
66 | image = client._build_request_body('path/to/project', 'my-image', 'live')
67 | self.assertEqual(
68 | image,
69 | {
70 | 'source': {
71 | 'storageSource': {
72 | 'bucket': 'my-bucket',
73 | 'object': 'path/to/project',
74 | },
75 | },
76 | 'steps': [{
77 | 'args': [
78 | '--destination=my-image:live',
79 | '--destination=my-image:latest',
80 | '--cache=true',
81 | '--cache-ttl=336h',
82 | ],
83 | 'name': 'gcr.io/kaniko-project/executor:latest',
84 | }],
85 | 'timeout': '1200s',
86 | },
87 | )
88 |
89 | def test_build_request_body_use_build_cache(self):
90 | client = cloud_build.Client(
91 | 'my-project',
92 | 'my-bucket',
93 | mock.Mock(),
94 | use_kaniko=False,
95 | use_cloud_build_cache=True,
96 | )
97 | image = client._build_request_body('path/to/project', 'my-image', 'live')
98 | self.assertEqual(
99 | image,
100 | {
101 | 'images': ['my-image'],
102 | 'options': {'machineType': 'E2_HIGHCPU_32'},
103 | 'source': {
104 | 'storageSource': {
105 | 'bucket': 'my-bucket',
106 | 'object': 'path/to/project',
107 | },
108 | },
109 | 'steps': [{
110 | 'args': [
111 | 'build',
112 | '-t',
113 | 'my-image:live',
114 | '-t',
115 | 'my-image:latest',
116 | '--cache-from',
117 | 'my-image:latest',
118 | '.',
119 | ],
120 | 'name': 'gcr.io/cloud-builders/docker',
121 | }],
122 | 'timeout': '1200s',
123 | },
124 | )
125 |
126 |
127 | if __name__ == '__main__':
128 | absltest.main()
129 |
--------------------------------------------------------------------------------
/xmanager/cloud/data/wrapped_entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | python3 -c "import vertex_utils; vertex_utils.create_workerpool_address_env_vars_script('./map_xm_env_vars')"
18 | source ./map_xm_env_vars
19 | ARGS=($(python3 -c "import vertex_utils; import sys; vertex_utils.print_workerpool_address_args(sys.argv)" $@ | tr -d '[],'))
20 | ./entrypoint.sh ${ARGS[@]}
21 |
--------------------------------------------------------------------------------
/xmanager/cloud/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for xmanager.cloud.utils."""
15 |
16 | import os
17 | import tempfile
18 | import unittest
19 |
20 | from xmanager.cloud import utils
21 |
22 | _CLUSTER_SPEC = """{
23 | "cluster": {
24 | "workerpool0": ["cmle-training-workerpool0-ab-0:2222"],
25 | "workerpool1": ["cmle-training-workerpool1-ab-0:2222", "cmle-training-workerpool1-ab-1:2222"],
26 | "workerpool2": ["cmle-training-workerpool2-ab-0:2222", "cmle-training-workerpool2-ab-1:2222"]
27 | },
28 | "environment": "cloud",
29 | "task": {
30 | "type": "workerpool1",
31 | "index": 1,
32 | "trial": ""
33 | }
34 | }""".replace(
35 | '\n', ' '
36 | )
37 |
38 |
39 | class UtilsTest(unittest.TestCase):
40 |
41 | def tearDown(self):
42 | super(UtilsTest, self).tearDown()
43 | os.environ['CLUSTER_SPEC'] = ''
44 |
45 | def test_get_master_address_port(self):
46 | os.environ['CLUSTER_SPEC'] = _CLUSTER_SPEC
47 | address, port = utils.get_master_address_port()
48 | self.assertEqual(address, 'cmle-training-workerpool0-ab-0')
49 | self.assertEqual(port, '2222')
50 |
51 | def test_get_master_address_port_default(self):
52 | address, port = utils.get_master_address_port()
53 | self.assertEqual(address, '127.0.0.1')
54 | self.assertEqual(port, '29500')
55 |
56 | def test_get_world_size_rank(self):
57 | os.environ['CLUSTER_SPEC'] = _CLUSTER_SPEC
58 | world_size, rank = utils.get_world_size_rank()
59 | self.assertEqual(world_size, 5)
60 | self.assertEqual(rank, 2)
61 |
62 | def test_get_world_size_rank_default(self):
63 | world_size, rank = utils.get_world_size_rank()
64 | self.assertEqual(world_size, 1)
65 | self.assertEqual(rank, 0)
66 |
67 | def test_wrap_and_unwrap_addresses(self):
68 | arg = '--master=' + utils.get_workerpool_address('workerpool0')
69 | self.assertEqual(arg, '--master=%objectname(workerpool0)%')
70 | os.environ['CLUSTER_SPEC'] = _CLUSTER_SPEC
71 | self.assertEqual(
72 | utils.map_workerpool_address_args([arg]),
73 | ['--master=cmle-training-workerpool0-ab-0:2222'],
74 | )
75 |
76 | def test_create_workerpool_address_env_vars_script(self):
77 | os.environ['MY_WORKER'] = utils.get_workerpool_address('workerpool0')
78 | os.environ['CLUSTER_SPEC'] = _CLUSTER_SPEC
79 | t = tempfile.NamedTemporaryFile()
80 | utils.create_workerpool_address_env_vars_script(t.name)
81 | expected = """
82 | #!/bin/bash
83 |
84 | export MY_WORKER=cmle-training-workerpool0-ab-0:2222
85 | """
86 | with open(t.name) as f:
87 | self.assertEqual(f.read(), expected.strip())
88 |
89 |
90 | if __name__ == '__main__':
91 | unittest.main()
92 |
--------------------------------------------------------------------------------
/xmanager/contrib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/xmanager/contrib/addressing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Module for getting the address of XManager jobs.
15 |
16 | Addresses in XManager can be statically evaluated because the experiment ID is
17 | known. Addressing should not involve tokens or late-bindings.
18 | """
19 |
20 |
21 | def k8s_pod_domain(
22 | job_name: str,
23 | experiment_id: int,
24 | work_unit_id: int,
25 | service: str = 'experiments',
26 | namespace: str = 'default',
27 | ) -> str:
28 | """Returns the Kubernetes pod address of a job.
29 |
30 | Args:
31 | job_name: Job name.
32 | experiment_id: Experiment ID.
33 | work_unit_id: Work unit ID
34 | service: Name of the service for the job. Defaults to 'experiments'
35 | namespace: Namespace of the job. Defaults to 'default'
36 | """
37 | return (
38 | f'{experiment_id}-{work_unit_id}-{job_name}'
39 | f'.{service}.{namespace}.svc.cluster.local:2222'
40 | )
41 |
--------------------------------------------------------------------------------
/xmanager/contrib/addressing_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for addressing."""
15 |
16 | from absl.testing import absltest
17 | from xmanager.contrib import addressing
18 |
19 |
20 | class AddressingTest(absltest.TestCase):
21 |
22 | def test_k8s_pod_domain(self):
23 | address = addressing.k8s_pod_domain(
24 | job_name='cifar10',
25 | experiment_id=123,
26 | work_unit_id=4,
27 | service='best_service',
28 | namespace='best_namespace',
29 | )
30 |
31 | self.assertEqual(
32 | address,
33 | '123-4-cifar10.best_service.best_namespace.svc.cluster.local:2222',
34 | )
35 |
36 |
37 | if __name__ == '__main__':
38 | absltest.main()
39 |
--------------------------------------------------------------------------------
/xmanager/contrib/copybara.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Module for transforming source code using copybara.
15 |
16 | XManager primarily uses Copybara to run folder-to-folder workflows in the form:
17 |
18 | core.workflow(
19 | name = "folder_to_folder",
20 | origin = folder.origin(),
21 | destination = folder.destination(),
22 | ...
23 | )
24 |
25 | Copybara allows for iterative local development on multiple platforms without
26 | needing to add/remove platform-specific code modifications. This allows you to
27 | preprocess source code so that it can be run on different platforms with
28 | different executors. e.g.
29 |
30 | local_version = run_workflow(config, 'local', path)
31 | vertex_version = run_workflow(config, 'vertex', path)
32 |
33 | local_spec = xm.PythonContainer(path=local_version, **kwargs)
34 | vertex_spec = xm.PythonContainer(path=vertex_version, **kwargs)
35 |
36 | [local_executable, vertex_executable] = experiment.package([
37 | xm.Packageable(
38 | executable_spec=spec,
39 | executor_spec=xm_local.Local.Spec()),
40 | xm.Packageable(
41 | executable_spec=spec,
42 | executor_spec=xm_local.Vertex.Spec())])
43 |
44 | Copybara has no release process, so you must compile copybara yourself:
45 | https://github.com/google/copybara
46 | """
47 | import os
48 | import subprocess
49 | import tempfile
50 | from typing import Optional
51 |
52 | # Set with the compiled path to copybara e.g.
53 | # COPYBARA_BIN = 'bazel-bin/java/com/google/copybara/copybara_deploy.jar'
54 | COPYBARA_BIN = 'copybara'
55 |
56 |
57 | def run_workflow(
58 | config: str,
59 | workflow: str,
60 | origin_folder: str,
61 | destination_folder: Optional[str] = None,
62 | config_root: Optional[str] = None,
63 | ) -> str:
64 | """Run a workflow in a copybara config to transform origin to destination.
65 |
66 | Args:
67 | config: Path to the Copybara config.
68 | workflow: Name of a workflow in copybara config.
69 | origin_folder: The origin folder to use as input. This will passed to
70 | Copybara via the source_ref argument.
71 | destination_folder: The destination folder to output.
72 | config_root: Configuration root path to be used for resolving absolute
73 | config labels like '//foo/bar'.
74 |
75 | Returns:
76 | The output destination folder.
77 | """
78 | origin_folder = os.path.abspath(origin_folder)
79 | if not destination_folder:
80 | destination_folder = tempfile.mkdtemp()
81 | command = [
82 | COPYBARA_BIN,
83 | config,
84 | workflow,
85 | '--ignore-noop',
86 | origin_folder,
87 | '--folder-dir=' + destination_folder,
88 | ]
89 | if config_root:
90 | command += ['--config-root=' + config_root]
91 | print('Copybara command: ', command)
92 | subprocess.run(command, check=True)
93 | return destination_folder
94 |
--------------------------------------------------------------------------------
/xmanager/contrib/framework_defaults_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from absl.testing import absltest
16 | from absl.testing import parameterized
17 | from xmanager import xm
18 | from xmanager.contrib import framework_defaults
19 |
20 | MLFramework = framework_defaults.MLFramework
21 |
22 |
23 | class FrameworkDefaultsTest(parameterized.TestCase):
24 |
25 | def test_known_frameworks(self):
26 | self.assertEqual(
27 | framework_defaults._get_framework('torch'), MLFramework.PYTORCH
28 | )
29 | self.assertEqual(
30 | framework_defaults._get_framework('pytorch'), MLFramework.PYTORCH
31 | )
32 | self.assertEqual(framework_defaults._get_framework('tf'), MLFramework.TF2)
33 | self.assertEqual(framework_defaults._get_framework('tf1'), MLFramework.TF1)
34 | self.assertEqual(framework_defaults._get_framework('tf2'), MLFramework.TF2)
35 | self.assertEqual(
36 | framework_defaults._get_framework('tensorflow 2.x'), MLFramework.TF2
37 | )
38 | self.assertEqual(framework_defaults._get_framework('jax'), MLFramework.JAX)
39 | self.assertEqual(framework_defaults._get_framework('flax'), MLFramework.JAX)
40 |
41 | def test_unknown_frameworks(self):
42 | self.assertEqual(
43 | framework_defaults._get_framework('huggingface'), MLFramework.UNKNOWN
44 | )
45 | self.assertEqual(
46 | framework_defaults._get_framework('objax'), MLFramework.UNKNOWN
47 | )
48 | self.assertEqual(
49 | framework_defaults._get_framework('not a framework name'),
50 | MLFramework.UNKNOWN,
51 | )
52 |
53 | @parameterized.named_parameters(
54 | ('cpu', None),
55 | ('gpu', xm.ResourceType.V100),
56 | ('tpu', xm.ResourceType.TPU_V3),
57 | )
58 | def test_jax_base_image(self, accelerator):
59 | base_image = framework_defaults.base_image(MLFramework.JAX, accelerator)
60 | self.assertStartsWith(base_image, 'gcr.io/deeplearning-platform-release/')
61 | # Jax uses CUDA images.
62 | self.assertIn('cu', base_image)
63 |
64 | @parameterized.named_parameters(
65 | ('cpu', None),
66 | ('gpu', xm.ResourceType.V100),
67 | ('tpu', xm.ResourceType.TPU_V3),
68 | )
69 | def test_tf2_base_image(self, accelerator):
70 | base_image = framework_defaults.base_image(MLFramework.TF2, accelerator)
71 | self.assertStartsWith(base_image, 'gcr.io/deeplearning-platform-release/')
72 | self.assertIn('tf2-', base_image)
73 |
74 | @parameterized.named_parameters(
75 | ('cpu', None),
76 | ('gpu', xm.ResourceType.V100),
77 | ('tpu', xm.ResourceType.TPU_V3),
78 | )
79 | def test_torch_base_image(self, accelerator):
80 | base_image = framework_defaults.base_image(MLFramework.PYTORCH, accelerator)
81 | self.assertStartsWith(base_image, 'gcr.io/')
82 | self.assertIn('pytorch', base_image)
83 | if accelerator in xm.TpuType:
84 | self.assertIn('tpu', base_image)
85 |
86 | @parameterized.named_parameters(
87 | ('cpu', None),
88 | ('gpu', xm.ResourceType.V100),
89 | ('tpu', xm.ResourceType.TPU_V3),
90 | )
91 | def test_unsupported_tf1_base_image(self, accelerator):
92 | base_image = framework_defaults.base_image(MLFramework.TF1, accelerator)
93 | self.assertStartsWith(base_image, 'gcr.io/deeplearning-platform-release/')
94 | self.assertIn('tf', base_image)
95 |
96 | @parameterized.named_parameters(
97 | ('cpu', None),
98 | ('gpu', xm.ResourceType.V100),
99 | ('tpu', xm.ResourceType.TPU_V3),
100 | )
101 | def test_unknown_base_image(self, accelerator):
102 | base_image = framework_defaults.base_image(MLFramework.UNKNOWN, accelerator)
103 | self.assertStartsWith(base_image, 'gcr.io/deeplearning-platform-release/')
104 | self.assertIn('base', base_image)
105 |
106 |
107 | if __name__ == '__main__':
108 | absltest.main()
109 |
--------------------------------------------------------------------------------
/xmanager/contrib/process_entry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Entry point of a user-defined Python function."""
16 |
17 |
18 | import contextlib
19 | import functools
20 | import os
21 | import sys
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | import cloudpickle
27 |
28 |
29 | FLAGS = flags.FLAGS
30 | flags.DEFINE_string(
31 | 'data_file', '', 'Pickle file location with entry points for all nodes'
32 | )
33 | flags.DEFINE_string(
34 | 'init_file',
35 | '',
36 | 'Pickle file location containing initialization module '
37 | 'executed for each node prior to an entry point',
38 | )
39 | flags.DEFINE_string('flags_to_populate', '{}', 'obsolete')
40 |
41 |
42 | def _parse_process_entry_flags(all_argv: list[str]) -> list[str]:
43 | """Parse and consume all flags for the entry script; return the rest."""
44 | # unconsumed_argv will still include all_argv[0], which is expected to be
45 | # the program name and is ignored by flag parsing.
46 | unconsumed_argv = FLAGS(all_argv, known_only=True)
47 |
48 | # JAX doesn't use absl flags and so we need to forward absl flags to JAX
49 | # explicitly. Here's a heuristic to detect JAX flags and forward them.
50 | if any(arg.startswith('--jax_') for arg in sys.argv):
51 | try:
52 | # pytype:disable=import-error
53 | # pylint:disable=g-import-not-at-top
54 | import jax
55 | # pytype:enable=import-error
56 | # pylint:enable=g-import-not-at-top
57 | jax.config.parse_flags_with_absl()
58 | except ImportError:
59 | pass
60 |
61 | return unconsumed_argv
62 |
63 |
64 | def main(argv: list[str], process_argv: list[str]):
65 | # See `parse_flags_and_run()` for why arguments are passed in `process_argv`
66 | # instead.
67 | assert len(argv) == 1
68 | del argv
69 |
70 | # Allow for importing modules from the current directory.
71 | sys.path.append(os.getcwd())
72 | data_file = FLAGS.data_file
73 | init_file = FLAGS.init_file
74 |
75 | if os.environ.get('TF_CONFIG', None):
76 | # For GCP runtime log to STDOUT so that logs are not reported as errors.
77 | logging.get_absl_handler().python_handler.stream = sys.stdout
78 |
79 | if init_file:
80 | init_function = cloudpickle.load(open(init_file, 'rb'))
81 | init_function()
82 | functions = cloudpickle.load(open(data_file, 'rb'))
83 |
84 | # Now that the code that we intend to run has been unpickled, that should
85 | # have caused the registration of any remaining flags that the program needs.
86 | [unused_program_name, *unconsumed_argv] = FLAGS(process_argv, known_only=True)
87 | if unconsumed_argv:
88 | logging.warning(
89 | 'The following command-line arguments were passed to the '
90 | 'program but are not used by anything that it imports: %s',
91 | unconsumed_argv,
92 | )
93 |
94 | with contextlib.suppress(): # no-op context manager
95 | # Currently only one function is supported.
96 | functions[0]()
97 |
98 |
99 | def parse_flags_and_run():
100 | # Parse flags for this module and the things it has already imported.
101 | # Pass whatever flags are left over to main() through a side channel, so that
102 | # app.run() doesn't try to parse them before we have set the scene.
103 | [program_name, *process_argv] = _parse_process_entry_flags(sys.argv)
104 | app.run(
105 | functools.partial(main, process_argv=[program_name, *process_argv]),
106 | argv=[program_name],
107 | )
108 |
109 |
110 | if __name__ == '__main__':
111 | parse_flags_and_run()
112 |
--------------------------------------------------------------------------------
/xmanager/contrib/tensorboard.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Helper methods for running Tensorboard from the client."""
15 |
16 | from typing import Any, Mapping, Optional
17 |
18 | from xmanager import xm
19 |
20 |
21 | class TensorboardProvider:
22 | """A class to generate package and job/args to Tensorboard jobs."""
23 |
24 | DEFAULT_TENSORBOARD_PORT = 6006
25 |
26 | @staticmethod
27 | def get_tensorboard_packageable(timeout_secs: int) -> xm.PythonContainer:
28 | """Creates container spec running TensorBoard server.
29 |
30 | Args:
31 | timeout_secs: Seconds to run the server for. Note that a value of 0
32 | disables the associated timeout.
33 |
34 | Raises:
35 | RuntimeError: `timeout_secs` is negative.
36 |
37 | Returns:
38 | Spec of container running TensorBoard server for a specified
39 | period of time.
40 | """
41 | if timeout_secs < 0:
42 | raise RuntimeError('`timeout_secs` must be a nonnegative number')
43 |
44 | return xm.PythonContainer(
45 | base_image='tensorflow/tensorflow',
46 | entrypoint=xm.CommandList([f'timeout {timeout_secs}s tensorboard']),
47 | )
48 |
49 | @staticmethod
50 | def get_tensorboard_job_args(
51 | log_dir: str,
52 | port: Optional[int] = None,
53 | additional_args: Optional[Mapping[str, Any]] = None,
54 | ) -> Mapping[str, Any]:
55 | """Get arguments to start a Tensorboard job."""
56 | args = {
57 | 'logdir': log_dir,
58 | 'port': port or TensorboardProvider.DEFAULT_TENSORBOARD_PORT,
59 | # Allows accessing visualisations from Docker container running locally.
60 | 'host': '0.0.0.0',
61 | # This is set to true by default when running Tensorboard.
62 | # Since it doesn't seem to be working well with GKE Workload Identity,
63 | # we set it to false for now. Can be overriden through the
64 | # `additional_args` param.
65 | #
66 | # https://github.com/tensorflow/tensorboard/issues/4784#issuecomment-868945650
67 | 'load_fast': 'false',
68 | }
69 | if additional_args:
70 | args.update(additional_args)
71 |
72 | return args
73 |
74 |
75 | def add_tensorboard(
76 | experiment: xm.Experiment,
77 | logdir: str,
78 | executor: xm.Executor,
79 | timeout_secs: int = 60 * 60 * 24,
80 | args: Optional[Mapping[str, Any]] = None,
81 | ) -> None:
82 | """Self-contained function which adds a Tensorboard auxiliary job to @experiment."""
83 | provider = TensorboardProvider
84 | [executable] = experiment.package(
85 | [
86 | xm.Packageable(
87 | provider.get_tensorboard_packageable(timeout_secs=timeout_secs),
88 | executor.Spec(),
89 | )
90 | ]
91 | )
92 |
93 | job = xm.Job(
94 | executable,
95 | executor,
96 | args=provider.get_tensorboard_job_args(logdir, additional_args=args),
97 | name='tensorboard',
98 | )
99 |
100 | # TODO: Add support for `termination_delay_secs`.
101 | experiment.add(xm.AuxiliaryUnitJob(job, termination_delay_secs=0))
102 |
--------------------------------------------------------------------------------
/xmanager/contrib/tpu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Helper module for using TPUs."""
15 | from typing import List
16 |
17 |
18 | # pylint: disable=line-too-long
19 | def tpuvm_docker_instructions() -> List[str]:
20 | return [
21 | (
22 | 'RUN wget'
23 | ' https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/20210525/libtpu.so'
24 | ' -O /lib/libtpu.so'
25 | ),
26 | 'RUN chmod 700 /lib/libtpu.so',
27 | (
28 | 'RUN wget '
29 | 'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/20210525/tf_nightly-2.6.0-cp38-cp38-linux_x86_64.whl'
30 | ),
31 | 'RUN pip3 install tf_nightly-2.6.0-cp38-cp38-linux_x86_64.whl',
32 | 'RUN rm tf_nightly-2.6.0-cp38-cp38-linux_x86_64.whl',
33 | ]
34 |
--------------------------------------------------------------------------------
/xmanager/generated/README.md:
--------------------------------------------------------------------------------
1 | # `generated`
2 |
3 | This directory contains generated source files, e.g. compiled protobufs.
4 |
--------------------------------------------------------------------------------
/xmanager/generated/command_line_pb2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The Bazel Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pyformat: disable
16 | # Generated by the protocol buffer compiler. DO NOT EDIT!
17 | # pylint: skip-file
18 | # source: src/main/protobuf/command_line.proto
19 | """Generated protocol buffer code."""
20 | from google.protobuf.internal import builder as _builder
21 | from google.protobuf import descriptor as _descriptor
22 | from google.protobuf import descriptor_pool as _descriptor_pool
23 | from google.protobuf import symbol_database as _symbol_database
24 | # @@protoc_insertion_point(imports)
25 |
26 | _sym_db = _symbol_database.Default()
27 |
28 |
29 | from . import option_filters_pb2 as src_dot_main_dot_protobuf_dot_option__filters__pb2
30 |
31 |
32 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/main/protobuf/command_line.proto\x12\x0c\x63ommand_line\x1a&src/main/protobuf/option_filters.proto\"]\n\x0b\x43ommandLine\x12\x1a\n\x12\x63ommand_line_label\x18\x01 \x01(\t\x12\x32\n\x08sections\x18\x02 \x03(\x0b\x32 .command_line.CommandLineSection\"\x9b\x01\n\x12\x43ommandLineSection\x12\x15\n\rsection_label\x18\x01 \x01(\t\x12-\n\nchunk_list\x18\x02 \x01(\x0b\x32\x17.command_line.ChunkListH\x00\x12/\n\x0boption_list\x18\x03 \x01(\x0b\x32\x18.command_line.OptionListH\x00\x42\x0e\n\x0csection_type\"\x1a\n\tChunkList\x12\r\n\x05\x63hunk\x18\x01 \x03(\t\"2\n\nOptionList\x12$\n\x06option\x18\x01 \x03(\x0b\x32\x14.command_line.Option\"\xac\x01\n\x06Option\x12\x15\n\rcombined_form\x18\x01 \x01(\t\x12\x13\n\x0boption_name\x18\x02 \x01(\t\x12\x14\n\x0coption_value\x18\x03 \x01(\t\x12-\n\x0b\x65\x66\x66\x65\x63t_tags\x18\x04 \x03(\x0e\x32\x18.options.OptionEffectTag\x12\x31\n\rmetadata_tags\x18\x05 \x03(\x0e\x32\x1a.options.OptionMetadataTagB-\n+com.google.devtools.build.lib.runtime.protob\x06proto3')
33 |
34 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
35 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.main.protobuf.command_line_pb2', globals())
36 | if _descriptor._USE_C_DESCRIPTORS == False:
37 |
38 | DESCRIPTOR._options = None
39 | DESCRIPTOR._serialized_options = b'\n+com.google.devtools.build.lib.runtime.proto'
40 | _COMMANDLINE._serialized_start=94
41 | _COMMANDLINE._serialized_end=187
42 | _COMMANDLINESECTION._serialized_start=190
43 | _COMMANDLINESECTION._serialized_end=345
44 | _CHUNKLIST._serialized_start=347
45 | _CHUNKLIST._serialized_end=373
46 | _OPTIONLIST._serialized_start=375
47 | _OPTIONLIST._serialized_end=425
48 | _OPTION._serialized_start=428
49 | _OPTION._serialized_end=600
50 | # @@protoc_insertion_point(module_scope)
51 |
--------------------------------------------------------------------------------
/xmanager/generated/data_pb2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pyformat: disable
16 | # Generated by the protocol buffer compiler. DO NOT EDIT!
17 | # pylint: skip-file
18 | # source: xm_local/storage/data.proto
19 | """Generated protocol buffer code."""
20 | from google.protobuf.internal import builder as _builder
21 | from google.protobuf import descriptor as _descriptor
22 | from google.protobuf import descriptor_pool as _descriptor_pool
23 | from google.protobuf import symbol_database as _symbol_database
24 | # @@protoc_insertion_point(imports)
25 |
26 | _sym_db = _symbol_database.Default()
27 |
28 |
29 |
30 |
31 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1bxm_local/storage/data.proto\x12\x08xmanager\"\x8a\x01\n\x03Job\x12#\n\x05local\x18\x01 \x01(\x0b\x32\x12.xmanager.LocalJobH\x00\x12\'\n\x04\x63\x61ip\x18\x02 \x01(\x0b\x32\x17.xmanager.AIPlatformJobH\x00\x12-\n\nkubernetes\x18\x03 \x01(\x0b\x32\x17.xmanager.KubernetesJobH\x00\x42\x06\n\x04kind\"7\n\x08LocalJob\x12\x0b\n\x03pid\x18\x01 \x01(\t\x12\x0b\n\x03\x63md\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\"&\n\rAIPlatformJob\x12\x15\n\rresource_name\x18\x01 \x01(\t\"4\n\rKubernetesJob\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x10\n\x08job_name\x18\x02 \x01(\tb\x06proto3')
32 |
33 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
34 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xm_local.storage.data_pb2', globals())
35 | if _descriptor._USE_C_DESCRIPTORS == False:
36 |
37 | DESCRIPTOR._options = None
38 | _JOB._serialized_start=42
39 | _JOB._serialized_end=180
40 | _LOCALJOB._serialized_start=182
41 | _LOCALJOB._serialized_end=237
42 | _AIPLATFORMJOB._serialized_start=239
43 | _AIPLATFORMJOB._serialized_end=277
44 | _KUBERNETESJOB._serialized_start=279
45 | _KUBERNETESJOB._serialized_end=331
46 | # @@protoc_insertion_point(module_scope)
47 |
--------------------------------------------------------------------------------
/xmanager/generated/invocation_policy_pb2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The Bazel Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pyformat: disable
16 | # Generated by the protocol buffer compiler. DO NOT EDIT!
17 | # pylint: skip-file
18 | # source: src/main/protobuf/invocation_policy.proto
19 | """Generated protocol buffer code."""
20 | from google.protobuf.internal import builder as _builder
21 | from google.protobuf import descriptor as _descriptor
22 | from google.protobuf import descriptor_pool as _descriptor_pool
23 | from google.protobuf import symbol_database as _symbol_database
24 | # @@protoc_insertion_point(imports)
25 |
26 | _sym_db = _symbol_database.Default()
27 |
28 |
29 |
30 |
31 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)src/main/protobuf/invocation_policy.proto\x12\x17\x62laze.invocation_policy\"N\n\x10InvocationPolicy\x12:\n\rflag_policies\x18\x01 \x03(\x0b\x32#.blaze.invocation_policy.FlagPolicy\"\xb4\x02\n\nFlagPolicy\x12\x11\n\tflag_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ommands\x18\x02 \x03(\t\x12\x36\n\tset_value\x18\x03 \x01(\x0b\x32!.blaze.invocation_policy.SetValueH\x00\x12:\n\x0buse_default\x18\x04 \x01(\x0b\x32#.blaze.invocation_policy.UseDefaultH\x00\x12\x42\n\x0f\x64isallow_values\x18\x05 \x01(\x0b\x32\'.blaze.invocation_policy.DisallowValuesH\x00\x12<\n\x0c\x61llow_values\x18\x06 \x01(\x0b\x32$.blaze.invocation_policy.AllowValuesH\x00\x42\x0b\n\toperation\"\xc6\x01\n\x08SetValue\x12\x12\n\nflag_value\x18\x01 \x03(\t\x12<\n\x08\x62\x65havior\x18\x04 \x01(\x0e\x32*.blaze.invocation_policy.SetValue.Behavior\"\\\n\x08\x42\x65havior\x12\r\n\tUNDEFINED\x10\x00\x12\x13\n\x0f\x41LLOW_OVERRIDES\x10\x01\x12\n\n\x06\x41PPEND\x10\x02\x12 \n\x1c\x46INAL_VALUE_IGNORE_OVERRIDES\x10\x03J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\x0c\n\nUseDefault\"\x97\x01\n\x0e\x44isallowValues\x12\x19\n\x11\x64isallowed_values\x18\x01 \x03(\t\x12\x13\n\tnew_value\x18\x03 \x01(\tH\x00\x12:\n\x0buse_default\x18\x04 \x01(\x0b\x32#.blaze.invocation_policy.UseDefaultH\x00\x42\x13\n\x11replacement_valueJ\x04\x08\x02\x10\x03\"\x91\x01\n\x0b\x41llowValues\x12\x16\n\x0e\x61llowed_values\x18\x01 \x03(\t\x12\x13\n\tnew_value\x18\x03 \x01(\tH\x00\x12:\n\x0buse_default\x18\x04 \x01(\x0b\x32#.blaze.invocation_policy.UseDefaultH\x00\x42\x13\n\x11replacement_valueJ\x04\x08\x02\x10\x03\x42-\n+com.google.devtools.build.lib.runtime.proto')
32 |
33 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
34 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.main.protobuf.invocation_policy_pb2', globals())
35 | if _descriptor._USE_C_DESCRIPTORS == False:
36 |
37 | DESCRIPTOR._options = None
38 | DESCRIPTOR._serialized_options = b'\n+com.google.devtools.build.lib.runtime.proto'
39 | _INVOCATIONPOLICY._serialized_start=70
40 | _INVOCATIONPOLICY._serialized_end=148
41 | _FLAGPOLICY._serialized_start=151
42 | _FLAGPOLICY._serialized_end=459
43 | _SETVALUE._serialized_start=462
44 | _SETVALUE._serialized_end=660
45 | _SETVALUE_BEHAVIOR._serialized_start=556
46 | _SETVALUE_BEHAVIOR._serialized_end=648
47 | _USEDEFAULT._serialized_start=662
48 | _USEDEFAULT._serialized_end=674
49 | _DISALLOWVALUES._serialized_start=677
50 | _DISALLOWVALUES._serialized_end=828
51 | _ALLOWVALUES._serialized_start=831
52 | _ALLOWVALUES._serialized_end=976
53 | # @@protoc_insertion_point(module_scope)
54 |
--------------------------------------------------------------------------------
/xmanager/generated/option_filters_pb2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The Bazel Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pyformat: disable
16 | # Generated by the protocol buffer compiler. DO NOT EDIT!
17 | # pylint: skip-file
18 | # source: src/main/protobuf/option_filters.proto
19 | """Generated protocol buffer code."""
20 | from google.protobuf.internal import builder as _builder
21 | from google.protobuf import descriptor as _descriptor
22 | from google.protobuf import descriptor_pool as _descriptor_pool
23 | from google.protobuf import symbol_database as _symbol_database
24 | # @@protoc_insertion_point(imports)
25 |
26 | _sym_db = _symbol_database.Default()
27 |
28 |
29 |
30 |
31 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&src/main/protobuf/option_filters.proto\x12\x07options*\xea\x02\n\x0fOptionEffectTag\x12\x0b\n\x07UNKNOWN\x10\x00\x12\t\n\x05NO_OP\x10\x01\x12\x1b\n\x17LOSES_INCREMENTAL_STATE\x10\x02\x12\x12\n\x0e\x43HANGES_INPUTS\x10\x03\x12\x13\n\x0f\x41\x46\x46\x45\x43TS_OUTPUTS\x10\x04\x12\x18\n\x14\x42UILD_FILE_SEMANTICS\x10\x05\x12 \n\x1c\x42\x41ZEL_INTERNAL_CONFIGURATION\x10\x06\x12\x18\n\x14LOADING_AND_ANALYSIS\x10\x07\x12\r\n\tEXECUTION\x10\x08\x12\'\n#HOST_MACHINE_RESOURCE_OPTIMIZATIONS\x10\t\x12\x15\n\x11\x45\x41GERNESS_TO_EXIT\x10\n\x12\x14\n\x10\x42\x41ZEL_MONITORING\x10\x0b\x12\x13\n\x0fTERMINAL_OUTPUT\x10\x0c\x12\x18\n\x14\x41\x43TION_COMMAND_LINES\x10\r\x12\x0f\n\x0bTEST_RUNNER\x10\x0e*\xb2\x01\n\x11OptionMetadataTag\x12\x10\n\x0c\x45XPERIMENTAL\x10\x00\x12\x17\n\x13INCOMPATIBLE_CHANGE\x10\x01\x12\x0e\n\nDEPRECATED\x10\x02\x12\n\n\x06HIDDEN\x10\x03\x12\x0c\n\x08INTERNAL\x10\x04\x12\x1b\n\x17\x45XPLICIT_IN_OUTPUT_PATH\x10\x06\"\x04\x08\x05\x10\x05*%TRIGGERED_BY_ALL_INCOMPATIBLE_CHANGESB*\n(com.google.devtools.common.options.protob\x06proto3')
32 |
33 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
34 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.main.protobuf.option_filters_pb2', globals())
35 | if _descriptor._USE_C_DESCRIPTORS == False:
36 |
37 | DESCRIPTOR._options = None
38 | DESCRIPTOR._serialized_options = b'\n(com.google.devtools.common.options.proto'
39 | _OPTIONEFFECTTAG._serialized_start=52
40 | _OPTIONEFFECTTAG._serialized_end=414
41 | _OPTIONMETADATATAG._serialized_start=417
42 | _OPTIONMETADATATAG._serialized_end=595
43 | # @@protoc_insertion_point(module_scope)
44 |
--------------------------------------------------------------------------------
/xmanager/module_lazy_loader/lazy_loader_module_attrs_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Test for custom lazy loader definition for xmanager sub-package __init__.py files."""
15 |
16 | import unittest
17 | from xmanager.module_lazy_loader import module_lazy_loader
18 |
19 |
20 | class LazyLoaderModuleAttrsTest(unittest.TestCase):
21 | test_lazy_loader = module_lazy_loader.XManagerLazyLoader(
22 | __name__,
23 | apis=[
24 | module_lazy_loader.XManagerAPI(
25 | module="xmanager.module_lazy_loader.module_lazy_loader",
26 | symbol="XManagerAPI",
27 | alias="boo",
28 | ),
29 | module_lazy_loader.XManagerAPI(
30 | module="xmanager.module_lazy_loader.module_lazy_loader",
31 | symbol="XManagerAPI",
32 | ),
33 | module_lazy_loader.XManagerAPI(
34 | module="xmanager.module_lazy_loader.module_lazy_loader",
35 | alias="baz",
36 | ),
37 | module_lazy_loader.XManagerAPI(
38 | module="xmanager.module_lazy_loader.module_lazy_loader",
39 | ),
40 | ],
41 | )
42 |
43 | def test_all(self):
44 | self.assertCountEqual(
45 | self.test_lazy_loader.get_module_all(),
46 | ["boo", "XManagerAPI", "baz", "module_lazy_loader"],
47 | )
48 |
49 | def test_dir(self):
50 | self.assertCountEqual(
51 | self.test_lazy_loader.get_module_dir()(),
52 | ["boo", "XManagerAPI", "baz", "module_lazy_loader"],
53 | )
54 |
55 | def test_getattr(self):
56 | local_getattr = self.test_lazy_loader.get_module_getattr()
57 | self.assertEqual(local_getattr("boo"), module_lazy_loader.XManagerAPI)
58 | self.assertEqual(
59 | local_getattr("XManagerAPI"), module_lazy_loader.XManagerAPI
60 | )
61 | self.assertEqual(local_getattr("baz"), module_lazy_loader)
62 | self.assertEqual(local_getattr("module_lazy_loader"), module_lazy_loader)
63 | self.assertRaises(AttributeError, local_getattr, "this_attr_does_not_exist")
64 |
65 |
66 | if __name__ == "__main__":
67 | unittest.main()
68 |
--------------------------------------------------------------------------------
/xmanager/vizier/vizier_cloud/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Vizier API for launching Vertex-Vizier explored Experiment for OSS."""
15 |
16 | from xmanager.vizier.vizier_cloud.study_factory import NewStudy
17 | from xmanager.vizier.vizier_cloud.vizier_exploration import VizierExploration
18 |
--------------------------------------------------------------------------------
/xmanager/vizier/vizier_cloud/study_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Factory classes for generating study of cloud Vertex Vizier."""
15 |
16 | import abc
17 | from typing import Optional
18 |
19 | from google.cloud import aiplatform_v1beta1 as aip
20 |
21 | from xmanager.cloud import auth
22 |
23 | _DEFAULT_LOCATION = 'us-central1'
24 |
25 |
26 | class StudyFactory(abc.ABC):
27 | """Abstract class representing vizier study generator."""
28 |
29 | vz_client: aip.VizierServiceClient
30 | study_config: aip.StudySpec
31 | num_trials_total: int
32 | display_name: str
33 |
34 | # TODO: Once vertex pyvizier is available, we should replace
35 | # aip.StudySpec with it.
36 | # display_name and num_trials_total are supposed to be set into the study
37 | # config, which is not supported by aip.StudySpec currently. But should be
38 | # settable when pyvizier.StudyConfig is available.
39 | def __init__(
40 | self,
41 | study_config: aip.StudySpec,
42 | num_trials_total: int,
43 | display_name: str,
44 | location: str,
45 | ) -> None:
46 | super().__init__()
47 | self.study_config = study_config
48 | self.num_trials_total = num_trials_total
49 | self.display_name = display_name
50 | self.vz_client = aip.VizierServiceClient(
51 | client_options=dict(
52 | api_endpoint=f'{location}-aiplatform.googleapis.com'
53 | )
54 | )
55 |
56 | @abc.abstractmethod
57 | def study(self) -> str:
58 | raise NotImplementedError
59 |
60 |
61 | class NewStudy(StudyFactory):
62 | """Vizier study generator that generates new study from given config."""
63 |
64 | project: str
65 | location: str
66 |
67 | # `num_trials_total` is a required field. Default it to 0 to unbreak the
68 | # soon-to-deprecate VizierExploration users.
69 | # `display_name` is optional for user to customize, if not set, XM will
70 | # set it with experiment information
71 | def __init__(
72 | self,
73 | study_config: aip.StudySpec,
74 | num_trials_total: int = 0,
75 | display_name: Optional[str] = None,
76 | project: Optional[str] = None,
77 | location: Optional[str] = None,
78 | ) -> None:
79 | self.project = project or auth.get_project_name()
80 | self.location = location or _DEFAULT_LOCATION
81 |
82 | super().__init__(
83 | study_config, num_trials_total, display_name or '', self.location
84 | )
85 |
86 | def study(self) -> str:
87 | return self.vz_client.create_study(
88 | parent=f'projects/{self.project}/locations/{self.location}',
89 | study=aip.Study(
90 | display_name=self.display_name, study_spec=self.study_config
91 | ),
92 | ).name
93 |
--------------------------------------------------------------------------------
/xmanager/vizier/vizier_cloud/vizier_exploration.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Interface for launching Vizier Explorations using Vertex Vizier."""
15 |
16 | from typing import Any, Dict
17 |
18 | from xmanager import xm
19 | from xmanager.vizier.vizier_cloud import study_factory as sf
20 | from xmanager.vizier.vizier_cloud import vizier_controller
21 |
22 | _DEFAULT_LOCATION = 'us-central1'
23 |
24 |
25 | # TODO: Add vizier_controller as auxiliary Job generator.
26 | class VizierExploration:
27 | """An API for launching experiment as a Vizier-based Exploration."""
28 |
29 | def __init__(
30 | self,
31 | experiment: xm.Experiment,
32 | job: xm.JobType,
33 | study_factory: sf.StudyFactory,
34 | num_trials_total: int,
35 | num_parallel_trial_runs: int,
36 | ) -> None:
37 | """Create a VizierExploration.
38 |
39 | Args:
40 | experiment: the experiment who does the exploration.
41 | job: a job to run.
42 | study_factory: the VizierStudyFactory used to create or load the study.
43 | num_trials_total: total number of trials the experiment want to explore.
44 | num_parallel_trial_runs: number of parallel runs evaluating the trials.
45 | """
46 |
47 | async def work_unit_generator(
48 | work_unit: xm.WorkUnit, vizier_params: Dict[str, Any]
49 | ):
50 | work_unit.add(job, self._to_job_params(vizier_params))
51 |
52 | if not study_factory.display_name:
53 | study_factory.display_name = f'X{experiment.experiment_id}'
54 |
55 | self._controller = vizier_controller.VizierController(
56 | experiment,
57 | work_unit_generator,
58 | study_factory.vz_client,
59 | study_factory.study(),
60 | num_trials_total,
61 | num_parallel_trial_runs,
62 | )
63 |
64 | def _to_job_params(self, vizier_params: Dict[str, Any]) -> Dict[str, Any]:
65 | # TODO: unflatten parameters for JobGroup case (currently this
66 | # works for xm.Job).
67 | # For example: transform
68 | # {'learner.args.learning_rate': 0.1}
69 | # to
70 | # {'learner': {'args': {'learning_rate': 0.1}}}
71 | return {'args': vizier_params}
72 |
73 | def launch(self, **kwargs) -> None:
74 | self._controller.run(**kwargs)
75 |
--------------------------------------------------------------------------------
/xmanager/vizier/vizier_cloud/vizier_worker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Run Job as a Vizier worker to manager WorkUnit Vizier interaction."""
15 | import re
16 | from typing import Dict, Optional
17 |
18 | from absl import logging
19 | from google.cloud import aiplatform_v1beta1 as aip
20 |
21 | _TRIAL_NAME_REGEX = (
22 | r'projects\/[^\/]+\/locations\/[^\/]+\/studies\/[^\/]+\/trials\/[^\/]+'
23 | )
24 |
25 |
26 | class VizierWorker:
27 | """Worker that manage interaction between job and Vizier."""
28 |
29 | def __init__(self, trial_name: str) -> None:
30 | if not re.match(_TRIAL_NAME_REGEX, trial_name):
31 | raise Exception(
32 | 'The trial_name must be in the form: '
33 | 'projects/{project}/locations/{location}/'
34 | 'studies/{study}/trials/{trial}'
35 | )
36 |
37 | self._trial_name = trial_name
38 |
39 | location = trial_name.split('/')[3]
40 | self._vz_client = aip.VizierServiceClient(
41 | client_options={
42 | 'api_endpoint': f'{location}-aiplatform.googleapis.com',
43 | }
44 | )
45 |
46 | def add_trial_measurement(self, step: int, metrics: Dict[str, float]) -> None:
47 | """Add trial measurements to Vizier."""
48 | self._vz_client.add_trial_measurement(
49 | request=aip.AddTrialMeasurementRequest(
50 | trial_name=self._trial_name,
51 | measurement=aip.Measurement(
52 | step_count=step,
53 | metrics=[
54 | aip.Measurement.Metric(metric_id=k, value=v)
55 | for k, v in metrics.items()
56 | ],
57 | ),
58 | )
59 | )
60 | logging.info('Step %d Metric %s is reported', step, metrics)
61 |
62 | def complete_trial(self, infeasible_reason: Optional[str] = None) -> None:
63 | """Complete a trial."""
64 | self._vz_client.complete_trial(
65 | request=aip.CompleteTrialRequest(
66 | name=self._trial_name,
67 | trial_infeasible=infeasible_reason is not None,
68 | infeasible_reason=infeasible_reason,
69 | )
70 | )
71 | logging.info('Trial %s is completed', self._trial_name)
72 |
--------------------------------------------------------------------------------
/xmanager/xm/README.md:
--------------------------------------------------------------------------------
1 | # `xm`
2 |
3 | This directory contains the Launch API base classes and common implementations
4 | that various experiment schedulers (local or provided as a service) and
5 | executors (Docker, Kubernetes, Vertex AI etc) implement.
6 |
--------------------------------------------------------------------------------
/xmanager/xm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """XManager client API.
15 |
16 | Provides XManager public API for configuring and launching experiments.
17 | """
18 |
19 | from xmanager.xm import job_operators
20 | from xmanager.xm.compute_units import *
21 | from xmanager.xm.core import AuxiliaryUnitJob
22 | from xmanager.xm.core import AuxiliaryUnitRole
23 | from xmanager.xm.core import DebugInterrupt
24 | from xmanager.xm.core import Experiment
25 | from xmanager.xm.core import ExperimentUnit
26 | from xmanager.xm.core import ExperimentUnitError
27 | from xmanager.xm.core import ExperimentUnitFailedError
28 | from xmanager.xm.core import ExperimentUnitNotCompletedError
29 | from xmanager.xm.core import ExperimentUnitRole
30 | from xmanager.xm.core import ExperimentUnitStatus
31 | from xmanager.xm.core import Importance
32 | from xmanager.xm.core import LaunchedJob
33 | from xmanager.xm.core import NotFoundError
34 | from xmanager.xm.core import ReloadError
35 | from xmanager.xm.core import WorkUnit
36 | from xmanager.xm.core import WorkUnitCompletedAwaitable
37 | from xmanager.xm.core import WorkUnitRole
38 | from xmanager.xm.executables import BazelBinary
39 | from xmanager.xm.executables import BazelContainer
40 | from xmanager.xm.executables import Binary
41 | from xmanager.xm.executables import BinaryDependency
42 | from xmanager.xm.executables import CommandList
43 | from xmanager.xm.executables import Container
44 | from xmanager.xm.executables import Dockerfile
45 | from xmanager.xm.executables import ModuleName
46 | from xmanager.xm.executables import PythonContainer
47 | from xmanager.xm.job_blocks import Constraint
48 | from xmanager.xm.job_blocks import Executable
49 | from xmanager.xm.job_blocks import ExecutableSpec
50 | from xmanager.xm.job_blocks import Executor
51 | from xmanager.xm.job_blocks import ExecutorSpec
52 | from xmanager.xm.job_blocks import get_args_for_all_jobs
53 | from xmanager.xm.job_blocks import Job
54 | from xmanager.xm.job_blocks import JobConfig
55 | from xmanager.xm.job_blocks import JobGeneratorType
56 | from xmanager.xm.job_blocks import JobGroup
57 | from xmanager.xm.job_blocks import JobType
58 | from xmanager.xm.job_blocks import merge_args
59 | from xmanager.xm.job_blocks import Packageable
60 | from xmanager.xm.job_blocks import SequentialArgs
61 | from xmanager.xm.job_blocks import UserArgs
62 | from xmanager.xm.metadata_context import ContextAnnotations
63 | from xmanager.xm.metadata_context import MetadataContext
64 | from xmanager.xm.packagables import bazel_binary
65 | from xmanager.xm.packagables import bazel_container
66 | from xmanager.xm.packagables import binary
67 | from xmanager.xm.packagables import container
68 | from xmanager.xm.packagables import dockerfile_container
69 | from xmanager.xm.packagables import python_container
70 | from xmanager.xm.resources import AcceleratorType
71 | from xmanager.xm.resources import Architecture
72 | from xmanager.xm.resources import GpuType
73 | from xmanager.xm.resources import InvalidTpuTopologyError
74 | from xmanager.xm.resources import JobRequirements
75 | from xmanager.xm.resources import ResourceDict
76 | from xmanager.xm.resources import ResourceQuantity
77 | from xmanager.xm.resources import ResourceType
78 | from xmanager.xm.resources import ServiceTier
79 | from xmanager.xm.resources import Topology
80 | from xmanager.xm.resources import TpuType
81 | from xmanager.xm.utils import run_in_asyncio_loop
82 | from xmanager.xm.utils import ShellSafeArg
83 |
--------------------------------------------------------------------------------
/xmanager/xm/async_packager_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import pickle
16 | from typing import Sequence
17 | import unittest
18 |
19 | from xmanager.xm import async_packager
20 | from xmanager.xm import job_blocks
21 | from xmanager.xm import utils
22 |
23 |
24 | def _package_batch(
25 | packageables: Sequence[job_blocks.Packageable],
26 | ) -> Sequence[job_blocks.Executable]:
27 | return [
28 | job_blocks.Executable(name=packageable.executable_spec.name)
29 | for packageable in packageables
30 | ]
31 |
32 |
33 | class _TestExecutableSpec(job_blocks.ExecutableSpec):
34 |
35 | def __init__(self, name: str) -> None:
36 | self._name = name
37 |
38 | @property
39 | def name(self) -> str:
40 | return self._name
41 |
42 |
43 | def _make_packageable(name: str) -> job_blocks.Packageable:
44 | return job_blocks.Packageable(
45 | executable_spec=_TestExecutableSpec(name),
46 | executor_spec=job_blocks.ExecutorSpec(),
47 | )
48 |
49 |
50 | class AsyncPackagerTest(unittest.TestCase):
51 |
52 | @utils.run_in_asyncio_loop
53 | async def test_async_packager_end_to_end(self):
54 | packager = async_packager.AsyncPackager(_package_batch)
55 | executable1 = packager.add(_make_packageable('1'))
56 | executable2 = packager.add(_make_packageable('2'))
57 | packager.package()
58 | self.assertEqual((await executable1).name, '1')
59 | self.assertEqual((await executable2).name, '2')
60 |
61 | @utils.run_in_asyncio_loop
62 | async def test_package_with_extra_packageables(self):
63 | packager = async_packager.AsyncPackager(_package_batch)
64 | async_executable = packager.add(_make_packageable('async'))
65 | [extra_executable] = packager.package((_make_packageable('extra'),))
66 | self.assertEqual((await async_executable).name, 'async')
67 | self.assertEqual(extra_executable.name, 'extra')
68 |
69 | @utils.run_in_asyncio_loop
70 | async def test_package_is_required(self):
71 | packager = async_packager.AsyncPackager(_package_batch)
72 | executable = packager.add(_make_packageable(''))
73 | with self.assertRaises(async_packager.PackageHasNotBeenCalledError):
74 | await executable
75 |
76 | def test_awaitable_is_picklable(self):
77 | packager = async_packager.AsyncPackager(_package_batch)
78 | executable = packager.add(_make_packageable(''))
79 | packager.package()
80 | executable_str = pickle.dumps(executable)
81 |
82 | # Wait for the executable in a separate event loop, which did not even exist
83 | # when we requested packaging.
84 | @utils.run_in_asyncio_loop
85 | async def wait_for_it():
86 | await pickle.loads(executable_str)
87 |
88 | wait_for_it()
89 |
90 | def test_awaitable_is_repeatedly_picklable(self):
91 | packager = async_packager.AsyncPackager(_package_batch)
92 | executable = packager.add(_make_packageable(''))
93 | packager.package()
94 | executable_str = pickle.dumps(executable)
95 | executable_reconstructed = pickle.loads(executable_str)
96 | executable_str2 = pickle.dumps(executable_reconstructed)
97 | self.assertEqual(executable_str, executable_str2)
98 |
99 |
100 | if __name__ == '__main__':
101 | unittest.main()
102 |
--------------------------------------------------------------------------------
/xmanager/xm/compute_units.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Defines convenience constants/functions for converting various units."""
15 |
16 | # pylint: disable=invalid-name
17 | vCPU = 1.0 # Virtual CPU
18 |
19 | # TODO: Explicit types are to work around a type inference issue.
20 | KiB: int = 2**10 # kibibyte
21 | MiB: int = 2**20 # mibibyte
22 | GiB: int = 2**30 # gibibyte
23 | TiB: int = 2**40 # tebibyte
24 | PiB: int = 2**50 # pebibyte
25 |
26 | KB = 10**3 # kilobyte
27 | MB = 10**6 # megabyte
28 | GB = 10**9 # gigabyte
29 | TB = 10**12 # terabyte
30 | PB = 10**15 # petabyte
31 | # pylint: enable=invalid-name
32 |
--------------------------------------------------------------------------------
/xmanager/xm/executables_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for executables."""
15 |
16 | import os
17 | import unittest
18 |
19 | from xmanager.xm import executables
20 | from xmanager.xm import utils
21 |
22 |
23 | class ExecutablesTest(unittest.TestCase):
24 |
25 | def test_python_container_name(self):
26 | executable = executables.PythonContainer(
27 | entrypoint=executables.ModuleName('module'),
28 | path='/home/user/project/',
29 | )
30 |
31 | self.assertEqual(executable.name, 'project')
32 |
33 | def test_container_name(self):
34 | executable = executables.Container(
35 | image_path='/home/user/project/image.tar'
36 | )
37 |
38 | self.assertEqual(executable.name, 'image_tar')
39 |
40 | def test_binary_name(self):
41 | executable = executables.Binary(path='./binary')
42 |
43 | self.assertEqual(executable.name, 'binary')
44 |
45 | def test_bazel_container_name(self):
46 | executable = executables.BazelContainer(label='//container')
47 |
48 | self.assertEqual(executable.name, 'container')
49 |
50 | def test_bazel_binary_name(self):
51 | executable = executables.BazelBinary(label=':binary')
52 |
53 | self.assertEqual(executable.name, '_binary')
54 |
55 | def test_dockerfile_defaults(self):
56 | root = utils.resolve_path_relative_to_launcher('.')
57 |
58 | spec = executables.Dockerfile()
59 | self.assertEqual(spec.path, root)
60 | self.assertEqual(spec.dockerfile, os.path.join(root, 'Dockerfile'))
61 |
62 |
63 | if __name__ == '__main__':
64 | unittest.main()
65 |
--------------------------------------------------------------------------------
/xmanager/xm/id_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A utility to predict IDs that would be assigned upon object creation.
15 |
16 | This class helps to untie this chicken and egg problem. Sometimes to create an
17 | object (for example WorkUnit) one may need to know its ID beforehand (for
18 | example to generate a checkpoint path). But the ID would only be assigned by a
19 | backend upon object creation. In ideal world we would rewrite the backend to
20 | allow ID reservation. This module provides a temporary solution which does the
21 | reservation on client-side. Following assumptions are made:
22 | * Ids are assigned sequentially by the backend, starting from some number.
23 | * Only one process at a time creates the objects. Any races are resolved only
24 | within that process.
25 |
26 | Usage:
27 | predictor = Predictor()
28 |
29 | # Obtain the ID and asynchronously construct the object.
30 | next_id = predictor.reserve_id()
31 | job = Job(args={'checkpoint_path': f'/tmp/{next_id}'})
32 |
33 | # Wait untill all objects with smaller IDs are submitted.
34 | async with predictor.submit_id(next_id):
35 | # And submit it to the backend.
36 | submit(job)
37 |
38 | If the submission fails, the sequence is considered broken and following calls
39 | to the predictor would raise an error.
40 | """
41 |
42 | import asyncio
43 | import threading
44 | from typing import AsyncIterable
45 |
46 | import async_generator
47 |
48 |
49 | class BrokenSequenceError(RuntimeError):
50 | """The ID would never be ready to submit."""
51 |
52 |
53 | class Predictor:
54 | """Predicts IDs that would be assigned on object creation.
55 |
56 | This class is thread safe and async Python friendly. It must be constructed
57 | from inside asyncio event loop.
58 | """
59 |
60 | def __init__(self, next_id: int) -> None:
61 | """Initializes the predictor.
62 |
63 | Args:
64 | next_id: The first available ID that would be assigned to the next object.
65 | """
66 | self._next_id = next_id
67 | # We use threading.Lock to allow calling reserve_id from non async context.
68 | # Note that no long operations are done under this lock.
69 | self._next_id_lock = threading.Lock()
70 |
71 | self._is_broken = False
72 | self._last_created_id = next_id - 1
73 | self._last_created_id_condition = asyncio.Condition()
74 |
75 | def reserve_id(self) -> int:
76 | """Returns the next ID."""
77 | with self._next_id_lock:
78 | next_id = self._next_id
79 | self._next_id += 1
80 | return next_id
81 |
82 | @async_generator.asynccontextmanager
83 | async def submit_id(self, id_to_submit: int) -> AsyncIterable[None]:
84 | """Waits until the ID can be submitted and marks it as such.
85 |
86 | A context manager which would wait for all smaller IDs to be submitted on
87 | entry and marks it as submitted on exit. As a result all submissions are
88 | serialized in the correct order and receive the right ID from the backend.
89 |
90 | Args:
91 | id_to_submit: The id to be submitted.
92 |
93 | Yields:
94 | Yields when it is time to send the request to the backend.
95 | """
96 | async with self._last_created_id_condition:
97 | await self._last_created_id_condition.wait_for(
98 | lambda: self._is_broken or self._last_created_id == id_to_submit - 1
99 | )
100 |
101 | try:
102 | if self._is_broken:
103 | raise BrokenSequenceError(
104 | f'Id {id} would never be ready to submit as'
105 | ' submission of the previous one has failed'
106 | )
107 | yield
108 | self._last_created_id = id_to_submit
109 | except:
110 | self._is_broken = True
111 | raise
112 | finally:
113 | self._last_created_id_condition.notify_all()
114 |
--------------------------------------------------------------------------------
/xmanager/xm/id_predictor_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import asyncio
16 | import unittest
17 |
18 | from xmanager.xm import id_predictor
19 | from xmanager.xm import utils
20 |
21 |
22 | class IdPredictorTest(unittest.TestCase):
23 |
24 | @utils.run_in_asyncio_loop
25 | async def test_first_id_is_correct(self):
26 | """Simple Predictor usage example."""
27 | predictor = id_predictor.Predictor(next_id=1)
28 |
29 | first_id = predictor.reserve_id()
30 | async with predictor.submit_id(first_id):
31 | self.assertEqual(first_id, 1)
32 |
33 | @utils.run_in_asyncio_loop
34 | async def test_ids_are_submitted_in_order(self):
35 | predictor = id_predictor.Predictor(next_id=1)
36 |
37 | self.assertEqual(predictor.reserve_id(), 1)
38 | self.assertEqual(predictor.reserve_id(), 2)
39 | self.assertEqual(predictor.reserve_id(), 3)
40 |
41 | submitted_ids = []
42 |
43 | async def submit(id_to_submit):
44 | async with predictor.submit_id(id_to_submit):
45 | submitted_ids.append(id_to_submit)
46 |
47 | await asyncio.gather(submit(3), submit(2), submit(1))
48 |
49 | self.assertEqual(submitted_ids, [1, 2, 3])
50 |
51 | @utils.run_in_asyncio_loop
52 | async def test_broken_sequence(self):
53 | predictor = id_predictor.Predictor(next_id=1)
54 |
55 | self.assertEqual(predictor.reserve_id(), 1)
56 | self.assertEqual(predictor.reserve_id(), 2)
57 |
58 | with self.assertRaises(RuntimeError):
59 | async with predictor.submit_id(1):
60 | raise RuntimeError('Id was eaten by a giant space ant.')
61 |
62 | with self.assertRaises(id_predictor.BrokenSequenceError):
63 | async with predictor.submit_id(2):
64 | pass
65 |
66 |
67 | if __name__ == '__main__':
68 | unittest.main()
69 |
--------------------------------------------------------------------------------
/xmanager/xm/job_operators_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import unittest
16 | from unittest import mock
17 | import uuid
18 |
19 | from xmanager import xm_mock
20 | from xmanager.xm import job_blocks
21 | from xmanager.xm import job_operators
22 |
23 |
24 | TEST_UUID = uuid.UUID('40c2fc75-9424-4e33-a929-2e0cc631dccf')
25 |
26 |
27 | def construct_job(name=None):
28 | return job_blocks.Job(
29 | name=name,
30 | executable=xm_mock.MockExecutable(),
31 | executor=xm_mock.MockExecutor(),
32 | )
33 |
34 |
35 | class JobOperatorsTest(unittest.TestCase):
36 |
37 | def test_collect_jobs_by_filter_gathers_matches(self):
38 | job_group = job_blocks.JobGroup(
39 | foo=construct_job('foo'),
40 | bar=construct_job('bar'),
41 | baz=construct_job('baz'),
42 | )
43 |
44 | self.assertEqual(
45 | job_operators.collect_jobs_by_filter(
46 | job_group,
47 | predicate=lambda job: job.name in ['foo', 'baz'],
48 | ),
49 | [job_group.jobs['foo'], job_group.jobs['baz']],
50 | )
51 |
52 | def test_flatten_jobs_traverses_nested_groups(self):
53 | baz = construct_job('baz')
54 | foo = construct_job('foo')
55 | job_group = job_blocks.JobGroup(
56 | foo=foo,
57 | bar=job_blocks.JobGroup(baz=baz),
58 | )
59 |
60 | self.assertEqual(
61 | job_operators.flatten_jobs(job_group),
62 | [foo, baz],
63 | )
64 |
65 | @mock.patch.object(uuid, 'uuid4', return_value=TEST_UUID)
66 | def test_aggregate_constraint_cliques(self, _):
67 | outer_1 = construct_job('outer_1')
68 | inner_1 = construct_job('inner_1')
69 | inner_2 = construct_job('inner_2')
70 | constraint_a = xm_mock.MockConstraint('A')
71 | constraint_b = xm_mock.MockConstraint('B')
72 | constraint_c = xm_mock.MockConstraint('C')
73 | job_group = job_blocks.JobGroup(
74 | outer_1=outer_1,
75 | outer_2=job_blocks.JobGroup(
76 | inner_1=inner_1,
77 | inner_2=inner_2,
78 | constraints=[constraint_b, constraint_c],
79 | ),
80 | constraints=[constraint_a],
81 | )
82 |
83 | group0_name = f'jobgroup_0_{TEST_UUID.hex}'
84 | group1_name = f'jobgroup_1_{TEST_UUID.hex}'
85 | self.assertEqual(
86 | job_operators.aggregate_constraint_cliques(job_group),
87 | [
88 | job_operators.ConstraintClique(
89 | constraint=constraint_a,
90 | jobs=[outer_1, inner_1, inner_2],
91 | group_name=group0_name,
92 | size=2,
93 | parent_group_name=None,
94 | ),
95 | job_operators.ConstraintClique(
96 | constraint=constraint_b,
97 | jobs=[inner_1, inner_2],
98 | group_name=group1_name,
99 | size=2,
100 | parent_group_name=group0_name,
101 | ),
102 | job_operators.ConstraintClique(
103 | constraint=constraint_c,
104 | jobs=[inner_1, inner_2],
105 | group_name=group1_name,
106 | size=2,
107 | parent_group_name=group0_name,
108 | ),
109 | ],
110 | )
111 |
112 |
113 | if __name__ == '__main__':
114 | unittest.main()
115 |
--------------------------------------------------------------------------------
/xmanager/xm/metadata_context.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """An interface to manipulate and access experiment metadata.
15 |
16 | Metadata is attached to a context and the context may belong to an experiment
17 | or a work unit.
18 | """
19 |
20 | from typing import Any, Collection, List, Mapping, Set
21 | import attr
22 |
23 |
24 | class ContextAnnotations:
25 | """Interface for managing annotations.
26 |
27 | Annotations are user-supplied attributes of a context, such as title or tags.
28 | Default method implementations are intentionally left blank so that backends
29 | only have to implement the subset they support.
30 | """
31 |
32 | @property
33 | def title(self) -> str:
34 | """An experiment title.
35 |
36 | To differentiate experiments from each other they can be given a human
37 | readable title. Same title can be reused for multiple experiments.
38 | """
39 | return ''
40 |
41 | def set_title(self, title: str) -> None:
42 | """Sets the context title."""
43 |
44 |
45 | @attr.s(auto_attribs=True)
46 | class MetadataContext:
47 | """Interface for managing metadata.
48 |
49 | The metadata context could be attached to an experiment or a work unit.
50 |
51 |
52 | Attributes:
53 | creator: The username of the creator of this context.
54 | annotations: User-modifiable annotations.
55 | """
56 |
57 | creator: str
58 | annotations: ContextAnnotations
59 |
--------------------------------------------------------------------------------
/xmanager/xm/packagables_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for packagables."""
15 |
16 | import unittest
17 |
18 | from xmanager.xm import executables
19 | from xmanager.xm import job_blocks
20 | from xmanager.xm import packagables
21 | from xmanager.xm_local import executors
22 |
23 |
24 | class PackagablesTest(unittest.TestCase):
25 |
26 | def test_minimal_executable_spec(self):
27 | expected = job_blocks.Packageable(
28 | executable_spec=executables.BazelBinary(label='label'),
29 | executor_spec=executors.Local.Spec(),
30 | args=[],
31 | env_vars={},
32 | )
33 |
34 | actual = packagables.bazel_binary(executors.Local.Spec(), label='label')
35 |
36 | self.assertEqual(actual, expected)
37 |
38 | def test_pkg_args_env_vars(self):
39 | expected = job_blocks.Packageable(
40 | executable_spec=executables.BazelBinary(label='label'),
41 | executor_spec=executors.Local.Spec(),
42 | args=['-f'],
43 | env_vars={'KEY': 'value'},
44 | )
45 |
46 | actual = packagables.bazel_binary(
47 | executors.Local.Spec(),
48 | label='label',
49 | args=['-f'],
50 | env_vars={'KEY': 'value'},
51 | )
52 |
53 | self.assertEqual(actual, expected)
54 |
55 |
56 | if __name__ == '__main__':
57 | unittest.main()
58 |
--------------------------------------------------------------------------------
/xmanager/xm/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import enum
16 | import unittest
17 |
18 | from xmanager.xm import utils
19 |
20 |
21 | async def make_me_a_sandwich() -> str:
22 | return 'sandwich'
23 |
24 |
25 | class ResourceType(enum.Enum):
26 | MINERALS = 1
27 | VESPEN = 2
28 |
29 |
30 | class UtilsTest(unittest.TestCase):
31 |
32 | @utils.run_in_asyncio_loop
33 | async def test_run_in_asyncio_loop(self):
34 | self.assertEqual(await make_me_a_sandwich(), 'sandwich')
35 |
36 | def test_run_in_asyncio_loop_returns_value(self):
37 | self.assertEqual(
38 | utils.run_in_asyncio_loop(make_me_a_sandwich)(), 'sandwich'
39 | )
40 |
41 | def test_arg_escaper(self):
42 | self.assertEqual(utils.ARG_ESCAPER(1.0), '1.0')
43 | self.assertEqual(utils.ARG_ESCAPER('Jonny Droptable'), "'Jonny Droptable'")
44 | self.assertEqual(utils.ARG_ESCAPER(ResourceType.VESPEN), 'VESPEN')
45 |
46 | def test_shell_safe_arg_in_f_string(self):
47 | # ShellSafeArg shouldn't be used in f-strings.
48 | with self.assertRaises(RuntimeError):
49 | f'{utils.ShellSafeArg("42")}' # pylint: disable=expression-not-assigned
50 |
51 |
52 | if __name__ == '__main__':
53 | unittest.main()
54 |
--------------------------------------------------------------------------------
/xmanager/xm_local/README.md:
--------------------------------------------------------------------------------
1 | # `xm_local`
2 |
3 | This directory contains the Launch API scheduler components for running
4 | experiments using the Local Scheduler.
5 |
--------------------------------------------------------------------------------
/xmanager/xm_local/executables.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Local backend executables."""
15 |
16 | from typing import Dict
17 |
18 | import attr
19 | from xmanager import xm
20 |
21 |
22 | @attr.s(auto_attribs=True)
23 | class LoadedContainerImage(xm.Executable):
24 | """A locally loaded container image."""
25 |
26 | image_id: str
27 | args: xm.SequentialArgs = attr.Factory(xm.SequentialArgs)
28 | env_vars: Dict[str, str] = attr.Factory(dict)
29 |
30 |
31 | @attr.s(auto_attribs=True)
32 | class LocalBinary(xm.Executable):
33 | """A locally located binary."""
34 |
35 | command: str
36 | args: xm.SequentialArgs = attr.Factory(xm.SequentialArgs)
37 | env_vars: Dict[str, str] = attr.Factory(dict)
38 |
39 |
40 | @attr.s(auto_attribs=True)
41 | class GoogleContainerRegistryImage(xm.Executable):
42 | """An image inside Google Container Registry."""
43 |
44 | image_path: str
45 | args: xm.SequentialArgs = attr.Factory(xm.SequentialArgs)
46 | env_vars: Dict[str, str] = attr.Factory(dict)
47 |
--------------------------------------------------------------------------------
/xmanager/xm_local/executors_test.py:
--------------------------------------------------------------------------------
1 | """Tests for xmanager.xm_local.executors."""
2 |
3 | import unittest
4 | from unittest import mock
5 |
6 | from xmanager.xm_local import executors
7 |
8 |
9 | class ExecutorsTest(unittest.TestCase):
10 |
11 | @mock.patch.object(executors, 'importlib')
12 | def test_executor_missing_required_module(self, mock_importlib):
13 | mock_importlib.import_module.side_effect = ModuleNotFoundError(
14 | "No module named 'xmanager.cloud.kubernetes'"
15 | )
16 | with self.assertRaises(ModuleNotFoundError):
17 | executors.Kubernetes()
18 |
19 |
20 | if __name__ == '__main__':
21 | unittest.main()
22 |
--------------------------------------------------------------------------------
/xmanager/xm_local/handles.py:
--------------------------------------------------------------------------------
1 | """Local execution handles."""
2 |
3 | import abc
4 | import asyncio
5 | from concurrent import futures
6 | import logging
7 |
8 | import attr
9 | import docker
10 | from docker.models import containers
11 | from xmanager.xm_local import status
12 |
13 |
14 | _DEFAULT_ENCODING = 'utf-8'
15 |
16 |
17 | def _print_chunk(name: str, line: str) -> None:
18 | print('[{}] {}'.format(name, line.strip()))
19 |
20 |
21 | class ExecutionHandle(abc.ABC):
22 | """An interface for operating on executions."""
23 |
24 | @abc.abstractmethod
25 | async def wait(self) -> None:
26 | raise NotImplementedError
27 |
28 | @abc.abstractmethod
29 | def get_status(self) -> status.LocalWorkUnitStatus:
30 | """Aggregates the statuses of all jobs in the work unit into one status."""
31 | raise NotImplementedError
32 |
33 | def save_to_storage(self, experiment_id: int, work_unit_id: int) -> None:
34 | """Saves the handle to the local database."""
35 | raise NotImplementedError
36 |
37 | def stop(self) -> None:
38 | """Stops execution."""
39 | raise NotImplementedError
40 |
41 |
42 | class LocalExecutionHandle(ExecutionHandle, abc.ABC):
43 | """An interface for operating on local executions."""
44 |
45 | @abc.abstractmethod
46 | async def monitor(self) -> None:
47 | raise NotImplementedError
48 |
49 | @abc.abstractmethod
50 | def terminate(self) -> None:
51 | raise NotImplementedError
52 |
53 |
54 | @attr.s(auto_attribs=True)
55 | class ContainerHandle(LocalExecutionHandle):
56 | """A handle for referring to the launched container."""
57 |
58 | name: str
59 | model: containers.Container | None
60 | stream_output: bool
61 | futures_executor: futures.Executor = attr.Factory(futures.ThreadPoolExecutor)
62 |
63 | async def wait(self) -> None:
64 | if self.model is None:
65 | return
66 |
67 | def _wait() -> None:
68 | try:
69 | self.model.wait()
70 | except docker.errors.NotFound:
71 | logging.info(
72 | 'Container %s not found (it may have already been removed).',
73 | self.model.name,
74 | )
75 |
76 | await asyncio.wrap_future(self.futures_executor.submit(_wait))
77 |
78 | def get_status(self) -> status.LocalWorkUnitStatus:
79 | raise NotImplementedError
80 |
81 | def terminate(self) -> None:
82 | if self.model is None:
83 | return
84 |
85 | self.model.stop()
86 | self.futures_executor.shutdown(wait=True)
87 |
88 | async def monitor(self) -> None:
89 | if self.model is None:
90 | return
91 |
92 | def _stream_chunks() -> None:
93 | try:
94 | for chunk in self.model.logs(stream=True, follow=True):
95 | _print_chunk(self.name, chunk.decode(_DEFAULT_ENCODING))
96 | except docker.errors.NotFound:
97 | logging.info(
98 | 'Container %s not found (it may have already been removed).',
99 | self.model.name,
100 | )
101 |
102 | if self.stream_output:
103 | await asyncio.wrap_future(self.futures_executor.submit(_stream_chunks))
104 |
105 |
106 | @attr.s(auto_attribs=True)
107 | class BinaryHandle(LocalExecutionHandle):
108 | """A handle referring to the launched binary."""
109 |
110 | name: str
111 | process: asyncio.subprocess.Process # pytype: disable=module-attr
112 | stream_output: bool
113 |
114 | async def wait(self) -> None:
115 | return_code = await self.process.wait()
116 | if return_code != 0:
117 | raise RuntimeError(
118 | f'Process {self.process!r} returned non-zero code: {return_code}'
119 | )
120 |
121 | def get_status(self) -> status.LocalWorkUnitStatus:
122 | raise NotImplementedError
123 |
124 | def terminate(self) -> None:
125 | self.process.terminate()
126 |
127 | async def monitor(self) -> None:
128 | if self.stream_output:
129 | if not self.process.stdout:
130 | raise ValueError(
131 | 'No stdout available from process. Cannot stream output.'
132 | )
133 | while True:
134 | line = await self.process.stdout.readline()
135 | if not line:
136 | break
137 | _print_chunk(self.name, line.decode(_DEFAULT_ENCODING))
138 |
--------------------------------------------------------------------------------
/xmanager/xm_local/packaging/bazel_tools_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import unittest
16 |
17 | from xmanager.xm_local.packaging import bazel_tools
18 |
19 |
20 | class BazelToolsTest(unittest.TestCase):
21 |
22 | def test_lex_full_label(self):
23 | self.assertEqual(
24 | bazel_tools._lex_label('//project/directory:target'),
25 | (['project', 'directory'], 'target'),
26 | )
27 |
28 | def test_lex_short_label(self):
29 | self.assertEqual(
30 | bazel_tools._lex_label('//project/package'),
31 | (['project', 'package'], 'package'),
32 | )
33 |
34 | def test_lex_root_target(self):
35 | self.assertEqual(bazel_tools._lex_label('//:label'), ([], 'label'))
36 |
37 | def test_lex_empty_label(self):
38 | with self.assertRaises(ValueError):
39 | bazel_tools._lex_label('//')
40 |
41 | def test_lex_relative_label(self):
42 | with self.assertRaises(ValueError):
43 | bazel_tools._lex_label('a/b:c')
44 |
45 | def test_assemble_label(self):
46 | self.assertEqual(bazel_tools._assemble_label((['a', 'b'], 'c')), '//a/b:c')
47 |
48 | def test_label_kind_lines_to_dict(self):
49 | self.assertEqual(
50 | bazel_tools._label_kind_lines_to_dict([
51 | 'py_binary rule //:py_target',
52 | 'cc_binary rule //:cc_target',
53 | ]),
54 | {'//:py_target': 'py_binary rule', '//:cc_target': 'cc_binary rule'},
55 | )
56 |
57 | def test_absolute_label_with_extension_dot(self):
58 | self.assertEqual(
59 | bazel_tools._lex_label('//project/directory:image.tar'),
60 | (['project', 'directory'], 'image.tar'),
61 | )
62 |
63 | def test_label_with_three_dots(self):
64 | with self.assertRaisesRegex(ValueError, 'is not an absolute Bazel label'):
65 | bazel_tools._lex_label('//project/directory/...')
66 |
67 | def test_label_with_star_target(self):
68 | with self.assertRaisesRegex(ValueError, 'is not an absolute Bazel label'):
69 | bazel_tools._lex_label('//project/directory:*')
70 |
71 | def test_label_with_all_target(self):
72 | with self.assertRaisesRegex(ValueError, '`:all` is not a valid target'):
73 | bazel_tools._lex_label('//project/directory:all')
74 |
75 |
76 | if __name__ == '__main__':
77 | unittest.main()
78 |
--------------------------------------------------------------------------------
/xmanager/xm_local/packaging/router.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Methods for routing packageables to appropriate packagers."""
15 |
16 | import collections
17 | from typing import Dict, List, Sequence, Tuple
18 |
19 | from xmanager import xm
20 | from xmanager.bazel import client as bazel_client
21 | from xmanager.xm_local import executors
22 | from xmanager.xm_local.packaging import bazel_tools
23 | from xmanager.xm_local.packaging import cloud as cloud_packaging
24 | from xmanager.xm_local.packaging import local as local_packaging
25 |
26 |
27 | def _packaging_router(
28 | built_targets: bazel_tools.TargetOutputs, packageable: xm.Packageable
29 | ) -> xm.Executable:
30 | match packageable.executor_spec:
31 | case executors.VertexSpec():
32 | return cloud_packaging.package_cloud_executable(
33 | built_targets,
34 | packageable,
35 | packageable.executable_spec,
36 | )
37 | case executors.LocalSpec():
38 | return local_packaging.package_for_local_executor(
39 | built_targets,
40 | packageable,
41 | packageable.executable_spec,
42 | )
43 | case executors.KubernetesSpec():
44 | return cloud_packaging.package_cloud_executable(
45 | built_targets,
46 | packageable,
47 | packageable.executable_spec,
48 | )
49 | case _:
50 | raise TypeError(
51 | f'Unsupported executor specification: {packageable.executor_spec!r}. '
52 | f'Packageable: {packageable!r}'
53 | )
54 |
55 |
56 | _ArgsToTargets = Dict[Tuple[str, ...], List[bazel_client.BazelTarget]]
57 |
58 |
59 | def package(packageables: Sequence[xm.Packageable]) -> List[xm.Executable]:
60 | """Routes a packageable to an appropriate packaging mechanism."""
61 | built_targets: bazel_tools.TargetOutputs = {}
62 | bazel_targets = bazel_tools.collect_bazel_targets(packageables)
63 |
64 | if bazel_targets:
65 | bazel_service = bazel_tools.local_bazel_service()
66 |
67 | args_to_targets: _ArgsToTargets = collections.defaultdict(list)
68 | for target in bazel_targets:
69 | args_to_targets[target.bazel_args].append(target)
70 | for args, targets in args_to_targets.items():
71 | outputs = bazel_service.build_targets(
72 | labels=tuple(target.label for target in targets),
73 | bazel_args=args,
74 | )
75 | for target, output in zip(targets, outputs):
76 | built_targets[target] = output
77 |
78 | return [
79 | _packaging_router(built_targets, packageable)
80 | for packageable in packageables
81 | ]
82 |
--------------------------------------------------------------------------------
/xmanager/xm_local/registry.py:
--------------------------------------------------------------------------------
1 | """Registry of execution logic for each executor type.
2 |
3 | Execution logic is automatically registered when the corresponding Executor
4 | class is instantiated.
5 |
6 | The registry allows users to avoid heavy dependencies for executors that are not
7 | used in the experiment (i.e. heavy Kubernetes deps for the `xm_local.Kubernetes`
8 | executor).
9 | """
10 |
11 | from typing import Awaitable, Callable, Generic, Sequence, Type, TypeVar
12 |
13 | import attr
14 | from xmanager import xm
15 | from xmanager.xm_local import handles
16 |
17 |
18 | _Handle = TypeVar('_Handle', bound=handles.ExecutionHandle)
19 |
20 |
21 | @attr.s(auto_attribs=True)
22 | class _ExecutorInfo(Generic[_Handle]):
23 | # Method to launch a job group and get the execution handles.
24 | launch: Callable[..., Awaitable[Sequence[_Handle]]]
25 | # Method to create an execution handle using data from the local database.
26 | create_handle: Callable[..., _Handle] | None = None
27 |
28 |
29 | _REGISTRY: dict[Type[xm.Executor], _ExecutorInfo] = {}
30 |
31 |
32 | def register(
33 | executor_type: Type[xm.Executor],
34 | launch: Callable[..., Awaitable[Sequence[_Handle]]],
35 | create_handle: Callable[..., _Handle] | None = None,
36 | ):
37 | _REGISTRY[executor_type] = _ExecutorInfo(
38 | launch=launch, create_handle=create_handle
39 | )
40 |
41 |
42 | def is_registered(executor_type: Type[xm.Executor]) -> bool:
43 | return executor_type in _REGISTRY
44 |
45 |
46 | def get_launch_method(
47 | executor_type: Type[xm.Executor],
48 | ) -> Callable[..., Awaitable[Sequence[_Handle]]]:
49 | return _REGISTRY[executor_type].launch
50 |
51 |
52 | def get_create_handle_method(
53 | executor_type: Type[xm.Executor],
54 | ) -> Callable[..., _Handle] | None:
55 | return _REGISTRY[executor_type].create_handle
56 |
--------------------------------------------------------------------------------
/xmanager/xm_local/status.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Implementation of local work unit statuses."""
15 | import enum
16 |
17 | from xmanager import xm
18 |
19 |
20 | class LocalWorkUnitStatusEnum(enum.Enum):
21 | """Status of a local experiment job."""
22 |
23 | # Work unit was created, but has not terminated yet.
24 | RUNNING = 1
25 | # Work unit terminated and was successful.
26 | COMPLETED = 2
27 | # Work unit terminated and was not succesful.
28 | FAILED = 3
29 | # Work unit terminated because it was cancelled by the user.
30 | CANCELLED = 4
31 |
32 |
33 | class LocalWorkUnitStatus(xm.ExperimentUnitStatus):
34 | """Status of a local experiment job."""
35 |
36 | def __init__(
37 | self, status: LocalWorkUnitStatusEnum, message: str = ''
38 | ) -> None:
39 | super().__init__()
40 | self._status = status
41 | self._message = message
42 |
43 | @property
44 | def is_active(self) -> bool:
45 | return self._status == LocalWorkUnitStatusEnum.RUNNING
46 |
47 | @property
48 | def is_completed(self) -> bool:
49 | return self._status == LocalWorkUnitStatusEnum.COMPLETED
50 |
51 | @property
52 | def is_failed(self) -> bool:
53 | return self._status == LocalWorkUnitStatusEnum.FAILED
54 |
55 | @property
56 | def message(self) -> str:
57 | return self._message
58 |
--------------------------------------------------------------------------------
/xmanager/xm_local/storage/alembic.ini:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # A generic, single database configuration.
16 |
17 | [alembic]
18 | # path to migration scripts
19 | script_location = SET_AT_RUNTIME_BY_XMANAGER
20 |
21 | # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
22 | # Uncomment the line below if you want the files to be prepended with date and time
23 | # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
24 | # for all available tokens
25 | # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
26 |
27 | # sys.path path, will be prepended to sys.path if present.
28 | # defaults to the current working directory.
29 | prepend_sys_path = .
30 |
31 | # timezone to use when rendering the date within the migration file
32 | # as well as the filename.
33 | # If specified, requires the python-dateutil library that can be
34 | # installed by adding `alembic[tz]` to the pip requirements
35 | # string value is passed to dateutil.tz.gettz()
36 | # leave blank for localtime
37 | # timezone =
38 |
39 | # max length of characters to apply to the
40 | # "slug" field
41 | # truncate_slug_length = 40
42 |
43 | # set to 'true' to run the environment during
44 | # the 'revision' command, regardless of autogenerate
45 | # revision_environment = false
46 |
47 | # set to 'true' to allow .pyc and .pyo files without
48 | # a source .py file to be detected as revisions in the
49 | # versions/ directory
50 | # sourceless = false
51 |
52 | # version location specification; This defaults
53 | # to alembic/versions. When using multiple version
54 | # directories, initial revisions must be specified with --version-path.
55 | # The path separator used here should be the separator specified by "version_path_separator" below.
56 | # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
57 |
58 | # version path separator; As mentioned above, this is the character used to split
59 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
60 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
61 | # Valid values for version_path_separator are:
62 | #
63 | # version_path_separator = :
64 | # version_path_separator = ;
65 | # version_path_separator = space
66 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
67 |
68 | # the output encoding used when revision files
69 | # are written from script.py.mako
70 | # output_encoding = utf-8
71 |
72 | sqlalchemy.url = SET_AT_RUNTIME_BY_XMANAGER
73 |
74 |
75 | [post_write_hooks]
76 | # post_write_hooks defines scripts or Python functions that are run
77 | # on newly generated revision scripts. See the documentation for further
78 | # detail and examples
79 |
80 | # format using "black" - use the console_scripts runner, against the "black" entrypoint
81 | # hooks = black
82 | # black.type = console_scripts
83 | # black.entrypoint = black
84 | # black.options = -l 79 REVISION_SCRIPT_FILENAME
85 |
--------------------------------------------------------------------------------
/xmanager/xm_local/storage/alembic/README.md:
--------------------------------------------------------------------------------
1 | # Database Migration
2 |
3 | ## Alembic
4 |
5 | XManager updates may introduce updates to the schema of the database
6 | used to store experiment metadata. Therefore, the existing database must adapt
7 | to the new changes. This is achieved by using
8 | [`alembic`](https://alembic.sqlalchemy.org/en/latest/) to perform migrations.
9 |
10 | XManager comes with an already configured `alembic` environment in the
11 | `xm_local.storage` package. Migration scripts are provided in the
12 | `alembic/versions` directory present there.
13 |
14 | A new revision can be created using the `alembic revision` command:
15 |
16 | ```
17 | $ cd xmanager/xm_local/storage
18 | $ alembic -m revision "Create Example table."
19 | Generating ${PATH_TO_XMANAGER}/xm_local/storage/alembic/versions/106c21f078d5_create_example_table.py ... done
20 | ```
21 |
22 | A new file `106c21f078d5_create_example_table.py` is generated. This file
23 | contains details about the migration, like the revision ID `106c21f078d5` and
24 | the `upgrade` and `downgrade` functions. One has to populate these functions
25 | with directives that will apply changes to the database.
26 |
27 | ```py
28 | """Create Example table.
29 |
30 | Revision ID: 106c21f078d5
31 | Revises: 5cbe03fe7ed1
32 | Create Date: 2022-09-20 09:23:48.029756
33 |
34 | """
35 | from alembic import op
36 | import sqlalchemy as sa
37 |
38 |
39 | # revision identifiers, used by Alembic.
40 | revision = '106c21f078d5'
41 | down_revision = '5cbe03fe7ed1'
42 | branch_labels = None
43 | depends_on = None
44 |
45 |
46 | def upgrade() -> None:
47 | pass
48 |
49 |
50 | def downgrade() -> None:
51 | pass
52 | ```
53 |
54 | We can add some directives to our script, e.g. adding a new table `Example`:
55 |
56 | ```py
57 | def upgrade() -> None:
58 | op.create_table(
59 | 'Example',
60 | sa.Column('Id', sa.Integer, primary_key=True)
61 | )
62 |
63 | def downgrade() -> None:
64 | op.drop_table('Example')
65 | ```
66 |
67 | Launching XManager now using the `--xm_upgrade_db` flag will attempt to update
68 | the database. For example, running
69 |
70 | ```sh
71 | $ xmanager launch examples/cifar10_tensorflow/launcher.py -- --xm_upgrade_db
72 | ```
73 | will perform the update and some information about this will be printed:
74 |
75 | ```
76 | INFO [alembic.runtime.migration] Context impl SQLiteImpl.
77 | INFO [alembic.runtime.migration] Will assume non-transactional DDL.
78 | INFO [alembic.runtime.migration] Running upgrade 5cbe03fe7ed1-> 106c21f078d5', Create Example table.
79 | ```
80 |
81 | For more details on how `alembic` works and how it can be used, check out the
82 | [tutorial](https://alembic.sqlalchemy.org/en/latest/tutorial.html).
83 |
84 | If more developers work on migrations, make sure to sync with the latest
85 | version of the code, check if there are multiple heads / any issues with the
86 | `alembic` history and create a merge revision accordingly. Check
87 | [this tutorial](https://alembic.sqlalchemy.org/en/latest/branches.html#working-with-branches)
88 | for more details on branching.
89 |
90 | NOTE: Migrations created by the user are not explicitly supported by XManager.
91 | On update, it's the responsibility of the user to adapt to new revisions.
92 |
93 | ## Automatic Migration on Update
94 |
95 | If XManager is updated and a new version of the database is available (a new
96 | migration script is provided), the currently used database
97 | (local SQLite or specified through the YAML config file) must be upgraded
98 | to run XManager.
99 |
100 | When using `--xm_upgrade_db`, XManager attempts to update the database
101 | to the latest version. Taking a backup of the database
102 | before using this flag is recommended, since migrations may fail.
103 |
104 | This doesn't apply for the case when a new database is used.
105 | The initial tables will be created automatically for a new database without
106 | having to use the flag.
107 |
108 | NOTE: Making changes to the database from outside of the XManager client
109 | is not supported and discouraged (SQL queries, `alembic` upgrades/downgrades).
110 | All database operations should be performed
111 | automatically by the client.
--------------------------------------------------------------------------------
/xmanager/xm_local/storage/alembic/env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Alembic env.py."""
15 |
16 | from alembic import context
17 | from sqlalchemy import engine_from_config
18 | from sqlalchemy import pool
19 |
20 | config = context.config
21 |
22 | target_metadata = None
23 |
24 |
25 | def run_migrations_offline() -> None:
26 | """Run migrations in 'offline' mode.
27 |
28 | This configures the context with just a URL
29 | and not an Engine, though an Engine is acceptable
30 | here as well. By skipping the Engine creation
31 | we don't even need a DBAPI to be available.
32 | """
33 | url = config.get_main_option('sqlalchemy.url')
34 | context.configure(
35 | url=url,
36 | target_metadata=target_metadata,
37 | literal_binds=True,
38 | dialect_opts={'paramstyle': 'named'},
39 | )
40 |
41 | with context.begin_transaction():
42 | context.run_migrations()
43 |
44 |
45 | def run_migrations_online() -> None:
46 | """Run migrations in 'online' mode.
47 |
48 | In this scenario we need to create an Engine
49 | and associate a connection with the context.
50 | """
51 | connectable = config.attributes.get('connection', None)
52 |
53 | if connectable is None:
54 | connectable = engine_from_config(
55 | config.get_section(config.config_ini_section),
56 | prefix='sqlalchemy.',
57 | poolclass=pool.NullPool,
58 | )
59 |
60 | with connectable.connect() as connection:
61 | context.configure(connection=connection, target_metadata=target_metadata)
62 |
63 | with context.begin_transaction():
64 | context.run_migrations()
65 | else:
66 | context.configure(connection=connectable, target_metadata=target_metadata)
67 |
68 | with context.begin_transaction():
69 | context.run_migrations()
70 |
71 |
72 | if context.is_offline_mode():
73 | run_migrations_offline()
74 | else:
75 | run_migrations_online()
76 |
--------------------------------------------------------------------------------
/xmanager/xm_local/storage/alembic/script.py.mako:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Generated script names don't follow the module naming convention
16 | # pylint: disable=invalid-name
17 | """${message}
18 |
19 | Revision ID: ${up_revision}
20 | Revises: ${down_revision | comma,n}
21 | Create Date: ${create_date}
22 |
23 | """
24 | from alembic import op
25 | import sqlalchemy as sa
26 | ${imports if imports else ""}
27 |
28 | # revision identifiers, used by Alembic.
29 | revision = ${repr(up_revision)}
30 | down_revision = ${repr(down_revision)}
31 | branch_labels = ${repr(branch_labels)}
32 | depends_on = ${repr(depends_on)}
33 |
34 |
35 | def upgrade() -> None:
36 | """Upgrades DB."""
37 | ${upgrades if upgrades else "pass"}
38 |
39 |
40 | def downgrade() -> None:
41 | """Downgrades DB."""
42 | ${downgrades if downgrades else "pass"}
43 |
--------------------------------------------------------------------------------
/xmanager/xm_local/storage/data.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2021 DeepMind Technologies Limited
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto3";
16 |
17 | package xmanager;
18 |
19 | message Job {
20 | oneof kind {
21 | LocalJob local = 1;
22 | AIPlatformJob caip = 2;
23 | KubernetesJob kubernetes = 3;
24 | }
25 | }
26 |
27 | message LocalJob {
28 | // The identity of the job as identified by `ps -o pid,cmd,etime`.
29 | // etime is then converted to an unix timestamp
30 | string pid = 1;
31 | string cmd = 2;
32 | int64 timestamp = 3;
33 | }
34 |
35 | message AIPlatformJob {
36 | // The resource name of the job. This will be in the form:
37 | // projects/{project}/locations/{location}/customJobs/{customJobId}
38 | string resource_name = 1;
39 | }
40 |
41 | message KubernetesJob {
42 | // Namespace of the kubernetes job.
43 | string namespace = 1;
44 | // Name of the kubernetes job.
45 | string job_name = 2;
46 | }
47 |
--------------------------------------------------------------------------------