├── .deepsource.toml ├── setup.py ├── LICENSE ├── .gitignore ├── examples ├── sweep.py └── pod.py ├── README.md └── tpucare └── __init__.py /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | [[analyzers]] 4 | name = "python" 5 | enabled = true 6 | 7 | [analyzers.meta] 8 | runtime_version = "3.x.x" 9 | max_line_length = 120 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | with open('README.md') as f: 5 | README = f.read() 6 | 7 | setuptools.setup( 8 | author="Lucas Nestler", 9 | author_email="github.tpucare@nestler.sh", 10 | name='tpucare', 11 | license='BSD', 12 | description='Automatically take good care of your preemptible TPUs', 13 | version='0.5.0', 14 | long_description=README, 15 | url='https://github.com/clashluke/tpucare', 16 | packages=setuptools.find_packages(), 17 | python_requires=">=3.7", 18 | long_description_content_type="text/markdown", 19 | install_requires=[], 20 | classifiers=[ 21 | # Trove classifiers 22 | # (https://pypi.python.org/pypi?%3Aaction=list_classifiers) 23 | 'Development Status :: 5 - Production/Stable', 24 | 'License :: OSI Approved :: BSD License', 25 | 'Programming Language :: Python', 26 | 'Programming Language :: Python :: 3.7', 27 | 'Programming Language :: Python :: 3.8', 28 | 'Programming Language :: Python :: 3.9', 29 | 'Topic :: Software Development :: Libraries', 30 | 'Topic :: Software Development :: Libraries :: Python Modules', 31 | 'Intended Audience :: Developers', 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022, Lucas Nestler 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /examples/sweep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import typing 4 | from netrc import netrc 5 | 6 | import wandb 7 | import yaml 8 | 9 | from tpucare import delete_all, exec_command, exec_on_tpu, send_to_tpu, start_multiple 10 | 11 | _, _, wandb_key = netrc().authenticators("api.wandb.ai") 12 | 13 | 14 | @dataclasses.dataclass 15 | class Context: 16 | zone: str 17 | host: str 18 | sweep_id: str 19 | 20 | 21 | def start_fn(ctx: Context, worker: int): 22 | cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key, 23 | run_command=f"/home/ubuntu/.local/bin/wandb agent {ctx.sweep_id}") 24 | send_to_tpu(ctx.host, ctx.zone, "setup.sh", cmd, worker) 25 | exec_on_tpu(ctx.host, ctx.zone, "bash setup.sh", worker) 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--prefix", type=str, help="Prefix used to identify TPUs") 31 | parser.add_argument("--tpu-version", type=int, default=3, help="Which TPU version to create (v2-8 or v3-8)") 32 | parser.add_argument("--zone", type=str, default="europe-west4-a", help="GCP Zone TPUs get created in") 33 | parser.add_argument("--preemptible", default=1, type=int, 34 | help="Whether to create preemptible or non-preemptible TPUs") 35 | parser.add_argument("--service-account", type=str, 36 | help="Service account that controls permissions of TPU (for example, to ensure EU TPUs " 37 | "won't use US data)") 38 | parser.add_argument("--branch", type=str, default="main", help="Branch on github to use") 39 | parser.add_argument("--slices", default=1, type=int, 40 | help="How many TPU slices each TPU should have (1=>vX-8, 4=>vX-32)") 41 | parser.add_argument("--config-path", type=str, help="Path to sweep's config.yaml") 42 | parser.add_argument("--cleanup", default=0, type=int, 43 | help="Instead of running something new, kill all tpus. 1 or 0 for y/n") 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | 51 | if args.cleanup: 52 | return delete_all(args.prefix, args.zone) 53 | 54 | with open(args.config_path, 'r') as f: 55 | config = yaml.safe_load(f.read()) 56 | sweep_id = wandb.sweep(config, entity="homebrewnlp", project="gpt") 57 | 58 | def creation_callback(host: str, ctx: typing.Optional[Context]) -> Context: 59 | if ctx is None: 60 | return Context(zone=args.zone, host=host, sweep_id=sweep_id) 61 | return ctx 62 | 63 | return start_multiple(args.host, args.tpu_version, args.zone, args.preemptible, args.service_account, 64 | args.slices, start_fn, creation_callback, args.tpus) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /examples/pod.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import typing 4 | from netrc import netrc 5 | 6 | import wandb 7 | import yaml 8 | 9 | from tpucare import exec_command, exec_on_tpu, send_to_tpu, start_single, synchronous_deletion 10 | 11 | _, _, wandb_key = netrc().authenticators("api.wandb.ai") 12 | 13 | 14 | @dataclasses.dataclass 15 | class Context: 16 | retry: int 17 | zone: str 18 | host: str 19 | branch: str 20 | run_name: str 21 | data_path: str 22 | config_path: str 23 | 24 | 25 | def load_config(ctx: Context): 26 | with open(ctx.config_path, 'r') as f: 27 | config = f.read() 28 | config = yaml.safe_load(config) 29 | 30 | wandb_api = wandb.Api() 31 | config["training"]["do_checkpoint"] = True 32 | base_checkpoint_path = config["training"]["checkpoint_path"] 33 | 34 | start_step = 0 35 | for run in wandb_api.runs(f"{config['wandb']['entity']}/{config['wandb']['project']}"): 36 | if run.name == config['wandb']['name']: 37 | start_step = run.summary["_step"] 38 | break 39 | start_step -= start_step % config["training"]["checkpoint_interval"] 40 | 41 | config["training"]["start_step"] = start_step 42 | config["data"]["path"] = ctx.data_path 43 | config["wandb"]["name"] = f"{ctx.run_name}-{ctx.retry}" 44 | if ctx.retry > 0: 45 | config["training"]["checkpoint_load_path"] = config["training"]["checkpoint_path"] 46 | config["training"]["checkpoint_path"] = f"{base_checkpoint_path}-{ctx.retry}" 47 | return yaml.dump(config) 48 | 49 | 50 | def start_fn(ctx: Context, worker: int): 51 | """ 52 | This function gets executed in threads to start a run on a new TPU. It receives the context object returned by 53 | `creation_callback` as well as the worker id which corresponds to the slice id this code was executed on in a 54 | multi-host setup. For single-host setups, such as v3-8s, the "worker" will always be set to 0. 55 | Ideally, it'd copy necessary files to the TPU and then run those. Here, `exec_command` can be used to create an 56 | execution command that automatically spawns a `screen` session which persists even when the SSH connection gets cut. 57 | """ 58 | send_to_tpu(ctx.host, ctx.zone, "config.yaml", load_config(ctx), worker) 59 | cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key, 60 | branch=ctx.branch) 61 | send_to_tpu(ctx.host, ctx.zone, "setup.sh", cmd, worker) 62 | exec_on_tpu(ctx.host, ctx.zone, "bash setup.sh", worker) 63 | 64 | 65 | def parse_args(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--host", type=str, help="Name of the TPU") 68 | parser.add_argument("--tpu-version", type=int, default=3, help="Which TPU version to create (v2-8 or v3-8)") 69 | parser.add_argument("--zone", type=str, default="europe-west4-a", help="GCP Zone TPUs get created in") 70 | parser.add_argument("--data-path", type=str, default="gs://ggpt4/the-char-pile/", 71 | help="Where the data is stored. Should be changed to a bucket in the correct region") 72 | parser.add_argument("--preemptible", default=1, type=int, 73 | help="Whether to create preemptible or non-preemptible TPUs") 74 | parser.add_argument("--service-account", type=str, 75 | help="Service account that controls permissions of TPU (for example, to ensure EU TPUs " 76 | "won't " 77 | "use US data)") 78 | parser.add_argument("--branch", type=str, default="main", help="Branch on github to use") 79 | parser.add_argument("--slices", default=1, type=int, 80 | help="How many TPU slices each TPU should have (1=>vX-8, 4=>vX-32)") 81 | parser.add_argument("--run-name", type=str, help="Prefix to use for all runs on WandB") 82 | parser.add_argument("--config-path", type=str, help="Path to config.yaml") 83 | parser.add_argument("--cleanup", default=0, type=int, 84 | help="Instead of running something new, kill all tpus. 1 or 0 for y/n") 85 | args = parser.parse_args() 86 | return args 87 | 88 | 89 | def main(): 90 | args = parse_args() 91 | if args.cleanup: 92 | synchronous_deletion("", args.host, args.zone) 93 | return 94 | 95 | def creation_callback(host: str, ctx: typing.Optional[Context]) -> Context: 96 | if ctx is None: # first invocation 97 | return Context(retry=0, zone=args.zone, host=args.host, branch=args.branch, run_name=args.run_name, 98 | data_path=args.data_path, config_path=args.config_path) 99 | ctx.retry += 1 100 | return ctx 101 | 102 | return start_single(args.host, args.tpu_version, args.zone, args.preemptible, args.service_account, 103 | args.slices, start_fn, creation_callback) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TPU Care 2 | 3 | Automatically take good care of your preemptible TPUs 4 | 5 | ## Table of Contents 6 | 7 | * [TPU Care](#tpu-care) 8 | * [Table of Contents](#table-of-contents) 9 | * [Features](#features) 10 | * [Getting Started](#getting-started) 11 | * [Installation](#installation) 12 | * [Examples](#examples) 13 | * [Long-running preemptible training](#long-running-preemptible-training) 14 | * [Sweeps](#sweeps) 15 | * [Citation](#citation) 16 | 17 | ## Features 18 | 19 | * **Reliable code execution**: TPU Care starts a TPU, ensures it's set up as specified and continues the experiment 20 | whenever the node dies. Think of it like [TerraForm](https://www.terraform.io/) + [Ansible](https://www.ansible.com/) 21 | for machine learning. 22 | * **Maintenance of large swarms**: When running multiple nodes, TPU Care will automatically delete dead instances while 23 | keeping as many alive as possible. 24 | * **Code generation**: To simplify setup, TPU Care efficiently clones your git repository and ensures trustable 25 | execution of your `run_command` that continues even during outages. 26 | * **Optimized management**: When a node dies, TPU Care deletes it within five minutes and creates a new one the second 27 | there is capacity. 28 | 29 | ## Getting Started 30 | 31 | ### Installation 32 | 33 | ```BASH 34 | python3 -m pip install tpucare 35 | ``` 36 | 37 | ## Examples 38 | 39 | We've been using TPU Care for a while at [HomebrewNLP](https://github.com/HomebrewNLP/). In fact, this library is just 40 | the branched out core of the original production-ready HomebrewNLP code. At HomebrewNLP, there were two major use-cases 41 | for this library. We started both massive hyperparameter sweeps which consumed 900,000 TPU-core hours within three 42 | months and stable training on large TPU pods. Below, you can see a list of TPUs which are largely managed by TPU 43 | Care:  44 |
Screenshot from TPUnicorn, a CLI-based TPU managed software
45 | In the following sections, you'll learn how we use at massive scale with minimal code effort. 46 | 47 | ### Long-running preemptible training 48 | 49 | For example, the following code can be used to create a production-ready v3-256 using 50 | the [HomebrewNLP-Jax](https://github.com/HomebrewNLP/HomebrewNLP-Jax) codebase ( 51 | see [examples/pod.py](https://github.com/clashluke/tpucare/blob/main/examples/pod.py) for an executable version): 52 | 53 | ```PYTHON 54 | import dataclasses 55 | import typing 56 | from netrc import netrc 57 | 58 | import yaml 59 | 60 | from tpucare import exec_command, exec_on_tpu, send_to_tpu, start_single 61 | 62 | 63 | @dataclasses.dataclass 64 | class Context: 65 | retry: int 66 | 67 | 68 | ZONE = "europe-west4-a" 69 | HOST = "big-pod" 70 | RUN_NAME = "256-core-tpu" 71 | 72 | 73 | def load_config(ctx: Context): 74 | with open("config.yaml", 'r') as f: 75 | config = f.read() 76 | config = yaml.safe_load(config) 77 | 78 | wandb_api = wandb.Api() 79 | config["training"]["do_checkpoint"] = True 80 | base_checkpoint_path = config["training"]["checkpoint_path"] 81 | 82 | start_step = 0 83 | for run in wandb_api.runs(f"{config['wandb']['entity']}/{config['wandb']['project']}"): 84 | if run.name == config['wandb']['name']: 85 | start_step = run.summary["_step"] 86 | break 87 | start_step -= start_step % config["training"]["checkpoint_interval"] 88 | 89 | config["training"]["start_step"] = start_step 90 | config["wandb"]["name"] = f"{RUN_NAME}-{ctx.retry}" 91 | if ctx.retry > 0: 92 | config["training"]["checkpoint_load_path"] = config["training"]["checkpoint_path"] 93 | config["training"]["checkpoint_path"] = f"{base_checkpoint_path}-{ctx.retry}" 94 | return yaml.dump(config) 95 | 96 | 97 | def start_fn(ctx: Context, worker: int): 98 | """ 99 | This function gets executed in threads to start a run on a new TPU. It receives the context object returned by 100 | `creation_callback` as well as the worker id which corresponds to the slice id this code was executed on in a 101 | multi-host setup. For single-host setups, such as v3-8s, the "worker" will always be set to 0. 102 | Ideally, it'd copy necessary files to the TPU and then run those. Here, `exec_command` can be used to create an 103 | execution command that automatically spawns a `screen` session which persists even when the SSH connection gets cut. 104 | """ 105 | send_to_tpu(HOST, ZONE, "config.yaml", load_config(ctx), worker) 106 | cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key) 107 | send_to_tpu(HOST, ZONE, "setup.sh", cmd, worker) 108 | exec_on_tpu(HOST, ZONE, "bash setup.sh", worker) 109 | 110 | 111 | def creation_callback(host: str, ctx: typing.Optional[Context]) -> Context: 112 | """ 113 | The `creation_callback` is called once whenever a new TPU gets created and can be used to persist state 114 | (such as retry counters) across multiple invocations. 115 | """ 116 | if ctx is None: # first invocation 117 | return Context(0) 118 | ctx.retry += 1 119 | return ctx 120 | 121 | 122 | def main(service_account: str, tpu_version: int = 3, slices: int = 32, preemptible: bool = True): 123 | start_single(host=HOST, tpu_version=tpu_version, zone=ZONE, preemptible=preemptible, 124 | service_account=service_account, slices=slices, start_fn=start_fn, 125 | creation_callback=creation_callback) 126 | ``` 127 | 128 | ### Sweeps 129 | 130 | Similarly, large swarms of instances can be launched trivially using tpucare. Here, we largely do the same setup as 131 | above, but call `launch_multiple` instead of `launch_single` which takes the additional argument `tpus` specifying the 132 | number of TPUs that should be launched and babysit. Depending on capacity and quota, the actual number of TPUs you get 133 | might be lower than the number of TPUs specified. 134 | 135 | ```PYTHON 136 | def main(service_account: str, tpus: int, tpu_version: int = 3, slices: int = 32, preemptible: bool = True): 137 | start_multiple(prefix=HOST, tpu_version=tpu_version, zone=ZONE, preemptible=preemptible, 138 | service_account=service_account, slices=slices, start_fn=start_fn, 139 | creation_callback=creation_callback, tpus=tpus) 140 | ``` 141 | 142 | However, this would simply launch the same run many times. If you instead plan to register them with a 143 | [WandB Sweep](https://docs.wandb.ai/guides/sweeps/configuration), we need to modify the `start_fn` to join the wandb 144 | sweep.\ 145 | By patching in the code below, tpucare will start and maintain a large swarm of TPUs all working towards the same 146 | hyperparameter optimization problem. 147 | 148 | ```PYTHON 149 | import wandb 150 | 151 | with open("sweep.yaml", 'r') as f: # sweep config passed straight to wandb 152 | config = yaml.safe_load(f.read()) 153 | sweep_id = wandb.sweep(config, entity="homebrewnlp", project="gpt") 154 | 155 | 156 | def start_fn(ctx: Context, worker: int): 157 | cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key, 158 | run_command=f"/home/ubuntu/.local/bin/wandb agent {sweep_id}") 159 | send_to_tpu(HOST, ZONE, "setup.sh", cmd, worker) 160 | exec_on_tpu(HOST, ZONE, "bash setup.sh", worker) 161 | ``` 162 | 163 | The full executable code can be found 164 | in [examples/sweep.py](https://github.com/clashluke/tpucare/blob/main/examples/sweep.py). 165 | 166 | Similarly, the `start_fn` could be adapted to start an inference server 167 | for [HomebrewNLP](https://github.com/HomebrewNLP/HomebrewNLP-Jax/) 168 | or [Craiyon](https://huggingface.co/spaces/dalle-mini/dalle-mini) or even execute machine learning unit-tests in 169 | parallel. 170 | 171 | ## Citation 172 | 173 | ```BIBTEX 174 | @software{nestler_lucas_2022_6837312, 175 | author = {Nestler, Lucas}, 176 | title = {TPU Care}, 177 | month = jul, 178 | year = 2022, 179 | publisher = {Zenodo}, 180 | version = {0.0.2}, 181 | doi = {10.5281/zenodo.6837312}, 182 | url = {https://doi.org/10.5281/zenodo.6837312} 183 | }``` -------------------------------------------------------------------------------- /tpucare/__init__.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import inspect 3 | import json 4 | import logging 5 | import multiprocessing 6 | import os 7 | import signal 8 | import subprocess 9 | import tempfile 10 | import threading 11 | import time 12 | import types 13 | import typing 14 | from contextlib import nullcontext 15 | from typing import Callable, List, Optional 16 | 17 | 18 | def call(cmd: str) -> str: 19 | return subprocess.check_output(cmd.split()).rstrip().decode() 20 | 21 | 22 | class CallReturnError(ValueError): 23 | pass 24 | 25 | 26 | def retry_call(cmd: typing.List[str], retries: int = -2): 27 | while retries != -1: 28 | retries -= 1 29 | popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, universal_newlines=True) 30 | out = "" 31 | for stdout_line in iter(popen.stdout.readline, ""): 32 | out += f'{out}' 33 | log(stdout_line.rstrip(), log_level=logging.DEBUG) 34 | popen.stdout.close() 35 | if not popen.wait(): 36 | return out 37 | raise ValueError 38 | 39 | 40 | PROJECT = call("gcloud config get project") 41 | TPU_CMD = "gcloud alpha compute tpus tpu-vm" 42 | GLOBAL_DICT = {} 43 | CACHE_TIME = 10 44 | LOG_LEVEL = logging.INFO 45 | Context = typing.TypeVar("Context") 46 | All = typing.Literal["all"] 47 | SliceIndex = typing.Union[All, int] 48 | 49 | 50 | def log(*message, log_level=1e9): 51 | if log_level > LOG_LEVEL: 52 | print(f'{datetime.datetime.now()} | {" ".join(map(str, message))}', flush=True) 53 | 54 | 55 | def exec_command(repository: str, wandb_key: typing.Optional[str] = None, branch: str = "main", 56 | setup_command: str = "(bash setup.sh; exit 0)", run_command: str = "bash run.sh", 57 | install_python: bool = True, pkilled: Optional[List[str]] = None): 58 | path = repository.split('/')[-1] 59 | if path.endswith('.git'): 60 | path = path[:-len('.git')] 61 | script = [] 62 | if install_python: 63 | script.append("sudo apt-get -o DPkg::Lock::Timeout=-1 update") 64 | script.append("sudo apt-get -o DPkg::Lock::Timeout=-1 --fix-missing --fix-broken install -y git python3 " 65 | "python3-pip") 66 | script.append(f"(rm -rf {path} ; pkill -f python3 ; exit 0)") 67 | script.append(f"python3 -m pip install --upgrade pip") 68 | script.append(f"git clone --depth 1 --branch {branch} {repository}") 69 | script.append(f"cd {path}") 70 | if wandb_key is not None: 71 | script.append("python3 -m pip install --upgrade wandb typer click") 72 | script.append(f"/home/ubuntu/.local/bin/wandb login {wandb_key}") 73 | script.append(setup_command) 74 | start_command = f'screen -dmS model bash -c "cd {path} ; git pull ; {run_command}"' 75 | script.append(start_command) 76 | if pkilled is None: 77 | pkilled = [] 78 | if pkilled: 79 | pkilled = ' ; '.join(f'pkill -f {k}' for k in pkilled) 80 | script.append(f"echo '{pkilled} ; sleep 30 ; {pkilled} ; {start_command} ; screen -r model' >> .bashrc") 81 | else: 82 | script.append(f"echo '{start_command} ; screen -r model' >> .bashrc") 83 | return ' &&\\\n'.join(script) 84 | 85 | 86 | def retry_delete(host: str, zone: str, cmd: typing.List[str], retries: int = -2): 87 | try: 88 | return retry_call(cmd, retries) 89 | except CallReturnError: 90 | delete_one_tpu(host, host, zone) 91 | return "" 92 | 93 | 94 | def log_entry(prefix: str, fmt: Callable): 95 | def _fn(fn: Callable): 96 | signature = inspect.signature(fn).parameters 97 | 98 | def _inner(*args, **kwargs): 99 | for a, s in zip(args, signature.keys()): 100 | kwargs[s] = a 101 | txt = f"{prefix} {fmt(kwargs)}" 102 | log(f"{txt} ...", log_level=logging.INFO) 103 | start_time = time.time() 104 | out = fn(**kwargs) 105 | log(f"finished {txt} after {time.time() - start_time:.1f}s", log_level=logging.INFO) 106 | return out 107 | 108 | return _inner 109 | 110 | return _fn 111 | 112 | 113 | @log_entry("sending", lambda x: f"'{x['filename_on_tpu']}'") 114 | def send_to_tpu(host: str, zone: str, filename_on_tpu: str, command: str, worker: SliceIndex = 0): 115 | with tempfile.NamedTemporaryFile(mode='w+') as f: 116 | f.write(command) 117 | f.flush() 118 | cmd = TPU_CMD.split(' ') + ["scp", f.name, f"ubuntu@{host}:~/{filename_on_tpu}", "--zone", zone, "--worker", 119 | str(worker)] 120 | retry_delete(host, zone, cmd, 4) 121 | 122 | 123 | @log_entry("running", lambda x: f"'{x['command']}'") 124 | def exec_on_tpu(host: str, zone: str, command: str, worker: SliceIndex = 0) -> str: 125 | cmd = TPU_CMD.split(' ') + ["ssh", f"ubuntu@{host}", f"--zone", zone, "--command", command, "--worker", str(worker)] 126 | return retry_delete(host, zone, cmd, 2) 127 | 128 | 129 | def all_tpus(zone: str): 130 | zone = 'projects/' + PROJECT + '/locations/' + zone 131 | if GLOBAL_DICT.get(f"last_write_{zone}", 0) < time.time() - CACHE_TIME: 132 | GLOBAL_DICT[f"last_write_{zone}"] = time.time() 133 | GLOBAL_DICT[f"tpus_{zone}"] = json.loads(call(f"{TPU_CMD} list --zone {zone} --format json")) 134 | return GLOBAL_DICT[f"tpus_{zone}"] 135 | 136 | 137 | def valid_tpu(tpu: dict, preempted: bool = True, deleting: bool = False, unhealthy: bool = True) -> bool: 138 | state = "state" in tpu and (deleting or tpu['state'] != "DELETING") and (preempted or tpu['state'] != "PREEMPTED") 139 | state |= deleting and preempted 140 | healthy = unhealthy or "health" not in tpu or tpu["health"] == "HEALTHY" # we assume no health info == good 141 | return state and healthy 142 | 143 | 144 | def tpu_names(zone: str, preempted: bool = True, deleting: bool = False, unhealthy: bool = False, 145 | no_filter: bool = False, prefix: str = ''): 146 | while True: 147 | try: 148 | tpus = all_tpus(zone) 149 | if no_filter: 150 | tpus = [t['name'].split('/')[-1] for t in tpus] 151 | else: 152 | tpus = [t['name'].split('/')[-1] for t in tpus if valid_tpu(t, preempted, deleting, unhealthy)] 153 | return [t for t in tpus if t.startswith(prefix)] 154 | except KeyboardInterrupt as exc: 155 | raise exc 156 | except: 157 | pass 158 | 159 | 160 | def delete_no_check(host: str, zone: str, asynchronous: bool): 161 | os.system(f"echo y | {TPU_CMD} delete {host} --zone {zone} {'--async' * asynchronous}") 162 | 163 | 164 | def tpu_ips(host: str, zone: str) -> typing.List[str]: 165 | out = call(f"{TPU_CMD} describe {host} --format json --zone {zone}") 166 | return [host["accessConfig"]["externalIp"] for host in json.loads(out)["networkEndpoints"]] 167 | 168 | 169 | def delete_one_tpu(prefix: str, host: str, zone: str, asynchronous: bool = True): 170 | if prefix not in host or host not in tpu_names(zone, no_filter=True): 171 | return 172 | log(f"\x1b[32;1m DELETING {host}\x1b[0m", log_level=logging.INFO) 173 | delete_no_check(host, zone, asynchronous) 174 | while not asynchronous and host in tpu_names(zone, no_filter=True): 175 | delete_no_check(host, zone, asynchronous) 176 | 177 | 178 | def delete_all(prefix: str, zone: str): 179 | while tpu_names(zone, prefix=prefix, no_filter=True): 180 | threads = [threading.Thread(target=delete_one_tpu, args=(prefix, host, zone, False), daemon=True) for host in 181 | tpu_names(zone, prefix=prefix)] 182 | for t in threads: 183 | t.start() 184 | for t in threads: 185 | t.join() 186 | 187 | 188 | def create_tpu(host: str, zone: str, tpu_version: int, preemptible: bool, service_account: str, 189 | semaphore: typing.Optional[typing.ContextManager], slices: int = 1): 190 | with semaphore: 191 | os.system(f'while ! gcloud alpha compute tpus tpu-vm create {host} --service-account {service_account} ' 192 | f'--zone {zone} --accelerator-type v{tpu_version}-{slices * 8} --version v2-alpha ' 193 | f'{"--preemptible" * preemptible}; do echo; done') 194 | 195 | 196 | def recreate(host: str, zone: str, tpu_version: int, preemptible: bool, service_account: str, slices: int, 197 | creation_semaphore: typing.Optional[typing.ContextManager] = None): 198 | delete_one_tpu("", host, zone, False) 199 | create_tpu(host, zone, tpu_version, preemptible, service_account, creation_semaphore, slices) 200 | 201 | 202 | def get_name(fn: typing.Callable, base: str): 203 | if hasattr(fn, '__name__'): 204 | return f"{base} ({fn.__name__})" 205 | if not isinstance(fn, types.FunctionType): 206 | return get_name(type(fn), base) 207 | return base 208 | 209 | 210 | def start_single(host: str, tpu_version: int, zone: str, preemptible: bool, service_account: str, slices: int, 211 | start_fn: typing.Callable[[Context, SliceIndex], None], 212 | creation_callback: typing.Callable[[str, typing.Optional[Context]], Context], 213 | creation_semaphore: typing.Optional[typing.ContextManager] = None, all_workers: bool = False, 214 | unhealthy_timeout_seconds=3600): 215 | if creation_semaphore is None: 216 | creation_semaphore = nullcontext() 217 | 218 | ctx = None 219 | 220 | creation_callback_name = get_name(creation_callback, "creation_callback") 221 | start_fn_name = get_name(start_fn, "start_fn") 222 | unhealthy_timeout_seconds /= CACHE_TIME 223 | 224 | while True: 225 | try: 226 | log("Recreating TPU", log_level=logging.INFO) 227 | recreate(host, zone, tpu_version, preemptible, service_account, slices, creation_semaphore) 228 | log(f"TPU Created. Calling {creation_callback_name}.", log_level=logging.INFO) 229 | ctx = creation_callback(host, ctx) 230 | log(f"Callback returned. Launching {start_fn_name}", log_level=logging.INFO) 231 | if all_workers: 232 | threads = [multiprocessing.Process(target=start_fn, args=(ctx, "all"), daemon=True)] 233 | else: 234 | threads = [multiprocessing.Process(target=start_fn, args=(ctx, i), daemon=True) for i in range(slices)] 235 | for t in threads: 236 | t.start() 237 | log("Started start_fn. Babysitting TPU..", log_level=logging.INFO) 238 | retries = unhealthy_timeout_seconds # sometimes "unhealthy" resolves itself. Let's wait 239 | while host in tpu_names(zone, preempted=False, unhealthy=True): 240 | if retries <= 0: 241 | break 242 | time.sleep(CACHE_TIME) 243 | if host in tpu_names(zone, preempted=False, unhealthy=False): 244 | retries = unhealthy_timeout_seconds 245 | else: 246 | retries -= 1 247 | log(f"TPU is {'unhealthy' if retries <= 0 else 'preempted'}. Recreating it now.", log_level=logging.INFO) 248 | while any(t.is_alive() for t in threads): 249 | for t in threads: 250 | if t.is_alive(): 251 | os.kill(t.pid, signal.SIGINT) 252 | time.sleep(0.1) 253 | log("Sent SIGINT to all workers", log_level=logging.INFO) 254 | except KeyboardInterrupt: 255 | log(f"{host} - {datetime.datetime.now()}: KeyboardInterrupt received. Killing TPU, then self.", 256 | log_level=logging.WARN) 257 | delete_one_tpu("", host, zone, False) 258 | return 259 | 260 | 261 | def start_multiple(prefix: str, tpu_version: int, zone: str, preemptible: bool, service_account: str, slices: int, 262 | start_fn: typing.Callable[[typing.Any, int], None], 263 | created_callback: typing.Callable[[typing.Any], typing.Any], tpus: int): 264 | procs = [] 265 | creation_semaphore = threading.Semaphore(2) 266 | for tpu_id in range(tpus): 267 | proc = threading.Thread(target=start_single, daemon=True, args=( 268 | f'{prefix}-{tpu_id}', tpu_version, zone, preemptible, service_account, slices, start_fn, created_callback, 269 | creation_semaphore)) 270 | proc.start() 271 | procs.append(proc) 272 | while all(t.is_alive() for t in procs): 273 | try: 274 | time.sleep(10) 275 | except KeyboardInterrupt: 276 | log(f"MAIN - {datetime.datetime.now()}: KeyboardInterrupt received. Killing All TPUs, then self.", 277 | logging.WARN) 278 | delete_all(prefix, zone) 279 | return 280 | --------------------------------------------------------------------------------