├── tests ├── __init__.py ├── node_test.py ├── secret_test.py ├── account_test.py ├── dataset_test.py ├── organization_test.py ├── image_test.py ├── config_test.py ├── client_test.py ├── cluster_test.py ├── group_test.py ├── util_test.py ├── job_test.py ├── workspace_test.py ├── experiment_test.py └── data_model_test.py ├── docs ├── .gitignore ├── source │ ├── _static │ │ ├── css │ │ │ └── custom.css │ │ ├── favicon.ico │ │ └── beaker-500px-transparent.png │ ├── CHANGELOG.md │ ├── CONTRIBUTING.md │ ├── overview.rst │ ├── api │ │ ├── config.rst │ │ ├── exceptions.rst │ │ ├── experiment_spec.rst │ │ ├── client.rst │ │ └── data_models.rst │ ├── quickstart.md │ ├── installation.md │ ├── examples.md │ ├── faq.md │ ├── index.rst │ └── conf.py ├── Makefile └── make.bat ├── beaker ├── py.typed ├── aliases.py ├── version.py ├── data_model │ ├── secret.py │ ├── account.py │ ├── organization.py │ ├── node.py │ ├── _base_v2.py │ ├── group.py │ ├── _base_v1.py │ ├── __init__.py │ ├── experiment.py │ ├── cluster.py │ ├── image.py │ ├── workspace.py │ ├── base.py │ ├── dataset.py │ └── job.py ├── services │ ├── __init__.py │ ├── node.py │ ├── account.py │ ├── secret.py │ ├── organization.py │ └── group.py ├── conftest.py ├── config.py ├── exceptions.py └── util.py ├── examples └── sweep │ ├── .dockerignore │ ├── Dockerfile │ ├── README.md │ ├── entrypoint.py │ └── run.py ├── Dockerfile ├── .github ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── documentation.yml │ ├── feature_request.yml │ └── bug_report.yml ├── workflows │ ├── pr_checks.yml │ └── main.yml └── actions │ └── setup-venv │ └── action.yml ├── .readthedocs.yaml ├── test_fixtures └── docker │ └── Dockerfile ├── .dockerignore ├── scripts ├── release.sh ├── add_pr_comments_on_release.sh ├── prepare_changelog.py └── release_notes.py ├── Makefile ├── .gitignore ├── RELEASE_PROCESS.md ├── integration_tests ├── jobs_test.py ├── images_test.py ├── experiments_test.py └── datasets_test.py ├── pyproject.toml ├── README.md ├── conftest.py └── CONTRIBUTING.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ../../CHANGELOG.md -------------------------------------------------------------------------------- /docs/source/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ../../CONTRIBUTING.md -------------------------------------------------------------------------------- /beaker/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file for type checking 2 | -------------------------------------------------------------------------------- /examples/sweep/.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !Dockerfile 3 | !entrypoint.py 4 | -------------------------------------------------------------------------------- /docs/source/overview.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | 4 | .. automodule:: beaker 5 | -------------------------------------------------------------------------------- /docs/source/api/config.rst: -------------------------------------------------------------------------------- 1 | Config 2 | ------ 3 | 4 | .. autoclass:: beaker.Config 5 | :members: 6 | -------------------------------------------------------------------------------- /beaker/aliases.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | PathOrStr = Union[os.PathLike, str] 5 | -------------------------------------------------------------------------------- /docs/source/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/beaker-py/HEAD/docs/source/_static/favicon.ico -------------------------------------------------------------------------------- /docs/source/api/exceptions.rst: -------------------------------------------------------------------------------- 1 | Exceptions 2 | ---------- 3 | 4 | .. automodule:: beaker.exceptions 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/_static/beaker-500px-transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/beaker-py/HEAD/docs/source/_static/beaker-500px-transparent.png -------------------------------------------------------------------------------- /beaker/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "1" 2 | _MINOR = "38" 3 | _PATCH = "1" 4 | _SUFFIX = "" 5 | 6 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 7 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-alpine 2 | 3 | WORKDIR /stage 4 | 5 | COPY . . 6 | RUN pip install --upgrade pip wheel build 7 | RUN pip install --no-cache-dir 'PyYAML<5.4' . 8 | 9 | WORKDIR /app/beaker 10 | RUN rm -rf /stage 11 | -------------------------------------------------------------------------------- /examples/sweep/Dockerfile: -------------------------------------------------------------------------------- 1 | # This Dockerfile defines the image that we'll use for all of the sweep experiments 2 | # that we submit to Beaker. 3 | 4 | FROM python:3.9 5 | 6 | COPY entrypoint.py . 7 | 8 | ENTRYPOINT ["python", "entrypoint.py"] 9 | -------------------------------------------------------------------------------- /tests/node_test.py: -------------------------------------------------------------------------------- 1 | from beaker import Beaker 2 | 3 | 4 | def test_node_get(client: Beaker, beaker_node_id: str): 5 | gpu_count = client.node.get(beaker_node_id).limits.gpu_count 6 | assert gpu_count is not None 7 | assert gpu_count > 0 8 | -------------------------------------------------------------------------------- /docs/source/quickstart.md: -------------------------------------------------------------------------------- 1 | Quick start 2 | =========== 3 | 4 | ```{include} ../../README.md 5 | :start-after: 6 | :end-before: 7 | ``` 8 | 9 | See the [Overview](/overview) to learn about the {class}`~beaker.Beaker` client's methods. 10 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "13:00" 8 | open-pull-requests-limit: 10 9 | - package-ecosystem: "github-actions" 10 | directory: "/" 11 | schedule: 12 | interval: "daily" 13 | -------------------------------------------------------------------------------- /beaker/data_model/secret.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | from .base import BaseModel 5 | 6 | __all__ = ["Secret"] 7 | 8 | 9 | class Secret(BaseModel): 10 | name: str 11 | created: datetime 12 | updated: datetime 13 | author_id: Optional[str] = None 14 | -------------------------------------------------------------------------------- /tests/secret_test.py: -------------------------------------------------------------------------------- 1 | from beaker import Beaker 2 | 3 | 4 | def test_secrets(client: Beaker, secret_name: str): 5 | secret = client.secret.write(secret_name, "foo") 6 | assert secret.name == secret_name 7 | assert client.secret.get(secret_name) == secret 8 | assert client.secret.read(secret) == "foo" 9 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | fail_on_warning: true 6 | 7 | build: 8 | os: ubuntu-22.04 9 | tools: 10 | python: "3.10" 11 | 12 | python: 13 | install: 14 | - method: pip 15 | path: . 16 | extra_requirements: 17 | - dev 18 | -------------------------------------------------------------------------------- /tests/account_test.py: -------------------------------------------------------------------------------- 1 | from beaker.client import Beaker 2 | 3 | 4 | def test_whoami(client: Beaker): 5 | client.account.whoami() 6 | 7 | 8 | def test_name(client: Beaker): 9 | assert isinstance(client.account.name, str) 10 | 11 | 12 | def test_list_organizations(client: Beaker): 13 | client.account.list_organizations() 14 | -------------------------------------------------------------------------------- /tests/dataset_test.py: -------------------------------------------------------------------------------- 1 | from beaker.client import Beaker 2 | 3 | 4 | def test_create_upload_commit(client: Beaker, dataset_name: str): 5 | ds = client.dataset.create(dataset_name, commit=False) 6 | client.dataset.upload(ds, b"foo-bar", "foo-bar") 7 | client.dataset.commit(ds) 8 | client.dataset.ls(ds) 9 | client.dataset.file_info(ds, "foo-bar") 10 | -------------------------------------------------------------------------------- /beaker/data_model/account.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | from .base import BaseModel 5 | 6 | __all__ = ["Account"] 7 | 8 | 9 | class Account(BaseModel): 10 | id: str 11 | name: str 12 | display_name: str 13 | institution: Optional[str] = None 14 | pronouns: Optional[str] = None 15 | email: Optional[str] = None 16 | created: Optional[datetime] = None 17 | -------------------------------------------------------------------------------- /test_fixtures/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # This Dockerfile is used to build an image in CI that's uploaded to Beaker. 2 | # It's used only to test the `Beaker.image` functionality of creating and uploading images. 3 | 4 | FROM python:3.9-alpine 5 | 6 | # We use the commit SHA to ensure this image has unique layers to upload. 7 | ARG COMMIT_SHA 8 | 9 | # Now do some other some other random stuff so that we have more layers to upload. 10 | RUN echo ${COMMIT_SHA} > /out.log 11 | RUN pip install --upgrade pip wheel build 12 | RUN pip install requests pydantic 13 | -------------------------------------------------------------------------------- /tests/organization_test.py: -------------------------------------------------------------------------------- 1 | from beaker import Beaker 2 | 3 | 4 | def test_organization_get(client: Beaker, beaker_org_name: str): 5 | org = client.organization.get(beaker_org_name) 6 | assert org.name == beaker_org_name 7 | # Now get by ID. 8 | client.organization.get(org.id) 9 | 10 | 11 | def test_organization_list_members(client: Beaker, beaker_org_name: str): 12 | client.organization.list_members(beaker_org_name) 13 | 14 | 15 | def test_organization_get_member(client: Beaker): 16 | client.organization.get_member(client.account.name) 17 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .dockerignore 2 | 3 | # Cache files: 4 | **.pyc 5 | **__pycache__ 6 | **.mypy_cache 7 | .pytest_cache 8 | .coverage 9 | 10 | # Config files: 11 | mypy.ini 12 | pytest.ini 13 | .flake8 14 | .readthedocs.yaml 15 | beaker.yml 16 | 17 | # Git/github stuff: 18 | .gitignore 19 | .git 20 | .github 21 | 22 | # Docs, tests, examples: 23 | docs 24 | examples 25 | tests 26 | test_fixtures 27 | integration_tests 28 | scripts 29 | conftest.py 30 | *.md 31 | !README.md 32 | 33 | # Build: 34 | dist 35 | *.egg-info 36 | .venv 37 | Makefile 38 | 39 | # Scratch: 40 | tmp*.py 41 | -------------------------------------------------------------------------------- /tests/image_test.py: -------------------------------------------------------------------------------- 1 | from beaker import Beaker 2 | 3 | 4 | def test_image_get(client: Beaker, hello_world_image_name: str): 5 | # Get by full name. 6 | image = client.image.get(hello_world_image_name) 7 | # Get by ID. 8 | client.image.get(image.id) 9 | # Get by name. 10 | assert image.name is not None 11 | client.image.get(image.name) 12 | 13 | 14 | def test_image_url(client: Beaker, hello_world_image_name: str): 15 | assert ( 16 | client.image.url(hello_world_image_name) 17 | == "https://beaker.org/im/01FPB7XCX3GHKW5PS9J4623EBN" 18 | ) 19 | -------------------------------------------------------------------------------- /docs/source/installation.md: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | **beaker-py** supports Python >= 3.7. 5 | 6 | ## Installing with `pip` 7 | 8 | **beaker-py** is available [on PyPI](https://pypi.org/project/beaker-py/). Just run 9 | 10 | ```bash 11 | pip install 'beaker-py<2.0' 12 | ``` 13 | 14 | ## Installing from source 15 | 16 | To install **beaker-py** from source, first clone [the repository](https://github.com/allenai/beaker-py): 17 | 18 | ```bash 19 | git clone https://github.com/allenai/beaker-py.git 20 | cd beaker-py 21 | ``` 22 | 23 | Then run 24 | 25 | ```bash 26 | pip install -e . 27 | ``` 28 | -------------------------------------------------------------------------------- /scripts/release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | TAG=$(python -c 'from beaker.version import VERSION; print("v" + VERSION)') 6 | 7 | read -p "Creating new release for $TAG. Do you want to continue? [Y/n] " prompt 8 | 9 | if [[ $prompt == "y" || $prompt == "Y" || $prompt == "yes" || $prompt == "Yes" ]]; then 10 | python scripts/prepare_changelog.py 11 | git add -A 12 | git commit -m "(chore) bump version to $TAG for release" || true && git push 13 | echo "Creating new git tag $TAG" 14 | git tag "$TAG" -m "$TAG" 15 | git push --tags 16 | else 17 | echo "Cancelled" 18 | exit 1 19 | fi 20 | -------------------------------------------------------------------------------- /tests/config_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import yaml 3 | 4 | from beaker import Beaker 5 | from beaker.config import Config 6 | 7 | 8 | def test_str_method(client: Beaker): 9 | assert "user_token=***" in str(client.config) 10 | assert client.config.user_token not in str(client.config) 11 | 12 | 13 | def test_config_from_path_unknown_field(tmp_path): 14 | path = tmp_path / "config.yml" 15 | with open(path, "w") as f: 16 | yaml.dump({"user_token": "foo-bar", "baz": 1}, f) 17 | 18 | with pytest.warns(RuntimeWarning, match="Unknown field 'baz' found in config"): 19 | Config.from_path(path) 20 | -------------------------------------------------------------------------------- /examples/sweep/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Hyperparameter sweeps 4 | 5 | This example shows how you can run a hyperparameter sweep in Beaker using **beaker-py**. 6 | 7 | 8 | 9 | 10 | 11 | To run it, first build the Docker image: 12 | 13 | ```bash 14 | image=sweep 15 | docker build -t $image . 16 | ``` 17 | 18 | Then launch the sweep with: 19 | 20 | ```bash 21 | workspace=ai2/my-sweep # change this to the workspace of your choosing 22 | cluster=ai2/petew-cpu # change this to the cluster of your choosing 23 | python run.py $image $workspace $cluster 24 | ``` 25 | 26 | 27 | -------------------------------------------------------------------------------- /beaker/data_model/organization.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | from .account import Account 5 | from .base import BaseModel, StrEnum 6 | 7 | __all__ = ["Organization", "OrganizationRole", "OrganizationMember"] 8 | 9 | 10 | class Organization(BaseModel): 11 | id: str 12 | name: str 13 | description: str 14 | created: datetime 15 | display_name: str 16 | pronouns: Optional[str] = None 17 | 18 | 19 | class OrganizationRole(StrEnum): 20 | admin = "admin" 21 | member = "member" 22 | 23 | 24 | class OrganizationMember(BaseModel): 25 | role: OrganizationRole 26 | organization: Organization 27 | user: Account 28 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | BEAKER_PROTO_PATH = "../beaker/msg" 2 | 3 | .PHONY : checks 4 | checks : 5 | mypy . 6 | ruff check . 7 | black . --check 8 | isort --check . 9 | 10 | .PHONY : docs 11 | docs : 12 | rm -rf docs/build/ 13 | sphinx-autobuild -b html --watch beaker/ --watch README.md docs/source/ docs/build/ 14 | 15 | .PHONY : build 16 | build : 17 | rm -rf *.egg-info/ 18 | python -m build 19 | 20 | .PHONY : grpc 21 | grpc : 22 | python -m grpc_tools.protoc --python_out=./beaker/ --pyi_out=./beaker/ --grpc_python_out=./beaker/ -I $(BEAKER_PROTO_PATH) $(BEAKER_PROTO_PATH)/beaker.proto 23 | sed -i '' 's/import beaker_pb2 as beaker__pb2/from . import beaker_pb2 as beaker__pb2/' beaker/beaker_pb2_grpc.py 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | 3 | .eggs/ 4 | .mypy_cache 5 | *.egg-info/ 6 | build/ 7 | dist/ 8 | pip-wheel-metadata/ 9 | 10 | 11 | # dev tools 12 | 13 | .envrc 14 | .python-version 15 | .idea 16 | .venv/ 17 | .vscode/ 18 | /*.iml 19 | 20 | 21 | # jupyter notebooks 22 | 23 | .ipynb_checkpoints 24 | 25 | 26 | # miscellaneous 27 | 28 | .cache/ 29 | doc/_build/ 30 | *.swp 31 | .DS_Store 32 | 33 | 34 | # python 35 | 36 | *.pyc 37 | *.pyo 38 | __pycache__ 39 | 40 | 41 | # testing and continuous integration 42 | 43 | .coverage 44 | .pytest_cache/ 45 | .benchmarks 46 | 47 | # documentation build artifacts 48 | 49 | docs/build 50 | site/ 51 | 52 | # local integration testing 53 | 54 | beaker.yml 55 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Documentation 2 | description: Report an issue related to https://beaker-py.readthedocs.io/latest 3 | labels: 'documentation' 4 | 5 | body: 6 | - type: textarea 7 | attributes: 8 | label: 📚 The doc issue 9 | description: > 10 | A clear and concise description of what content in https://beaker-py.readthedocs.io/latest is an issue. 11 | validations: 12 | required: true 13 | - type: textarea 14 | attributes: 15 | label: Suggest a potential alternative/fix 16 | description: > 17 | Tell us how we could improve the documentation in this regard. 18 | - type: markdown 19 | attributes: 20 | value: > 21 | Thanks for contributing 🎉! 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= -W 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tests/client_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from flaky import flaky 3 | 4 | from beaker import Beaker 5 | from beaker.config import InternalConfig 6 | 7 | 8 | @flaky # this can fail if the request to GitHub fails 9 | def test_warn_for_newer_version(monkeypatch): 10 | import beaker.client 11 | import beaker.version 12 | 13 | InternalConfig().save() 14 | 15 | monkeypatch.setattr(Beaker, "CLIENT_VERSION", "1.0.0") 16 | monkeypatch.setattr(beaker.client, "_LATEST_VERSION_CHECKED", False) 17 | 18 | with pytest.warns(UserWarning, match="Please upgrade with"): 19 | Beaker.from_env() 20 | 21 | # Shouldn't warn a second time. 22 | Beaker.from_env() 23 | 24 | 25 | def test_str_method(client: Beaker): 26 | str(client) 27 | -------------------------------------------------------------------------------- /examples/sweep/entrypoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script that will run on Beaker as the Docker image's "entrypoint". 3 | 4 | All it does is write out a simple JSON file with a random number in it to 5 | the experiment's result directory. This is just meant to simulate the results 6 | of a training/evaluation pipeline. 7 | """ 8 | 9 | import json 10 | import random 11 | import sys 12 | 13 | # NOTE: it's important that this file is called 'metrics.json'. That tells Beaker 14 | # to collect metrics for the task from this file. 15 | OUTPUT_PATH = "/output/metrics.json" 16 | 17 | 18 | def main(x: int): 19 | random.seed(x) 20 | with open(OUTPUT_PATH, "w") as out_file: 21 | json.dump({"result": random.random()}, out_file) 22 | 23 | 24 | if __name__ == "__main__": 25 | main(int(sys.argv[1])) 26 | -------------------------------------------------------------------------------- /RELEASE_PROCESS.md: -------------------------------------------------------------------------------- 1 | # GitHub Release Process 2 | 3 | ## Steps 4 | 5 | 1. Update the version in `beaker-py/version.py`. 6 | 7 | 3. Run the release script: 8 | 9 | ```bash 10 | ./scripts/release.sh 11 | ``` 12 | 13 | This will commit the changes to the CHANGELOG and `version.py` files and then create a new tag in git 14 | which will trigger a workflow on GitHub Actions that handles the rest. 15 | 16 | ## Fixing a failed release 17 | 18 | If for some reason the GitHub Actions release workflow failed with an error that needs to be fixed, you'll have to delete both the tag and corresponding release from GitHub. After you've pushed a fix, delete the tag from your local clone with 19 | 20 | ```bash 21 | git tag -l | xargs git tag -d && git fetch -t 22 | ``` 23 | 24 | Then repeat the steps above. 25 | -------------------------------------------------------------------------------- /beaker/services/__init__.py: -------------------------------------------------------------------------------- 1 | from .account import AccountClient 2 | from .cluster import ClusterClient 3 | from .dataset import DatasetClient 4 | from .experiment import ExperimentClient 5 | from .group import GroupClient 6 | from .image import ImageClient 7 | from .job import JobClient 8 | from .node import NodeClient 9 | from .organization import OrganizationClient 10 | from .secret import SecretClient 11 | from .service_client import ServiceClient 12 | from .workspace import WorkspaceClient 13 | 14 | __all__ = [ 15 | "AccountClient", 16 | "ClusterClient", 17 | "DatasetClient", 18 | "ExperimentClient", 19 | "GroupClient", 20 | "ImageClient", 21 | "JobClient", 22 | "NodeClient", 23 | "OrganizationClient", 24 | "SecretClient", 25 | "ServiceClient", 26 | "WorkspaceClient", 27 | ] 28 | -------------------------------------------------------------------------------- /.github/workflows/pr_checks.yml: -------------------------------------------------------------------------------- 1 | name: PR Checks 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | pull_request: 9 | branches: 10 | - main 11 | paths: 12 | - 'beaker/**' 13 | 14 | jobs: 15 | changelog: 16 | name: CHANGELOG 17 | runs-on: ubuntu-latest 18 | if: github.event_name == 'pull_request' 19 | 20 | steps: 21 | - uses: actions/checkout@v1 # needs v1 for now 22 | 23 | - name: Check that CHANGELOG has been updated 24 | run: | 25 | # If this step fails, this means you haven't updated the CHANGELOG.md 26 | # file with notes on your contribution. 27 | git diff --name-only $(git merge-base origin/main HEAD) | grep '^CHANGELOG.md$' && echo "Thanks for helping keep our CHANGELOG up-to-date!" 28 | -------------------------------------------------------------------------------- /beaker/services/node.py: -------------------------------------------------------------------------------- 1 | from ..data_model import * 2 | from ..exceptions import * 3 | from .service_client import ServiceClient 4 | 5 | 6 | class NodeClient(ServiceClient): 7 | """ 8 | Accessed via :data:`Beaker.node `. 9 | """ 10 | 11 | def get(self, node_id: str) -> Node: 12 | """ 13 | Get information about a node. 14 | 15 | :param node_id: The ID of the node. 16 | 17 | :raises NodeNotFound: If the node doesn't exist. 18 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 19 | :raises RequestException: Any other exception that can occur when contacting the 20 | Beaker server. 21 | """ 22 | return Node.from_json( 23 | self.request( 24 | f"nodes/{node_id}", 25 | exceptions_for_status={404: NodeNotFound(node_id)}, 26 | ).json() 27 | ) 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /integration_tests/jobs_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from beaker import Beaker, ExperimentSpec, TaskSpec 4 | 5 | 6 | def test_job_stop_and_finalize(client: Beaker, experiment_name: str): 7 | start = time.time() 8 | spec = ExperimentSpec(budget="ai2/allennlp").with_task( 9 | TaskSpec.new( 10 | "main", 11 | docker_image="hello-world", 12 | ), 13 | ) 14 | print(f"Creating experiment {experiment_name}") 15 | experiment = client.experiment.create(experiment_name, spec) 16 | print("Waiting for jobs to register", end="") 17 | while not experiment.jobs: 18 | if time.time() - start > (60 * 5): 19 | raise TimeoutError 20 | time.sleep(2) 21 | print(".", end="") 22 | experiment = client.experiment.get(experiment.id) 23 | print("\nStopping job") 24 | client.job.stop(experiment.jobs[0]) 25 | print("Finalizing job") 26 | job = client.job.finalize(experiment.jobs[0]) 27 | assert job.is_finalized 28 | -------------------------------------------------------------------------------- /scripts/add_pr_comments_on_release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | repo_url=https://github.com/allenai/beaker-py 6 | tags=$(git tag -l --sort=-version:refname 'v*' | head -n 2) 7 | current_tag=$(echo "$tags" | head -n 1) 8 | last_tag=$(echo "$tags" | tail -n 1) 9 | 10 | echo "Current release: $current_tag" 11 | echo "Last release: $last_tag" 12 | 13 | if [ -z "$last_tag" ]; then 14 | echo "No previous release, nothing to do" 15 | exit 0; 16 | fi 17 | 18 | commits_since_last_release=$(git log "${last_tag}..${current_tag}" --format=format:%H) 19 | 20 | echo "Commits/PRs since last release:" 21 | for commit in $commits_since_last_release; do 22 | pr_number=$(gh pr list --search "$commit" --state merged --json number --jq '.[].number') 23 | if [ -z "$pr_number" ]; then 24 | echo "$commit" 25 | else 26 | echo "$commit (PR #$pr_number)" 27 | gh pr comment "$pr_number" --body "This PR has been released in [${current_tag}](${repo_url}/releases/tag/${current_tag})." 28 | fi 29 | done 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new feature 3 | labels: 'feature request' 4 | 5 | body: 6 | - type: textarea 7 | attributes: 8 | label: 🚀 The feature, motivation and pitch 9 | description: > 10 | A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. 11 | validations: 12 | required: true 13 | - type: textarea 14 | attributes: 15 | label: Alternatives 16 | description: > 17 | A description of any alternative solutions or features you've considered, if any. 18 | - type: textarea 19 | attributes: 20 | label: Additional context 21 | description: > 22 | Add any other context or screenshots about the feature request. 23 | - type: markdown 24 | attributes: 25 | value: > 26 | Thanks for contributing 🎉! 27 | -------------------------------------------------------------------------------- /docs/source/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ```{include} ../../examples/sweep/README.md 4 | :start-after: 5 | :end-before: 6 | ``` 7 | 8 | ```{note} 9 | You can find the source code for this example [on GitHub](https://github.com/allenai/beaker-py/tree/main/examples/sweep). 10 | ``` 11 | 12 | ### Setup 13 | 14 | Add the following files to a directory of your choosing: 15 | 16 | ````{tab} Dockerfile 17 | ```{literalinclude} ../../examples/sweep/Dockerfile 18 | :language: Dockerfile 19 | ``` 20 | ```` 21 | 22 | ````{tab} entrypoint.py 23 | ```{literalinclude} ../../examples/sweep/entrypoint.py 24 | :language: py 25 | ``` 26 | ```` 27 | 28 | ````{tab} run.py 29 | ```{literalinclude} ../../examples/sweep/run.py 30 | :language: py 31 | ``` 32 | ```` 33 | 34 | ### Running it 35 | 36 | ```{include} ../../examples/sweep/README.md 37 | :start-after: 38 | :end-before: 39 | ``` 40 | 41 | 42 | -------------------------------------------------------------------------------- /beaker/data_model/node.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional, Tuple 3 | 4 | from .base import BaseModel 5 | 6 | __all__ = ["NodeResources", "Node", "NodeUtilization"] 7 | 8 | 9 | class NodeResources(BaseModel): 10 | cpu_count: Optional[float] = None 11 | memory: Optional[str] = None 12 | gpu_count: Optional[int] = None 13 | gpu_type: Optional[str] = None 14 | gpu_ids: Optional[Tuple[str, ...]] = None 15 | 16 | 17 | class Node(BaseModel): 18 | id: str 19 | hostname: str 20 | created: datetime 21 | limits: NodeResources 22 | expiry: Optional[datetime] = None 23 | cordoned: Optional[datetime] = None 24 | cordon_reason: Optional[str] = None 25 | cordon_agent_id: Optional[str] = None 26 | cluster_id: Optional[str] = None 27 | account_id: Optional[str] = None 28 | 29 | 30 | class NodeUtilization(BaseModel): 31 | id: str 32 | hostname: str 33 | limits: NodeResources 34 | running_jobs: int 35 | running_preemptible_jobs: int 36 | used: NodeResources 37 | free: NodeResources 38 | cordoned: bool = False 39 | -------------------------------------------------------------------------------- /tests/cluster_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beaker import Beaker, Organization 4 | 5 | 6 | def test_cluster_get_on_prem(client: Beaker, beaker_on_prem_cluster_name: str): 7 | cluster = client.cluster.get(beaker_on_prem_cluster_name) 8 | assert cluster.autoscale is False 9 | assert cluster.is_cloud is False 10 | assert cluster.is_active is True 11 | assert cluster.node_spec is None 12 | assert cluster.node_shape is None 13 | 14 | 15 | @pytest.mark.skip(reason="Takes too long") 16 | def test_cluster_utilization(client: Beaker, beaker_on_prem_cluster_name: str): 17 | client.cluster.utilization(beaker_on_prem_cluster_name) 18 | 19 | 20 | def test_cluster_list(client: Beaker, beaker_org: Organization): 21 | client.cluster.list(beaker_org) 22 | 23 | 24 | def test_cluster_nodes(client: Beaker, beaker_on_prem_cluster_name: str): 25 | client.cluster.nodes(beaker_on_prem_cluster_name) 26 | 27 | 28 | def test_cluster_url(client: Beaker): 29 | assert ( 30 | client.cluster.url("ai2/jupiter-cirrascale-2") 31 | == "https://beaker.org/cl/ai2/jupiter/details" 32 | ) 33 | -------------------------------------------------------------------------------- /docs/source/api/experiment_spec.rst: -------------------------------------------------------------------------------- 1 | Experiment Spec 2 | --------------- 3 | 4 | .. autoclass:: beaker.ExperimentSpec 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: beaker.RetrySpec 9 | :members: 10 | :undoc-members: 11 | 12 | .. autoclass:: beaker.TaskSpec 13 | :members: 14 | :undoc-members: 15 | 16 | .. autoclass:: beaker.ImageSource 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: beaker.EnvVar 21 | :members: 22 | :undoc-members: 23 | 24 | .. autoclass:: beaker.DataMount 25 | :members: 26 | :undoc-members: 27 | 28 | .. autoclass:: beaker.DataSource 29 | :members: 30 | :undoc-members: 31 | 32 | .. autoclass:: beaker.TaskResources 33 | :members: 34 | :undoc-members: 35 | 36 | .. autoclass:: beaker.TaskContext 37 | :members: 38 | :undoc-members: 39 | 40 | .. autoclass:: beaker.Constraints 41 | :members: 42 | :undoc-members: 43 | 44 | .. autoclass:: beaker.ResultSpec 45 | :members: 46 | :undoc-members: 47 | 48 | .. autoclass:: beaker.Priority 49 | :members: 50 | :undoc-members: 51 | 52 | .. autoclass:: beaker.SpecVersion 53 | :members: 54 | :undoc-members: 55 | -------------------------------------------------------------------------------- /scripts/prepare_changelog.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | 4 | from beaker.version import VERSION 5 | 6 | 7 | def main(): 8 | changelog = Path("CHANGELOG.md") 9 | 10 | with changelog.open() as f: 11 | lines = f.readlines() 12 | 13 | insert_index: int = -1 14 | for i in range(len(lines)): 15 | line = lines[i] 16 | if line.startswith("## Unreleased"): 17 | insert_index = i + 1 18 | elif line.startswith(f"## [v{VERSION}]"): 19 | print("CHANGELOG already up-to-date") 20 | return 21 | elif line.startswith("## [v"): 22 | break 23 | 24 | if insert_index < 0: 25 | raise RuntimeError("Couldn't find 'Unreleased' section") 26 | 27 | lines.insert(insert_index, "\n") 28 | lines.insert( 29 | insert_index + 1, 30 | f"## [v{VERSION}](https://github.com/allenai/beaker-py/releases/tag/v{VERSION}) - " 31 | f"{datetime.now().strftime('%Y-%m-%d')}\n", 32 | ) 33 | 34 | with changelog.open("w") as f: 35 | f.writelines(lines) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /docs/source/faq.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | ```{rubric} Do I need to install the Beaker [command-line client](https://github.com/allenai/beaker) for beaker-py to work? 4 | ``` 5 | 6 | > No, **beaker-py** is written in pure Python. It communicates with the Beaker server directly through HTTP requests, so you don't need to have the command-line client installed. 7 | 8 | ```{rubric} Do I need Docker? 9 | ``` 10 | 11 | > Not necessarily. **beaker-py** will work fine without Docker, unless you want to do something that requires Docker, like uploading an image to Beaker ({meth}`Beaker.image.create `). 12 | 13 | ```{rubric} Is there way to suppress the progress bars that certain methods print out? 14 | ``` 15 | 16 | > Yes, just pass the `quiet=True` parameter to those methods. 17 | 18 | ```{rubric} I keep getting warnings that I should upgrade beaker-py, but I don't want to. Can I turn those warnings off? 19 | ``` 20 | 21 | > Yes, just pass `check_for_upgrades=False` to {class}`~beaker.Beaker()` or {meth}`Beaker.from_env() `. 22 | 23 | ```{rubric} What's the different between a task and a job? 24 | ``` 25 | 26 | > In Beaker, tasks are the fundamental unit of work. A {class}`~beaker.data_model.job.Job` is just an execution of a task. So a {class}`~beaker.data_model.experiment.Task` can have any number of {data}`~beaker.data_model.experiment.Task.jobs` associated with it, but a job is always associated with at most a single task (only "session" type jobs won't be associated with a task). 27 | -------------------------------------------------------------------------------- /beaker/data_model/_base_v2.py: -------------------------------------------------------------------------------- 1 | from typing import Any, ClassVar, Dict, Set, Type 2 | 3 | from pydantic import BaseModel as _BaseModel 4 | from pydantic import ConfigDict, model_validator 5 | 6 | from ..util import issue_data_model_warning, to_snake_case 7 | 8 | 9 | class BaseModelV2(_BaseModel): 10 | """ 11 | The Pydantic v2 base class for a Beaker data models. 12 | """ 13 | 14 | model_config = ConfigDict( 15 | validate_assignment=True, use_enum_values=True, frozen=True, extra="ignore" 16 | ) 17 | 18 | IGNORE_FIELDS: ClassVar[Set[str]] = set() 19 | 20 | @model_validator(mode="before") 21 | def _validate_and_rename_to_snake_case( # type: ignore 22 | cls: Type["BaseModelV2"], values: Dict[str, Any] # type: ignore 23 | ) -> Dict[str, Any]: 24 | """ 25 | Raw data from the Beaker server will use lower camel case. 26 | """ 27 | # In some cases we get an instance instead of a dict. 28 | # We'll just punt there and hope for the best. 29 | if not isinstance(values, dict): 30 | return values 31 | 32 | as_snake_case = {to_snake_case(k): v for k, v in values.items()} 33 | for key, value in as_snake_case.items(): 34 | if ( 35 | cls.model_config["extra"] != "allow" # type: ignore 36 | and key not in cls.model_fields 37 | and key not in cls.IGNORE_FIELDS 38 | ): 39 | issue_data_model_warning(cls, key, value) 40 | return as_snake_case 41 | -------------------------------------------------------------------------------- /docs/source/api/client.rst: -------------------------------------------------------------------------------- 1 | Client 2 | ------ 3 | 4 | .. autoclass:: beaker.Beaker 5 | :members: 6 | :member-order: bysource 7 | 8 | Account 9 | ~~~~~~~ 10 | 11 | .. autoclass:: beaker.services.AccountClient 12 | :members: 13 | :member-order: bysource 14 | 15 | Organization 16 | ~~~~~~~~~~~~ 17 | 18 | .. autoclass:: beaker.services.OrganizationClient 19 | :members: 20 | :member-order: bysource 21 | 22 | Workspace 23 | ~~~~~~~~~ 24 | 25 | .. autoclass:: beaker.services.WorkspaceClient 26 | :members: 27 | :member-order: bysource 28 | 29 | Cluster 30 | ~~~~~~~ 31 | 32 | .. autoclass:: beaker.services.ClusterClient 33 | :members: 34 | :member-order: bysource 35 | 36 | Node 37 | ~~~~ 38 | 39 | .. autoclass:: beaker.services.NodeClient 40 | :members: 41 | :member-order: bysource 42 | 43 | Dataset 44 | ~~~~~~~ 45 | 46 | .. autoclass:: beaker.services.DatasetClient 47 | :members: 48 | :member-order: bysource 49 | 50 | Image 51 | ~~~~~ 52 | 53 | .. autoclass:: beaker.services.ImageClient 54 | :members: 55 | :member-order: bysource 56 | 57 | Job 58 | ~~~ 59 | 60 | .. autoclass:: beaker.services.JobClient 61 | :members: 62 | :member-order: bysource 63 | 64 | Experiment 65 | ~~~~~~~~~~ 66 | 67 | .. autoclass:: beaker.services.ExperimentClient 68 | :members: 69 | :member-order: bysource 70 | 71 | Secret 72 | ~~~~~~ 73 | 74 | .. autoclass:: beaker.services.SecretClient 75 | :members: 76 | :member-order: bysource 77 | 78 | Group 79 | ~~~~~ 80 | 81 | .. autoclass:: beaker.services.GroupClient 82 | :members: 83 | :member-order: bysource 84 | -------------------------------------------------------------------------------- /tests/group_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beaker import Beaker, GroupConflict, GroupNotFound 4 | 5 | 6 | def test_group_methods( 7 | client: Beaker, group_name: str, alternate_group_name: str, hello_world_experiment_id: str 8 | ): 9 | # Create a new group. 10 | group = client.group.create(group_name) 11 | assert group.name == group_name 12 | 13 | # Add an experiment to the group. 14 | client.group.add_experiments(group, hello_world_experiment_id) 15 | assert len(client.group.list_experiments(group)) == 1 16 | 17 | # Export the experiments from the group 18 | # (expect a three line CSV: the header, one experiment, and a trailing newline) 19 | export = list(client.group.export_experiments(group)) 20 | assert len(export) == 1 21 | assert len(export[0].decode().split("\n")) == 3 22 | 23 | # Remove the experiment from the group. 24 | client.group.remove_experiments(group, hello_world_experiment_id) 25 | assert len(client.group.list_experiments(group)) == 0 26 | 27 | # Rename the group. 28 | group = client.group.rename(group, alternate_group_name) 29 | assert group.name == alternate_group_name 30 | 31 | # Test group not found error. 32 | with pytest.raises(GroupNotFound): 33 | client.group.get(group_name) 34 | 35 | # Test group conflict error. 36 | with pytest.raises(GroupConflict): 37 | client.group.create(alternate_group_name) 38 | 39 | # List groups in the workspace. 40 | group_names = [group.name for group in client.workspace.groups()] 41 | assert alternate_group_name in group_names 42 | -------------------------------------------------------------------------------- /beaker/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture(autouse=True) 5 | def doctest_fixtures( 6 | doctest_namespace, 7 | client, 8 | workspace_name, 9 | docker_image_name, 10 | beaker_image_name, 11 | beaker_cluster_name, 12 | beaker_on_prem_cluster_name, 13 | experiment_name, 14 | dataset_name, 15 | download_path, 16 | beaker_org_name, 17 | beaker_node_id, 18 | secret_name, 19 | group_name, 20 | hello_world_experiment_name, 21 | squad_dataset_name, 22 | squad_dataset_file_name, 23 | tmp_path, 24 | ): 25 | doctest_namespace["beaker"] = client 26 | doctest_namespace["workspace_name"] = workspace_name 27 | doctest_namespace["docker_image_name"] = docker_image_name 28 | doctest_namespace["beaker_image_name"] = beaker_image_name 29 | doctest_namespace["beaker_cluster_name"] = beaker_cluster_name 30 | doctest_namespace["beaker_on_prem_cluster_name"] = beaker_on_prem_cluster_name 31 | doctest_namespace["experiment_name"] = experiment_name 32 | doctest_namespace["dataset_name"] = dataset_name 33 | doctest_namespace["download_path"] = download_path 34 | doctest_namespace["beaker_org_name"] = beaker_org_name 35 | doctest_namespace["beaker_node_id"] = beaker_node_id 36 | doctest_namespace["secret_name"] = secret_name 37 | doctest_namespace["group_name"] = group_name 38 | doctest_namespace["hello_world_experiment_name"] = hello_world_experiment_name 39 | doctest_namespace["squad_dataset_name"] = squad_dataset_name 40 | doctest_namespace["squad_dataset_file_name"] = squad_dataset_file_name 41 | doctest_namespace["tmp_path"] = tmp_path 42 | -------------------------------------------------------------------------------- /integration_tests/images_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests creating, pushing, and pulling images to/from Beaker. 3 | 4 | This requires building a test image called "beaker-py-test" using the Dockerfile 5 | at "test_fixtures/docker/Dockerfile". 6 | """ 7 | 8 | from beaker import Beaker 9 | 10 | LOCAL_IMAGE_TAG = "beaker-py-test" 11 | 12 | 13 | def test_image_create_workflow( 14 | client: Beaker, beaker_image_name: str, alternate_beaker_image_name: str 15 | ): 16 | # Create and push the image. 17 | print(f"Creating image '{beaker_image_name}'") 18 | image = client.image.create(beaker_image_name, LOCAL_IMAGE_TAG) 19 | assert image.name == beaker_image_name 20 | assert image.original_tag == LOCAL_IMAGE_TAG 21 | 22 | # Rename the image. 23 | print(f"Renaming image to '{alternate_beaker_image_name}'") 24 | image = client.image.rename(image, alternate_beaker_image_name) 25 | assert image.name == alternate_beaker_image_name 26 | 27 | # Test with budget parameter 28 | print(f"Creating image with budget '{beaker_image_name}-budget'") 29 | image_with_budget = client.image.create( 30 | beaker_image_name + "-budget", 31 | LOCAL_IMAGE_TAG, 32 | budget="ai2/compute", 33 | description="Test image with budget", 34 | ) 35 | assert image_with_budget.name == beaker_image_name + "-budget" 36 | assert image_with_budget.original_tag == LOCAL_IMAGE_TAG 37 | client.image.delete(image_with_budget) 38 | 39 | 40 | def test_image_pull_workflow(client: Beaker, beaker_python_image_name: str): 41 | print(f"Pulling image '{beaker_python_image_name}' from Beaker") 42 | local_image = client.image.pull(beaker_python_image_name) 43 | print(f"Pull complete: {local_image}") 44 | -------------------------------------------------------------------------------- /integration_tests/experiments_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beaker import ( 4 | Beaker, 5 | ExperimentConflict, 6 | ExperimentNotFound, 7 | ExperimentSpec, 8 | ImageSource, 9 | ResultSpec, 10 | TaskContext, 11 | TaskSpec, 12 | ) 13 | 14 | 15 | def test_experiment_workflow( 16 | client: Beaker, 17 | experiment_name: str, 18 | alternate_experiment_name: str, 19 | hello_world_experiment_name: str, 20 | ): 21 | spec = ExperimentSpec( 22 | budget="ai2/allennlp", 23 | tasks=[ 24 | TaskSpec( 25 | name="main", 26 | image=ImageSource(docker="hello-world"), 27 | context=TaskContext(preemptible=True), 28 | result=ResultSpec(path="/unused"), # required even if the task produces no output. 29 | ), 30 | ], 31 | ) 32 | # Create the experiment. 33 | experiment = client.experiment.create(experiment_name, spec) 34 | 35 | # Wait for it to finish. 36 | experiment = client.experiment.wait_for(experiment, timeout=60 * 5)[0] 37 | 38 | # Get the logs. 39 | logs = "".join([line.decode() for line in client.experiment.logs(experiment)]) 40 | assert logs 41 | 42 | # Test experiment conflict error with rename. 43 | with pytest.raises(ExperimentConflict): 44 | client.experiment.rename(experiment, hello_world_experiment_name) 45 | 46 | # Rename the experiment. 47 | experiment = client.experiment.rename(experiment, alternate_experiment_name) 48 | assert experiment.name == alternate_experiment_name 49 | 50 | # Test experiment not found error. 51 | with pytest.raises(ExperimentNotFound): 52 | client.experiment.get(experiment_name) 53 | -------------------------------------------------------------------------------- /beaker/data_model/group.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Optional, Tuple 3 | 4 | from .account import Account 5 | from .base import BaseModel, BasePage, StrEnum 6 | from .workspace import WorkspaceRef 7 | 8 | __all__ = [ 9 | "Group", 10 | "GroupSpec", 11 | "GroupParameterType", 12 | "GroupParameter", 13 | "GroupPatch", 14 | "GroupsPage", 15 | "GroupSort", 16 | ] 17 | 18 | 19 | class Group(BaseModel): 20 | id: str 21 | name: Optional[str] = None 22 | full_name: Optional[str] = None 23 | owner: Account 24 | author: Account 25 | created: datetime 26 | modified: datetime 27 | workspace_ref: Optional[WorkspaceRef] = None 28 | description: Optional[str] = None 29 | 30 | @property 31 | def workspace(self) -> Optional[WorkspaceRef]: 32 | return self.workspace_ref 33 | 34 | 35 | class GroupSpec(BaseModel): 36 | workspace: Optional[str] = None 37 | name: Optional[str] = None 38 | description: Optional[str] = None 39 | experiments: Optional[List[str]] = None 40 | 41 | 42 | class GroupParameterType(StrEnum): 43 | metric = "metric" 44 | env = "env" 45 | 46 | 47 | class GroupParameter(BaseModel): 48 | type: GroupParameterType 49 | name: str 50 | 51 | 52 | class GroupPatch(BaseModel): 53 | name: Optional[str] = None 54 | description: Optional[str] = None 55 | add_experiments: Optional[List[str]] = None 56 | remove_experiments: Optional[List[str]] = None 57 | parameters: Optional[List[GroupParameter]] = None 58 | 59 | 60 | class GroupsPage(BasePage[Group]): 61 | data: Tuple[Group, ...] 62 | 63 | 64 | class GroupSort(StrEnum): 65 | created = "created" 66 | modified = "modified" 67 | author = "author" 68 | group_name = "name" 69 | group_name_or_description = "nameOrDescription" 70 | -------------------------------------------------------------------------------- /tests/util_test.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import time 3 | 4 | import pytest 5 | 6 | from beaker.client import Beaker 7 | from beaker.services.service_client import ServiceClient 8 | from beaker.util import * 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "camel_case, snake_case", 13 | [ 14 | ("hostPath", "host_path"), 15 | ("fooBarBaz", "foo_bar_baz"), 16 | ("docker", "docker"), 17 | ("replicaGroupID", "replica_group_id"), 18 | ], 19 | ) 20 | def test_to_lower_camel_and_back(camel_case: str, snake_case: str): 21 | assert to_lower_camel(snake_case) == camel_case 22 | assert to_snake_case(camel_case) == snake_case 23 | 24 | 25 | def test_cached_property(client: Beaker, alternate_workspace_name): 26 | class FakeService(ServiceClient): 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | self._x = 0 30 | 31 | @cached_property(ttl=0.5) 32 | def x(self) -> int: 33 | self._x += 1 34 | return self._x 35 | 36 | service_client = FakeService(client) 37 | 38 | assert service_client.x == 1 39 | assert service_client.x == 1 40 | 41 | time.sleep(1.0) 42 | assert service_client.x == 2 43 | 44 | client.config.default_workspace = alternate_workspace_name 45 | assert service_client.x == 3 46 | 47 | 48 | def test_format_cursor(): 49 | cursor = 100 50 | formatted = format_cursor(100) 51 | assert int.from_bytes(base64.urlsafe_b64decode(formatted), "little") == cursor 52 | 53 | 54 | def test_parse_duration(): 55 | assert parse_duration("1") == 1_000_000_000 56 | assert parse_duration("1s") == 1_000_000_000 57 | assert parse_duration("1sec") == 1_000_000_000 58 | assert parse_duration("1m") == 60 * 1_000_000_000 59 | assert parse_duration("1h") == 60 * 60 * 1_000_000_000 60 | -------------------------------------------------------------------------------- /scripts/release_notes.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | """ 4 | Prepares markdown release notes for GitHub releases. 5 | """ 6 | 7 | import os 8 | from typing import List 9 | 10 | TAG = os.environ["TAG"] 11 | 12 | ADDED_HEADER = "### Added 🎉" 13 | CHANGED_HEADER = "### Changed ⚠️" 14 | FIXED_HEADER = "### Fixed ✅" 15 | REMOVED_HEADER = "### Removed 👋" 16 | 17 | 18 | def get_change_log_notes() -> str: 19 | in_current_section = False 20 | current_section_notes: List[str] = [] 21 | with open("CHANGELOG.md") as changelog: 22 | for line in changelog: 23 | if line.startswith("## "): 24 | if line.startswith("## Unreleased"): 25 | continue 26 | if line.startswith(f"## [{TAG}]"): 27 | in_current_section = True 28 | continue 29 | break 30 | if in_current_section: 31 | if line.startswith("### Added"): 32 | line = ADDED_HEADER + "\n" 33 | elif line.startswith("### Changed"): 34 | line = CHANGED_HEADER + "\n" 35 | elif line.startswith("### Fixed"): 36 | line = FIXED_HEADER + "\n" 37 | elif line.startswith("### Removed"): 38 | line = REMOVED_HEADER + "\n" 39 | current_section_notes.append(line) 40 | assert current_section_notes 41 | return "## What's new\n\n" + "".join(current_section_notes).strip() + "\n" 42 | 43 | 44 | def get_commit_history() -> str: 45 | stream = os.popen( 46 | f"git log $(git describe --always --tags --abbrev=0 {TAG}^^)..{TAG} --oneline --pretty='%h %s'" 47 | ) 48 | return "## Commits\n\n" + stream.read() 49 | 50 | 51 | def main(): 52 | print(get_change_log_notes()) 53 | print(get_commit_history()) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. beaker-py documentation master file, created by 2 | sphinx-quickstart on Tue Sep 21 08:07:48 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | **beaker-py** 7 | =============== 8 | 9 | **beaker-py** is a lightweight, standalone, pure Python client for `Beaker `_. 10 | 11 | Features 12 | -------- 13 | 14 | .. include:: ../../README.md 15 | :start-after: 16 | :end-before: 17 | 18 | Contents 19 | -------- 20 | 21 | .. toctree:: 22 | :maxdepth: 2 23 | :caption: Getting started: 24 | 25 | installation 26 | quickstart 27 | overview 28 | examples 29 | faq 30 | 31 | .. toctree:: 32 | :hidden: 33 | :caption: API Reference 34 | 35 | api/client 36 | api/data_models 37 | api/experiment_spec 38 | api/config 39 | api/exceptions 40 | 41 | .. toctree:: 42 | :hidden: 43 | :caption: Development 44 | 45 | CHANGELOG 46 | License 47 | CONTRIBUTING 48 | GitHub Repository 49 | 50 | Team 51 | ---- 52 | 53 | **beaker-py** is developed and maintained at 54 | `the Allen Institute for Artificial Intelligence (AI2) `_. 55 | AI2 is a non-profit institute with the mission to contribute to humanity through high-impact AI research and engineering. 56 | 57 | To learn more about who specifically contributed to this codebase, see 58 | `our contributors `_ page. 59 | 60 | License 61 | ------- 62 | 63 | **beaker-py** is licensed under `Apache 2.0 `_. 64 | A full copy of the license can be found `on GitHub `_. 65 | 66 | Indices and tables 67 | ------------------ 68 | 69 | * :ref:`genindex` 70 | * :ref:`modindex` 71 | -------------------------------------------------------------------------------- /beaker/data_model/_base_v1.py: -------------------------------------------------------------------------------- 1 | from typing import Any, ClassVar, Dict, Optional, Set, Type 2 | 3 | from pydantic import BaseModel as _BaseModel 4 | from pydantic import root_validator, validator 5 | 6 | from ..util import issue_data_model_warning, to_snake_case 7 | 8 | 9 | def field_validator(*fields: str, mode: str = "after"): 10 | return validator(*fields, pre=mode == "before") 11 | 12 | 13 | def model_validator(mode: str = "after"): 14 | return root_validator(pre=mode == "before") # type: ignore 15 | 16 | 17 | class BaseModelV1(_BaseModel): 18 | """ 19 | The Pydantic v1 base class for all Beaker data models. 20 | """ 21 | 22 | class Config: 23 | validate_assignment = True 24 | use_enum_values = True 25 | frozen = True 26 | extra = "ignore" 27 | 28 | IGNORE_FIELDS: ClassVar[Set[str]] = set() 29 | 30 | @root_validator(pre=True) 31 | def _validate_and_rename_to_snake_case( # type: ignore 32 | cls: Type["BaseModelV1"], values: Dict[str, Any] # type: ignore 33 | ) -> Dict[str, Any]: 34 | """ 35 | Raw data from the Beaker server will use lower camel case. 36 | """ 37 | # In some cases we get an instance instead of a dict. 38 | # We'll just punt there and hope for the best. 39 | if not isinstance(values, dict): 40 | return values 41 | 42 | as_snake_case = {to_snake_case(k): v for k, v in values.items()} 43 | for key, value in as_snake_case.items(): 44 | if ( 45 | cls.__config__.extra != "allow" # type: ignore 46 | and key not in cls.__fields__ # type: ignore 47 | and key not in cls.IGNORE_FIELDS 48 | ): 49 | issue_data_model_warning(cls, key, value) 50 | return as_snake_case 51 | 52 | def model_copy(self, update: Optional[Dict[str, Any]] = None, deep: bool = False): # type: ignore 53 | return self.copy(update=update, deep=deep) # type: ignore 54 | 55 | def model_dump(self, *args, **kwargs): 56 | return self.dict(*args, **kwargs) # type: ignore 57 | -------------------------------------------------------------------------------- /beaker/services/account.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..data_model import * 4 | from ..exceptions import * 5 | from ..util import cached_property 6 | from .service_client import ServiceClient 7 | 8 | 9 | class AccountClient(ServiceClient): 10 | """ 11 | Accessed via :data:`Beaker.account `. 12 | """ 13 | 14 | @cached_property(ttl=3 * 60) 15 | def name(self) -> str: 16 | """ 17 | A convenience property to get username of your Beaker account. 18 | """ 19 | return self.whoami().name 20 | 21 | def whoami(self) -> Account: 22 | """ 23 | Check who you are authenticated as. 24 | 25 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 26 | :raises RequestException: Any other exception that can occur when contacting the 27 | Beaker server. 28 | """ 29 | return Account.from_json(self.request("user").json()) 30 | 31 | def list_organizations(self) -> List[Organization]: 32 | """ 33 | List all organizations you are a member of. 34 | 35 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 36 | :raises RequestException: Any other exception that can occur when contacting the 37 | Beaker server. 38 | """ 39 | return [Organization.from_json(d) for d in self.request("user/orgs").json()["data"]] 40 | 41 | def get(self, account: str) -> Account: 42 | """ 43 | Get information about an account. 44 | 45 | :param account: The account name or ID. 46 | 47 | :raises AccountNotFound: If the account doesn't exist. 48 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 49 | :raises RequestException: Any other exception that can occur when contacting the 50 | Beaker server. 51 | """ 52 | return Account.from_json( 53 | self.request( 54 | f"users/{self.url_quote(account)}", 55 | method="GET", 56 | exceptions_for_status={404: AccountNotFound(account)}, 57 | ).json() 58 | ) 59 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | labels: 'bug' 4 | 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: > 9 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/allenai/beaker-py/issues?q=is%3Aissue+sort%3Acreated-desc+). 10 | - type: textarea 11 | attributes: 12 | label: 🐛 Describe the bug 13 | description: | 14 | Please provide a clear and concise description of what the bug is. 15 | 16 | If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: 17 | 18 | ```python 19 | # All necessary imports at the beginning 20 | from beaker import * 21 | 22 | # A succinct reproducing example trimmed down to the essential parts: 23 | beaker = Beaker.from_env() 24 | assert False is True, "Oh no!" 25 | ``` 26 | 27 | If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. 28 | 29 | Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. 30 | placeholder: | 31 | A clear and concise description of what the bug is. 32 | validations: 33 | required: true 34 | - type: textarea 35 | attributes: 36 | label: Versions 37 | description: | 38 | Please run the following and paste the output below. 39 | ```sh 40 | python --version && pip freeze 41 | ``` 42 | validations: 43 | required: true 44 | - type: markdown 45 | attributes: 46 | value: > 47 | Thanks for contributing 🎉! 48 | -------------------------------------------------------------------------------- /tests/job_test.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import Optional, Union 3 | 4 | import pytest 5 | 6 | from beaker import Beaker, CurrentJobStatus, JobKind, JobNotFound 7 | 8 | 9 | def test_job_get(client: Beaker, hello_world_job_id: str): 10 | job = client.job.get(hello_world_job_id) 11 | assert job.id == hello_world_job_id 12 | assert job.status.current == CurrentJobStatus.finalized 13 | assert job.kind == JobKind.execution 14 | assert job.to_json()["kind"] == "execution" 15 | 16 | 17 | def test_job_results(client: Beaker, hello_world_job_id: str): 18 | client.job.results(hello_world_job_id) 19 | 20 | 21 | def test_job_logs(client: Beaker, hello_world_job_id: str): 22 | logs = "\n".join( 23 | [ 24 | line.strip() 25 | for line in b"".join(list(client.job.logs(hello_world_job_id, quiet=True))) 26 | .decode() 27 | .split("\n") 28 | ] 29 | ) 30 | assert "Hello from Docker!" in logs 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "since, tail_lines", 35 | [ 36 | (None, None), 37 | (None, 10), 38 | (datetime.utcnow(), None), 39 | (timedelta(hours=1), None), 40 | ], 41 | ) 42 | def test_structured_job_logs( 43 | client: Beaker, 44 | hello_world_job_id: str, 45 | since: Optional[Union[datetime, timedelta]], 46 | tail_lines: Optional[int], 47 | ): 48 | list( 49 | client.job.structured_logs( 50 | hello_world_job_id, quiet=True, since=since, tail_lines=tail_lines 51 | ) 52 | ) 53 | 54 | 55 | def test_job_logs_since(client: Beaker, hello_world_job_id: str): 56 | logs = "\n".join( 57 | [ 58 | line.strip() 59 | for line in b"".join( 60 | list( 61 | client.job.logs( 62 | hello_world_job_id, quiet=True, since="2023-02-11T00:34:19.938308862Z" 63 | ) 64 | ) 65 | ) 66 | .decode() 67 | .split("\n") 68 | ] 69 | ) 70 | assert "Hello from Docker!" not in logs 71 | 72 | 73 | def test_summarized_job_events(client: Beaker): 74 | client.job.summarized_events("01JSCFT1563SA35GXS206J575B") 75 | 76 | try: 77 | client.job.summarized_events("blah") 78 | except JobNotFound: 79 | pass 80 | -------------------------------------------------------------------------------- /.github/actions/setup-venv/action.yml: -------------------------------------------------------------------------------- 1 | name: Python virtualenv 2 | description: Set up a Python virtual environment with caching 3 | inputs: 4 | python-version: 5 | description: The Python version to use 6 | required: true 7 | cache-prefix: 8 | description: Update this to invalidate the cache 9 | required: true 10 | default: v0 11 | packages: 12 | description: Extra packages to install or pin 13 | required: false 14 | default: '' 15 | runs: 16 | using: composite 17 | steps: 18 | - name: Setup Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: ${{ inputs.python-version }} 22 | 23 | - shell: bash 24 | run: | 25 | # Install prerequisites. 26 | pip install --upgrade pip setuptools wheel build virtualenv 27 | 28 | - shell: bash 29 | run: | 30 | # Get the exact Python version to use in the cache key. 31 | echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV 32 | 33 | - uses: actions/cache@v4 34 | id: virtualenv-cache 35 | with: 36 | path: .venv 37 | key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('pyproject.toml') }} 38 | 39 | - if: steps.virtualenv-cache.outputs.cache-hit != 'true' 40 | shell: bash 41 | run: | 42 | # Set up virtual environment without cache hit. 43 | test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv 44 | 45 | - if: steps.virtualenv-cache.outputs.cache-hit != 'true' && inputs.packages == '' 46 | shell: bash 47 | run: | 48 | . .venv/bin/activate 49 | pip install -e .[dev] 50 | 51 | - if: steps.virtualenv-cache.outputs.cache-hit != 'true' && inputs.packages != '' 52 | shell: bash 53 | run: | 54 | . .venv/bin/activate 55 | pip install '${{ inputs.packages }}' -e .[dev] 56 | 57 | - if: steps.virtualenv-cache.outputs.cache-hit == 'true' 58 | shell: bash 59 | run: | 60 | # Set up virtual environment from cache hit. 61 | . .venv/bin/activate 62 | pip install --no-deps -e .[dev] 63 | 64 | - shell: bash 65 | run: | 66 | # Show environment info. 67 | . .venv/bin/activate 68 | echo "✓ Installed $(python --version) virtual environment to $(which python)" 69 | pip freeze 70 | -------------------------------------------------------------------------------- /beaker/data_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .account import * 2 | from .cluster import * 3 | from .dataset import * 4 | from .experiment import * 5 | from .experiment_spec import * 6 | from .group import * 7 | from .image import * 8 | from .job import * 9 | from .node import * 10 | from .organization import * 11 | from .secret import * 12 | from .workspace import * 13 | 14 | __all__ = [ 15 | "Account", 16 | "ClusterStatus", 17 | "Cluster", 18 | "ClusterUtilization", 19 | "ClusterSpec", 20 | "ClusterPatch", 21 | "DatasetStorage", 22 | "DatasetSize", 23 | "Dataset", 24 | "DatasetInfo", 25 | "DatasetInfoPage", 26 | "Digest", 27 | "DigestHashAlgorithm", 28 | "FileInfo", 29 | "DatasetsPage", 30 | "DatasetSpec", 31 | "DatasetPatch", 32 | "DatasetSort", 33 | "Experiment", 34 | "Task", 35 | "Tasks", 36 | "ExperimentsPage", 37 | "ExperimentPatch", 38 | "ExperimentSort", 39 | "ImageSource", 40 | "EnvVar", 41 | "DataSource", 42 | "DataMount", 43 | "ResultSpec", 44 | "TaskResources", 45 | "Priority", 46 | "TaskContext", 47 | "TaskSpec", 48 | "SpecVersion", 49 | "RetrySpec", 50 | "ExperimentSpec", 51 | "Constraints", 52 | "Group", 53 | "GroupSpec", 54 | "GroupParameterType", 55 | "GroupParameter", 56 | "GroupPatch", 57 | "GroupsPage", 58 | "GroupSort", 59 | "Image", 60 | "ImagesPage", 61 | "ImageRepoAuth", 62 | "ImageRepo", 63 | "DockerLayerProgress", 64 | "DockerLayerUploadStatus", 65 | "DockerLayerDownloadStatus", 66 | "DockerLayerUploadState", 67 | "DockerLayerDownloadState", 68 | "ImageSpec", 69 | "ImagePatch", 70 | "ImageSort", 71 | "CurrentJobStatus", 72 | "CanceledCode", 73 | "JobStatus", 74 | "ExecutionResult", 75 | "JobRequests", 76 | "JobLimits", 77 | "JobExecution", 78 | "JobKind", 79 | "Job", 80 | "Jobs", 81 | "JobStatusUpdate", 82 | "JobPatch", 83 | "Session", 84 | "SummarizedJobEvent", 85 | "JobLog", 86 | "NodeResources", 87 | "Node", 88 | "NodeUtilization", 89 | "Organization", 90 | "OrganizationRole", 91 | "OrganizationMember", 92 | "Secret", 93 | "WorkspaceSize", 94 | "Workspace", 95 | "WorkspaceRef", 96 | "WorkspacePage", 97 | "WorkspaceSpec", 98 | "WorkspaceTransferSpec", 99 | "Permission", 100 | "WorkspacePermissions", 101 | "WorkspacePatch", 102 | "WorkspacePermissionsPatch", 103 | "WorkspaceClearResult", 104 | "WorkspaceSort", 105 | ] 106 | -------------------------------------------------------------------------------- /beaker/data_model/experiment.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Optional, Tuple 3 | 4 | from pydantic import Field 5 | 6 | from .account import Account 7 | from .base import BaseModel, BasePage, MappedSequence, StrEnum 8 | from .job import Job 9 | from .workspace import WorkspaceRef 10 | 11 | __all__ = ["Experiment", "Task", "Tasks", "ExperimentsPage", "ExperimentPatch", "ExperimentSort"] 12 | 13 | 14 | class Experiment(BaseModel): 15 | id: str 16 | name: Optional[str] = None 17 | full_name: Optional[str] = None 18 | description: Optional[str] = None 19 | budget: Optional[str] = None 20 | author: Account 21 | created: datetime 22 | canceled: Optional[datetime] = None 23 | canceled_for: Optional[str] = None 24 | workspace_ref: WorkspaceRef 25 | jobs: Tuple[Job, ...] = Field(default_factory=tuple) 26 | 27 | @property 28 | def display_name(self) -> str: 29 | return self.name if self.name is not None else self.id 30 | 31 | @property 32 | def workspace(self) -> WorkspaceRef: 33 | return self.workspace_ref 34 | 35 | 36 | class Task(BaseModel): 37 | id: str 38 | name: Optional[str] = None 39 | experiment_id: str 40 | author: Account 41 | created: datetime 42 | schedulable: bool = False 43 | jobs: Tuple[Job, ...] = Field(default_factory=tuple) 44 | owner: Optional[Account] = None 45 | # replica_rank: Optional[int] = None 46 | 47 | @property 48 | def replica_rank(self) -> Optional[int]: 49 | if ( 50 | (job := self.latest_job) is not None 51 | and job.execution is not None 52 | and (replica_rank := job.execution.replica_rank) is not None 53 | ): 54 | return replica_rank 55 | return None 56 | 57 | @property 58 | def display_name(self) -> str: 59 | return self.name if self.name is not None else self.id 60 | 61 | @property 62 | def latest_job(self) -> Optional[Job]: 63 | if not self.jobs: 64 | return None 65 | return sorted(self.jobs, key=lambda job: job.status.created)[-1] 66 | 67 | 68 | class Tasks(MappedSequence[Task]): 69 | """ 70 | A sequence of :class:`Task` that also behaves like a mapping of task names to tasks, 71 | i.e. you can use ``get()`` or ``__getitem__()`` with the name of the task. 72 | """ 73 | 74 | def __init__(self, tasks: List[Task]): 75 | super().__init__(tasks, {task.name: task for task in tasks if task.name is not None}) 76 | 77 | 78 | class ExperimentsPage(BasePage[Experiment]): 79 | data: Tuple[Experiment, ...] 80 | 81 | 82 | class ExperimentPatch(BaseModel): 83 | name: Optional[str] = None 84 | description: Optional[str] = None 85 | 86 | 87 | class ExperimentSort(StrEnum): 88 | created = "created" 89 | author = "author" 90 | experiment_name = "name" 91 | experiment_name_or_description = "nameOrDescription" 92 | -------------------------------------------------------------------------------- /docs/source/api/data_models.rst: -------------------------------------------------------------------------------- 1 | Data Models 2 | ----------- 3 | 4 | .. autoclass:: beaker.data_model.base.BaseModel 5 | :members: 6 | 7 | Account 8 | ~~~~~~~ 9 | 10 | .. autoclass:: beaker.Account 11 | :members: 12 | 13 | Organization 14 | ~~~~~~~~~~~~ 15 | 16 | .. autoclass:: beaker.Organization 17 | :members: 18 | 19 | .. autoclass:: beaker.OrganizationRole 20 | :members: 21 | 22 | .. autoclass:: beaker.OrganizationMember 23 | :members: 24 | 25 | Workspace 26 | ~~~~~~~~~ 27 | 28 | .. autoclass:: beaker.Workspace 29 | :members: 30 | 31 | .. autoclass:: beaker.WorkspaceSize 32 | :members: 33 | 34 | .. autoclass:: beaker.WorkspaceRef 35 | :members: 36 | 37 | .. autoclass:: beaker.Permission 38 | :members: 39 | 40 | .. autoclass:: beaker.WorkspacePermissions 41 | :members: 42 | 43 | Cluster 44 | ~~~~~~~ 45 | 46 | .. autoclass:: beaker.Cluster 47 | :members: 48 | 49 | .. autoclass:: beaker.ClusterStatus 50 | :members: 51 | 52 | .. autoclass:: beaker.ClusterUtilization 53 | :members: 54 | 55 | .. autoclass:: beaker.ClusterSpec 56 | :members: 57 | 58 | .. autoclass:: beaker.ClusterPatch 59 | :members: 60 | 61 | Node 62 | ~~~~ 63 | 64 | .. autoclass:: beaker.Node 65 | :members: 66 | 67 | .. autoclass:: beaker.NodeResources 68 | :members: 69 | 70 | .. autoclass:: beaker.NodeUtilization 71 | :members: 72 | 73 | Dataset 74 | ~~~~~~~ 75 | 76 | .. autoclass:: beaker.Dataset 77 | :members: 78 | 79 | .. autoclass:: beaker.DatasetStorage 80 | :members: 81 | 82 | .. autoclass:: beaker.FileInfo 83 | :members: 84 | 85 | .. autoclass:: beaker.Digest 86 | :members: 87 | 88 | .. autoclass:: beaker.DigestHashAlgorithm 89 | :members: 90 | 91 | Image 92 | ~~~~~ 93 | 94 | .. autoclass:: beaker.Image 95 | :members: 96 | 97 | Job 98 | ~~~ 99 | 100 | .. autoclass:: beaker.Job 101 | :members: 102 | 103 | .. autoclass:: beaker.JobKind 104 | :members: 105 | 106 | .. autoclass:: beaker.CurrentJobStatus 107 | :members: 108 | 109 | .. autoclass:: beaker.CanceledCode 110 | :members: 111 | 112 | .. autoclass:: beaker.JobStatus 113 | :members: 114 | 115 | .. autoclass:: beaker.ExecutionResult 116 | :members: 117 | 118 | .. autoclass:: beaker.JobRequests 119 | :members: 120 | 121 | .. autoclass:: beaker.JobExecution 122 | :members: 123 | 124 | .. autoclass:: beaker.JobLimits 125 | :members: 126 | 127 | .. autoclass:: beaker.Session 128 | :members: 129 | 130 | .. autoclass:: beaker.SummarizedJobEvent 131 | :members: 132 | 133 | .. autoclass:: beaker.JobLog 134 | :members: 135 | 136 | Experiment 137 | ~~~~~~~~~~ 138 | 139 | .. autoclass:: beaker.Experiment 140 | :members: 141 | 142 | .. autoclass:: beaker.Task 143 | :members: 144 | 145 | .. autoclass:: beaker.Tasks 146 | :members: 147 | 148 | Secret 149 | ~~~~~~ 150 | 151 | .. autoclass:: beaker.Secret 152 | :members: 153 | 154 | Group 155 | ~~~~~ 156 | 157 | .. autoclass:: beaker.Group 158 | :members: 159 | -------------------------------------------------------------------------------- /beaker/data_model/cluster.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional, Tuple 3 | 4 | from .base import BaseModel, StrEnum, field_validator 5 | from .job import Job 6 | from .node import NodeResources, NodeUtilization 7 | 8 | __all__ = ["ClusterStatus", "Cluster", "ClusterUtilization", "ClusterSpec", "ClusterPatch"] 9 | 10 | 11 | class ClusterStatus(StrEnum): 12 | """ 13 | Current status of a cluster. 14 | """ 15 | 16 | pending = "pending" 17 | active = "active" 18 | terminated = "terminated" 19 | failed = "failed" 20 | 21 | 22 | class Cluster(BaseModel): 23 | id: str 24 | name: str 25 | full_name: str 26 | created: datetime 27 | autoscale: bool 28 | capacity: int 29 | preemptible: bool 30 | status: ClusterStatus 31 | status_message: Optional[str] = None 32 | aliases: Optional[Tuple[str, ...]] = None 33 | node_spec: Optional[NodeResources] = None 34 | """ 35 | The requested node configuration. 36 | """ 37 | node_shape: Optional[NodeResources] = None 38 | """ 39 | The actual node configuration. 40 | """ 41 | node_cost: Optional[str] = None 42 | validated: Optional[datetime] = None 43 | user_restrictions: Optional[Tuple[str, ...]] = None 44 | allow_preemptible_restriction_exceptions: Optional[bool] = None 45 | compute_source: Optional[str] = None 46 | max_job_timeout: Optional[int] = None 47 | max_session_timeout: Optional[int] = None 48 | require_preemptible_tasks: Optional[bool] = None 49 | 50 | @field_validator("validated") 51 | def _validate_datetime(cls, v: Optional[datetime]) -> Optional[datetime]: 52 | if v is not None and v.year == 1: 53 | return None 54 | return v 55 | 56 | @field_validator("node_spec") 57 | def _validate_node_spec(cls, v: Optional[NodeResources]) -> Optional[NodeResources]: 58 | if v is not None and not v.to_json(): 59 | return None 60 | return v 61 | 62 | @property 63 | def is_cloud(self) -> bool: 64 | """ 65 | Returns ``True`` is the cluster is a cloud cluster, otherwise ``False``. 66 | """ 67 | return self.node_shape is not None and self.node_spec is not None 68 | 69 | @property 70 | def is_active(self) -> bool: 71 | """ 72 | Returns ``True`` if the cluster is ready to be used. 73 | """ 74 | return not self.is_cloud or self.status == ClusterStatus.active 75 | 76 | 77 | class ClusterUtilization(BaseModel): 78 | cluster: Cluster 79 | running_jobs: int 80 | queued_jobs: int 81 | running_preemptible_jobs: int 82 | nodes: Tuple[NodeUtilization, ...] 83 | jobs: Tuple[Job, ...] 84 | 85 | @property 86 | def id(self) -> str: 87 | return self.cluster.id 88 | 89 | 90 | class ClusterSpec(BaseModel): 91 | name: str 92 | capacity: int 93 | preemptible: bool 94 | spec: NodeResources 95 | 96 | 97 | class ClusterPatch(BaseModel): 98 | capacity: Optional[int] = None 99 | allow_preemptible_restriction_exceptions: Optional[bool] = None 100 | -------------------------------------------------------------------------------- /examples/sweep/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script will upload an image to Beaker and then submit a bunch 3 | of experiments with different inputs. It will wait for all experiments to finish 4 | and then collect the results. 5 | 6 | See the output of 'python run.py --help' for usage. 7 | """ 8 | 9 | import argparse 10 | import uuid 11 | 12 | import petname 13 | from rich import print, progress, table, traceback 14 | 15 | from beaker import * 16 | 17 | 18 | def unique_name() -> str: 19 | """Helper function to generate a unique name for the image, group, and each experiment.""" 20 | return petname.generate() + "-" + str(uuid.uuid4())[:8] # type: ignore 21 | 22 | 23 | def main(image: str, workspace: str): 24 | beaker = Beaker.from_env(default_workspace=workspace) 25 | sweep_name = unique_name() 26 | print(f"Starting sweep '{sweep_name}'...\n") 27 | 28 | # Using the `beaker.session()` context manager is not necessary, but it does 29 | # speed things up since it allows the Beaker client to reuse the same TCP connection 30 | # for all requests made within-context. 31 | with beaker.session(): 32 | # Upload image to Beaker. 33 | print(f"Uploading image '{image}' to Beaker...") 34 | beaker_image = beaker.image.create(unique_name(), image) 35 | print( 36 | f"Image uploaded as '{beaker_image.full_name}', view at {beaker.image.url(beaker_image)}\n" 37 | ) 38 | 39 | # Launch experiments. 40 | experiments = [] 41 | for x in progress.track(range(5), description="Launching experiments..."): 42 | spec = ExperimentSpec.new( 43 | "ai2/allennlp", 44 | description=f"Run {x+1} of sweep {sweep_name}", 45 | beaker_image=beaker_image.full_name, 46 | result_path="/output", 47 | priority=Priority.preemptible, 48 | arguments=[str(x)], 49 | ) 50 | experiment = beaker.experiment.create(f"{sweep_name}-{x+1}", spec) 51 | experiments.append(experiment) 52 | print() 53 | 54 | # Create group. 55 | print("Creating group for sweep...") 56 | group = beaker.group.create( 57 | sweep_name, *experiments, description="Group for sweep {sweep_name}" 58 | ) 59 | print(f"Group '{group.full_name}' created, view at {beaker.group.url(group)}\n") 60 | 61 | # Wait for experiments to finish. 62 | print("Waiting for experiments to finalize...\n") 63 | experiments = beaker.experiment.wait_for(*experiments) 64 | print() 65 | 66 | # Display results as a table. 67 | results_table = table.Table(title="Results for sweep") 68 | results_table.add_column("Input") 69 | results_table.add_column("Output") 70 | for x, experiment in enumerate( 71 | progress.track(experiments, description="Gathering results...") 72 | ): 73 | metrics = beaker.experiment.metrics(experiment) 74 | assert metrics is not None 75 | results_table.add_row(f"x = {x}", f"{metrics['result']:.4f}") 76 | print() 77 | print(results_table) 78 | 79 | 80 | if __name__ == "__main__": 81 | traceback.install() 82 | 83 | parser = argparse.ArgumentParser(description="Run a hyperparameter sweep in Beaker") 84 | parser.add_argument( 85 | "image", type=str, help="""The tag of the local Docker image built from the Dockerfile.""" 86 | ) 87 | parser.add_argument("workspace", type=str, help="""The Beaker workspace to use.""") 88 | opts = parser.parse_args() 89 | 90 | main(image=opts.image, workspace=opts.workspace) 91 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "beaker-py" 7 | dynamic = ["version"] 8 | readme = "README.md" 9 | description = "A Python Beaker client" 10 | classifiers = [ 11 | "Intended Audience :: Science/Research", 12 | "License :: OSI Approved :: Apache Software License", 13 | "Programming Language :: Python :: 3", 14 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 15 | ] 16 | authors = [ 17 | { name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" }, 18 | { name = "Pete Walsh", email = "petew@allenai.org" }, 19 | ] 20 | license = {file = "LICENSE"} 21 | requires-python = ">=3.8" 22 | dependencies = [ 23 | "requests", 24 | "packaging", 25 | "pydantic>=1.8.2,<3.0", 26 | "rich>=12.3,<14.0", 27 | "PyYAML", 28 | "docker>=5.0,<8.0", 29 | "grpcio>=1.70.0", 30 | "protobuf>=5.0", 31 | ] 32 | 33 | [project.optional-dependencies] 34 | dev = [ 35 | "ruff", 36 | "mypy>=1.0,<1.6", 37 | "types-requests", 38 | "types-cachetools", 39 | "types-PyYAML", 40 | "types-protobuf", 41 | "black>=23.0,<24.0", 42 | "isort>=5.12,<5.13", 43 | "pytest<8.0", 44 | "pytest-sphinx", 45 | "flaky", 46 | "twine>=1.11.0", 47 | "build", 48 | "setuptools", 49 | "wheel", 50 | "Sphinx>=6.0,<7.0.2", 51 | "furo==2023.5.20", 52 | "myst-parser>=1.0,<2.1", 53 | "sphinx-copybutton==0.5.2", 54 | "sphinx-autobuild==2021.3.14", 55 | "sphinx-autodoc-typehints==1.23.3", 56 | "sphinx-inline-tabs==2022.1.2b11", 57 | "packaging", 58 | "petname==2.6", 59 | "grpcio-tools", 60 | ] 61 | 62 | [project.urls] 63 | Homepage = "https://github.com/allenai/beaker-py" 64 | Repository = "https://github.com/allenai/beaker-py" 65 | Changelog = "https://github.com/allenai/beaker-py/blob/main/CHANGELOG.md" 66 | Documentation = "https://beaker-py.readthedocs.io/" 67 | 68 | [tool.setuptools] 69 | include-package-data = true 70 | 71 | [tool.setuptools.packages.find] 72 | exclude = [ 73 | "tests*", 74 | "docs*", 75 | "scripts*", 76 | "examples*", 77 | "integration_tests*", 78 | ] 79 | 80 | [tool.setuptools.package-data] 81 | beaker = ["py.typed"] 82 | 83 | [tool.setuptools.dynamic] 84 | version = {attr = "beaker.version.VERSION"} 85 | 86 | [tool.black] 87 | line-length = 100 88 | include = '\.pyi?$' 89 | exclude = ''' 90 | ( 91 | __pycache__ 92 | | \.git 93 | | \.mypy_cache 94 | | \.pytest_cache 95 | | \.vscode 96 | | \.venv 97 | | \bdist\b 98 | | \bdoc\b 99 | | beaker_pb2.* 100 | ) 101 | ''' 102 | 103 | [tool.isort] 104 | profile = "black" 105 | multi_line_output = 3 106 | skip_glob = ["beaker/beaker_pb2*"] 107 | 108 | [tool.ruff] 109 | line-length = 115 110 | exclude = ["beaker/beaker_pb2*"] 111 | 112 | [tool.ruff.lint] 113 | ignore = ["E501", "F403", "F405"] 114 | 115 | [tool.ruff.lint.per-file-ignores] 116 | "__init__.py" = ["F401"] 117 | 118 | [tool.mypy] 119 | ignore_missing_imports = true 120 | no_site_packages = false 121 | check_untyped_defs = true 122 | 123 | [[tool.mypy.overrides]] 124 | module = "tests.*" 125 | strict_optional = false 126 | 127 | [tool.pytest.ini_options] 128 | testpaths = [ 129 | "tests/", 130 | "integration_tests/", 131 | ] 132 | python_classes = [ 133 | "Test*", 134 | "*Test", 135 | ] 136 | log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" 137 | log_level = "DEBUG" 138 | filterwarnings = [ 139 | 'ignore:.*distutils Version classes are deprecated.*:DeprecationWarning:docker\.utils\.utils', 140 | ] 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | 4 |
5 |
6 |

beaker-py

7 |

A lightweight, standalone, pure Python client for Beaker

8 |
9 |
10 | 11 | ❗ NOTICE: this project has moved! The latest version of beaker-py is now maintained in [allenai/beaker](https://github.com/allenai/beaker/tree/main/bindings/python), with documentation at [beaker-py-docs.allen.ai](https://beaker-py-docs.allen.ai/index.html). Version 1 will continue to survive here for folks that need it. We also provide a [migration guide](https://github.com/allenai/beaker/blob/main/bindings/python/MIGRATION_GUIDE.md) to make upgrading easier. 12 | 13 |
14 | 15 | ## Features 16 | 17 | 18 | 19 | 🪶 *Lightweight* 20 | 21 | - Minimal dependencies. 22 | - Only pure-Python dependencies. 23 | - Communicates directly with the Beaker server via HTTP requests (Beaker CLI not required). 24 | 25 | 💪 *Robust* 26 | 27 | - Automatically retries failed HTTP requests with exponential backoff. 28 | - Runtime data validation. 29 | - High test coverage. 30 | 31 | 📓 *Exhaustively-typed and documented* 32 | 33 | - Thorough data model for all input / output types. 34 | - Every expected HTTP error from the Beaker server is translated into a specific exception type. 35 | 36 | 37 | 38 | ## Quick links 39 | 40 | - [PyPI package](https://pypi.org/project/beaker-py/) 41 | - [Contributing](https://github.com/allenai/beaker-py/blob/main/CONTRIBUTING.md) 42 | - [License](https://github.com/allenai/beaker-py/blob/main/LICENSE) 43 | 44 | *See also 👇* 45 | 46 | - [Beaker project](https://github.com/allenai/beaker) 47 | - [Beaker Gantry](https://github.com/allenai/beaker-gantry) 48 | - Beaker-relevant *GitHub Actions* 49 | - [setup-beaker](https://github.com/marketplace/actions/setup-beaker) 50 | - [beaker-command](https://github.com/marketplace/actions/beaker-command) 51 | - [beaker-run](https://github.com/marketplace/actions/beaker-run) 52 | 53 | ## Installing 54 | 55 | ### Installing with `pip` 56 | 57 | **beaker-py** is available [on PyPI](https://pypi.org/project/beaker-py/). Just run 58 | 59 | ```bash 60 | pip install 'beaker-py<2.0' 61 | ``` 62 | 63 | ### Installing from source 64 | 65 | To install **beaker-py** from source, first clone [the repository](https://github.com/allenai/beaker-py): 66 | 67 | ```bash 68 | git clone https://github.com/allenai/beaker-py.git 69 | cd beaker-py 70 | ``` 71 | 72 | Then run 73 | 74 | ```bash 75 | pip install -e . 76 | ``` 77 | 78 | ## Quick start 79 | 80 | 81 | 82 | If you've already configured the [Beaker command-line client](https://github.com/allenai/beaker/), **beaker-py** will 83 | find and use the existing configuration file (usually located at `$HOME/.beaker/config.yml`). 84 | Otherwise just set the environment variable `BEAKER_TOKEN` to your Beaker [user token](https://beaker.org/user). 85 | 86 | Either way, you should then instantiate the Beaker client with `.from_env()`: 87 | 88 | ```python 89 | from beaker import Beaker 90 | 91 | beaker = Beaker.from_env(default_workspace="my_org/my_workspace") 92 | ``` 93 | 94 | The API of **beaker-py** is meant to mirror - as closely as possible - the API of the Beaker CLI. 95 | For example, when you do this with the CLI: 96 | 97 | ```bash 98 | beaker dataset create --name foo . 99 | ``` 100 | 101 | The **beaker-py** equivalent would be: 102 | 103 | ```python 104 | beaker.dataset.create("foo", ".") 105 | ``` 106 | 107 | -------------------------------------------------------------------------------- /beaker/data_model/image.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional, Tuple 3 | 4 | from .account import Account 5 | from .base import BaseModel, BasePage, StrEnum, field_validator 6 | from .workspace import WorkspaceRef 7 | 8 | __all__ = [ 9 | "Image", 10 | "ImagesPage", 11 | "ImageRepoAuth", 12 | "ImageRepo", 13 | "DockerLayerProgress", 14 | "DockerLayerUploadStatus", 15 | "DockerLayerDownloadStatus", 16 | "DockerLayerUploadState", 17 | "DockerLayerDownloadState", 18 | "ImageSpec", 19 | "ImagePatch", 20 | "ImageSort", 21 | ] 22 | 23 | 24 | class Image(BaseModel): 25 | id: str 26 | owner: Account 27 | author: Account 28 | created: datetime 29 | workspace_ref: WorkspaceRef 30 | original_tag: Optional[str] = None 31 | docker_tag: Optional[str] = None 32 | name: Optional[str] = None 33 | full_name: Optional[str] = None 34 | description: Optional[str] = None 35 | committed: Optional[datetime] = None 36 | size: Optional[int] = None 37 | budget: Optional[str] = None 38 | 39 | @property 40 | def display_name(self) -> str: 41 | return self.name if self.name is not None else self.id 42 | 43 | @property 44 | def workspace(self) -> WorkspaceRef: 45 | return self.workspace_ref 46 | 47 | @field_validator("committed") 48 | def _validate_datetime(cls, v: Optional[datetime]) -> Optional[datetime]: 49 | if v is not None and v.year == 1: 50 | return None 51 | return v 52 | 53 | 54 | class ImagesPage(BasePage[Image]): 55 | data: Tuple[Image, ...] 56 | 57 | 58 | class ImageRepoAuth(BaseModel): 59 | user: str 60 | password: str 61 | server_address: str 62 | 63 | 64 | class ImageRepo(BaseModel): 65 | image_tag: str 66 | auth: ImageRepoAuth 67 | 68 | 69 | class DockerLayerProgress(BaseModel): 70 | current: Optional[int] = None 71 | total: Optional[int] = None 72 | 73 | 74 | class DockerLayerUploadStatus(StrEnum): 75 | preparing = "preparing" 76 | waiting = "waiting" 77 | pushing = "pushing" 78 | pushed = "pushed" 79 | already_exists = "layer already exists" 80 | 81 | 82 | class DockerLayerDownloadStatus(StrEnum): 83 | waiting = "waiting" 84 | downloading = "downloading" 85 | download_complete = "download complete" 86 | verifying_checksum = "verifying checksum" 87 | extracting = "extracting" 88 | pull_complete = "pull complete" 89 | already_exists = "already exists" 90 | 91 | 92 | class DockerLayerUploadState(BaseModel): 93 | id: str 94 | status: str 95 | progress_detail: DockerLayerProgress 96 | progress: Optional[str] = None 97 | 98 | @field_validator("status", mode="before") 99 | def _validate_status(cls, v: str) -> str: 100 | return v.lower() 101 | 102 | 103 | class DockerLayerDownloadState(BaseModel): 104 | id: str 105 | status: str 106 | progress_detail: DockerLayerProgress 107 | progress: Optional[str] = None 108 | 109 | @field_validator("status", mode="before") 110 | def _validate_status(cls, v: str) -> str: 111 | return v.lower() 112 | 113 | 114 | class ImageSpec(BaseModel): 115 | workspace: Optional[str] = None 116 | image_id: Optional[str] = None 117 | image_tag: Optional[str] = None 118 | description: Optional[str] = None 119 | budget: Optional[str] = None 120 | 121 | 122 | class ImagePatch(BaseModel): 123 | name: Optional[str] = None 124 | description: Optional[str] = None 125 | commit: Optional[bool] = None 126 | 127 | 128 | class ImageSort(StrEnum): 129 | created = "created" 130 | author = "author" 131 | image_name = "name" 132 | -------------------------------------------------------------------------------- /beaker/data_model/workspace.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Any, Dict, List, Optional, Tuple 3 | 4 | from .account import Account 5 | from .base import BaseModel, BasePage, StrEnum, model_validator 6 | from .experiment_spec import Priority 7 | from .organization import Organization 8 | 9 | __all__ = [ 10 | "WorkspaceSize", 11 | "Workspace", 12 | "WorkspaceRef", 13 | "WorkspacePage", 14 | "WorkspaceSpec", 15 | "WorkspaceTransferSpec", 16 | "Permission", 17 | "WorkspacePermissions", 18 | "WorkspacePatch", 19 | "WorkspacePermissionsPatch", 20 | "WorkspaceClearResult", 21 | "WorkspaceSort", 22 | ] 23 | 24 | 25 | class WorkspaceSize(BaseModel): 26 | datasets: int = 0 27 | experiments: int = 0 28 | groups: int = 0 29 | images: int = 0 30 | environments: int = 0 31 | 32 | 33 | class Workspace(BaseModel): 34 | id: str 35 | name: str 36 | full_name: str 37 | description: Optional[str] = None 38 | size: Optional[WorkspaceSize] = None 39 | owner: Optional[Account] = None 40 | owner_org: Optional[Organization] = None 41 | budget: Optional[str] = None 42 | author: Account 43 | created: datetime 44 | modified: datetime 45 | archived: bool = False 46 | max_workload_priority: Optional[Priority] = None 47 | budget_id: Optional[str] = None 48 | slot_limit_preemptible: Optional[int] = None 49 | slot_limit_non_preemptible: Optional[int] = None 50 | assigned_slots_preemptible: Optional[int] = None 51 | 52 | @model_validator(mode="before") 53 | def _adjust_new_field_names_for_compat_with_rpc_api( 54 | cls, values: Dict[str, Any] 55 | ) -> Dict[str, Any]: 56 | if ( 57 | "maxWorkloadPriority" not in values 58 | and (priority := values.pop("maximumWorkloadPriority", None)) is not None 59 | ): 60 | values["maxWorkloadPriority"] = priority.lower().replace("job_priority_", "") 61 | 62 | values.setdefault("author", values.pop("authorUser", None)) 63 | 64 | if (name := values.get("name")) is not None and "/" in name: 65 | values["name"] = name.split("/")[1] 66 | 67 | if ( 68 | "fullName" not in values 69 | and (name := values.get("name")) is not None 70 | and (org := values.get("ownerOrg")) is not None 71 | ): 72 | values["fullName"] = f"{org['name']}/{name}" 73 | 74 | return values 75 | 76 | 77 | class WorkspaceRef(BaseModel): 78 | id: str 79 | name: str 80 | full_name: str 81 | 82 | 83 | class WorkspacePage(BasePage[Workspace]): 84 | data: Tuple[Workspace, ...] 85 | org: Optional[str] = None 86 | 87 | 88 | class WorkspaceSpec(BaseModel): 89 | name: Optional[str] = None 90 | description: Optional[str] = None 91 | public: Optional[bool] = None 92 | org: Optional[str] = None 93 | 94 | 95 | class WorkspaceTransferSpec(BaseModel): 96 | ids: List[str] 97 | 98 | 99 | class Permission(StrEnum): 100 | """ 101 | Workspace permission levels. 102 | """ 103 | 104 | no_permission = "none" 105 | read = "read" 106 | write = "write" 107 | full_control = "all" 108 | 109 | 110 | class WorkspacePermissions(BaseModel): 111 | requester_auth: str 112 | public: bool 113 | authorizations: Optional[Dict[str, Permission]] = None 114 | """ 115 | A dictionary of account IDs to authorizations. 116 | """ 117 | 118 | 119 | class WorkspacePatch(BaseModel): 120 | name: Optional[str] = None 121 | description: Optional[str] = None 122 | archive: Optional[bool] = None 123 | 124 | 125 | class WorkspacePermissionsPatch(BaseModel): 126 | public: Optional[bool] = None 127 | authorizations: Optional[Dict[str, Permission]] = None 128 | 129 | 130 | class WorkspaceClearResult(BaseModel): 131 | groups_deleted: int = 0 132 | experiments_deleted: int = 0 133 | images_deleted: int = 0 134 | datasets_deleted: int = 0 135 | secrets_deleted: int = 0 136 | 137 | 138 | class WorkspaceSort(StrEnum): 139 | created = "created" 140 | modified = "modified" 141 | workspace_name = "name" 142 | -------------------------------------------------------------------------------- /beaker/data_model/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from enum import Enum 3 | from typing import ( 4 | Any, 5 | Dict, 6 | Generic, 7 | Iterator, 8 | Mapping, 9 | Optional, 10 | Sequence, 11 | Tuple, 12 | Type, 13 | TypeVar, 14 | Union, 15 | ) 16 | 17 | from pydantic import ValidationError 18 | 19 | from ..util import to_lower_camel, to_snake_case 20 | 21 | try: 22 | from pydantic import field_validator, model_validator 23 | 24 | from ._base_v2 import BaseModelV2 as _BaseModel 25 | except ImportError: 26 | from ._base_v1 import BaseModelV1 as _BaseModel # type: ignore 27 | from ._base_v1 import field_validator, model_validator # type: ignore 28 | 29 | T = TypeVar("T") 30 | 31 | logger = logging.getLogger("beaker") 32 | 33 | 34 | __all__ = [ 35 | "BaseModel", 36 | "MappedSequence", 37 | "StrEnum", 38 | "IntEnum", 39 | "BasePage", 40 | "field_validator", 41 | "model_validator", 42 | ] 43 | 44 | 45 | class BaseModel(_BaseModel): # type: ignore 46 | """ 47 | The base class for all Beaker data models. 48 | """ 49 | 50 | def __str__(self) -> str: 51 | return self.__repr__() 52 | 53 | def __getitem__(self, key): 54 | try: 55 | return self.model_dump()[key] # type: ignore 56 | except KeyError: 57 | if not key.islower(): 58 | snake_case_key = to_snake_case(key) 59 | try: 60 | return self.model_dump()[snake_case_key] # type: ignore 61 | except KeyError: 62 | pass 63 | raise 64 | 65 | @classmethod 66 | def from_json(cls: Type[T], json_data: Dict[str, Any]) -> T: 67 | try: 68 | return cls(**json_data) 69 | except ValidationError: 70 | logger.error("Error validating raw JSON data for %s: %s", cls.__name__, json_data) 71 | raise 72 | 73 | def to_json(self) -> Dict[str, Any]: 74 | return self.jsonify(self) 75 | 76 | @classmethod 77 | def jsonify(cls, x: Any) -> Any: 78 | if isinstance(x, BaseModel): 79 | return { 80 | to_lower_camel(key): cls.jsonify(value) for key, value in x if value is not None # type: ignore 81 | } 82 | elif isinstance(x, Enum): 83 | return cls.jsonify(x.value) 84 | elif isinstance(x, (str, float, int, bool)): 85 | return x 86 | elif isinstance(x, dict): 87 | return {key: cls.jsonify(value) for key, value in x.items()} 88 | elif isinstance(x, (list, tuple, set)): 89 | return [cls.jsonify(x_i) for x_i in x] 90 | else: 91 | return x 92 | 93 | 94 | class MappedSequence(Sequence[T], Mapping[str, T]): 95 | def __init__(self, sequence: Sequence[T], mapping: Mapping[str, T]): 96 | self._sequence = sequence 97 | self._mapping = mapping 98 | 99 | def __getitem__(self, k) -> Union[T, Sequence[T]]: # type: ignore[override] 100 | if isinstance(k, (int, slice)): 101 | return self._sequence[k] 102 | elif isinstance(k, str): 103 | return self._mapping[k] 104 | else: 105 | raise TypeError("keys must be integers, slices, or strings") 106 | 107 | def __contains__(self, k) -> bool: 108 | if isinstance(k, str): 109 | return k in self._mapping 110 | else: 111 | return k in self._sequence 112 | 113 | def __iter__(self) -> Iterator[T]: 114 | return iter(self._sequence) 115 | 116 | def __len__(self) -> int: 117 | return len(self._sequence) 118 | 119 | def keys(self): 120 | return self._mapping.keys() 121 | 122 | def values(self): 123 | return self._mapping.values() 124 | 125 | 126 | class StrEnum(str, Enum): 127 | def __str__(self) -> str: 128 | return self.value 129 | 130 | 131 | class IntEnum(int, Enum): 132 | def __str__(self) -> str: 133 | return str(self.value) 134 | 135 | 136 | class BasePage(BaseModel, Generic[T]): 137 | data: Tuple[T, ...] 138 | next_cursor: Optional[str] = None 139 | next: Optional[str] = None 140 | -------------------------------------------------------------------------------- /beaker/services/secret.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from ..data_model import * 4 | from ..exceptions import * 5 | from .service_client import ServiceClient 6 | 7 | 8 | class SecretClient(ServiceClient): 9 | """ 10 | Accessed via :data:`Beaker.secret `. 11 | """ 12 | 13 | def get(self, secret: str, workspace: Optional[Union[str, Workspace]] = None) -> Secret: 14 | """ 15 | Get metadata about a secret. 16 | 17 | :param secret: The name of the secret. 18 | :param workspace: The Beaker workspace ID, full name, or object. If not specified, 19 | :data:`Beaker.config.default_workspace ` is used. 20 | 21 | :raises SecretNotFound: If the secret doesn't exist. 22 | :raises WorkspaceNotSet: If neither ``workspace`` nor 23 | :data:`Beaker.config.default_workspace ` are set. 24 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 25 | :raises RequestException: Any other exception that can occur when contacting the 26 | Beaker server. 27 | """ 28 | workspace = self.resolve_workspace(workspace, read_only_ok=True) 29 | return Secret.from_json( 30 | self.request( 31 | f"workspaces/{workspace.id}/secrets/{self.url_quote(secret)}", 32 | method="GET", 33 | exceptions_for_status={404: SecretNotFound(secret)}, 34 | ).json() 35 | ) 36 | 37 | def read( 38 | self, secret: Union[str, Secret], workspace: Optional[Union[str, Workspace]] = None 39 | ) -> str: 40 | """ 41 | Read the value of a secret. 42 | 43 | :param secret: The secret name or object. 44 | :param workspace: The Beaker workspace ID, full name, or object. If not specified, 45 | :data:`Beaker.config.default_workspace ` is used. 46 | 47 | :raises SecretNotFound: If the secret doesn't exist. 48 | :raises WorkspaceNotSet: If neither ``workspace`` nor 49 | :data:`Beaker.config.default_workspace ` are set. 50 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 51 | :raises RequestException: Any other exception that can occur when contacting the 52 | Beaker server. 53 | """ 54 | workspace = self.resolve_workspace(workspace, read_only_ok=True) 55 | name = secret.name if isinstance(secret, Secret) else secret 56 | return self.request( 57 | f"workspaces/{workspace.id}/secrets/{self.url_quote(name)}/value", 58 | method="GET", 59 | ).content.decode() 60 | 61 | def write( 62 | self, name: str, value: str, workspace: Optional[Union[str, Workspace]] = None 63 | ) -> Secret: 64 | """ 65 | Write a new secret or update an existing one. 66 | 67 | :param name: The name of the secret. 68 | :param value: The value to write to the secret. 69 | :param workspace: The Beaker workspace ID, full name, or object. If not specified, 70 | :data:`Beaker.config.default_workspace ` is used. 71 | 72 | :raises WorkspaceNotSet: If neither ``workspace`` nor 73 | :data:`Beaker.config.default_workspace ` are set. 74 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 75 | :raises RequestException: Any other exception that can occur when contacting the 76 | Beaker server. 77 | """ 78 | workspace = self.resolve_workspace(workspace) 79 | return Secret.from_json( 80 | self.request( 81 | f"workspaces/{workspace.id}/secrets/{self.url_quote(name)}/value", 82 | method="PUT", 83 | data=value.encode(), 84 | ).json() 85 | ) 86 | 87 | def delete(self, secret: Union[str, Secret], workspace: Optional[Union[str, Workspace]] = None): 88 | """ 89 | Permanently delete a secret. 90 | 91 | :param secret: The secret name or object. 92 | :param workspace: The Beaker workspace ID, full name, or object. If not specified, 93 | :data:`Beaker.config.default_workspace ` is used. 94 | 95 | :raises SecretNotFound: If the secret doesn't exist. 96 | :raises WorkspaceNotSet: If neither ``workspace`` nor 97 | :data:`Beaker.config.default_workspace ` are set. 98 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 99 | :raises RequestException: Any other exception that can occur when contacting the 100 | Beaker server. 101 | """ 102 | workspace = self.resolve_workspace(workspace) 103 | name = secret.name if isinstance(secret, Secret) else secret 104 | return self.request( 105 | f"workspaces/{workspace.id}/secrets/{self.url_quote(name)}", 106 | method="DELETE", 107 | exceptions_for_status={404: SecretNotFound(secret)}, 108 | ) 109 | -------------------------------------------------------------------------------- /tests/workspace_test.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import pytest 4 | 5 | from beaker import ( 6 | Account, 7 | Beaker, 8 | BudgetNotFound, 9 | Permission, 10 | Workspace, 11 | WorkspaceNotFound, 12 | WorkspaceWriteError, 13 | ) 14 | 15 | 16 | def test_ensure_workspace_invalid_name(client: Beaker): 17 | with pytest.raises(ValueError, match="Invalid name"): 18 | client.workspace.ensure("blah&&") 19 | 20 | 21 | def test_workspace_get(client: Beaker, workspace_name: str): 22 | workspace = client.workspace.get(workspace_name) 23 | # Now get by ID. 24 | client.workspace.get(workspace.id) 25 | # Now get by name without the org prefix. 26 | client.workspace.get(workspace.name) 27 | 28 | 29 | @pytest.mark.parametrize("match", [pytest.param(v, id=f"match={v}") for v in (None, "squad")]) 30 | @pytest.mark.parametrize( 31 | "results", [pytest.param(v, id=f"results={v}") for v in (None, True, False)] 32 | ) 33 | @pytest.mark.parametrize( 34 | "uncommitted", [pytest.param(v, id=f"uncommitted={v}") for v in (None, True, False)] 35 | ) 36 | def test_workspace_datasets( 37 | client: Beaker, match: Optional[str], results: Optional[bool], uncommitted: Optional[bool] 38 | ): 39 | client.workspace.datasets(match=match, results=results, uncommitted=uncommitted, limit=50) 40 | 41 | 42 | def test_workspace_experiments(client: Beaker, hello_world_experiment_name: str): 43 | experiments = client.workspace.experiments(match=hello_world_experiment_name) 44 | assert experiments 45 | 46 | 47 | def test_workspace_images(client: Beaker): 48 | images = client.workspace.images(match="hello-world") 49 | assert images 50 | 51 | 52 | def test_workspace_list(client: Beaker, workspace_name: str): 53 | workspaces = client.workspace.list("ai2", match=workspace_name.split("/")[1]) 54 | assert workspaces 55 | 56 | 57 | def test_archived_workspace_write_error(client: Beaker, archived_workspace: Workspace): 58 | with pytest.raises(WorkspaceWriteError): 59 | client.workspace.archive(archived_workspace) 60 | with pytest.raises(WorkspaceWriteError): 61 | client.secret.write("foo", "bar", workspace=archived_workspace) 62 | 63 | 64 | def test_archived_workspace_read_ok(client: Beaker, archived_workspace: Workspace): 65 | client.workspace.secrets(archived_workspace) 66 | 67 | 68 | def test_organization_not_set(client: Beaker, archived_workspace: Workspace): 69 | client.config.default_org = None 70 | with pytest.raises(WorkspaceNotFound): 71 | client.workspace.secrets(archived_workspace.name) 72 | 73 | 74 | def test_workspace_move( 75 | client: Beaker, workspace_name: str, alternate_workspace_name: str, dataset_name: str 76 | ): 77 | dataset = client.dataset.create(dataset_name, workspace=alternate_workspace_name) 78 | assert dataset.workspace_ref.full_name == alternate_workspace_name 79 | client.workspace.move(dataset) 80 | assert client.dataset.get(dataset.id).workspace_ref.full_name == workspace_name 81 | 82 | 83 | def list_objects(client: Beaker, workspace: Optional[Union[str, Workspace]]): 84 | client.workspace.secrets(workspace=workspace) 85 | client.workspace.datasets(workspace=workspace, limit=2, results=False) 86 | client.workspace.experiments(workspace=workspace, limit=2, match="hello-world") 87 | client.workspace.images(workspace=workspace, limit=2, match="hello-world") 88 | 89 | 90 | def test_default_workspace_list_objects(client: Beaker): 91 | list_objects(client, None) 92 | 93 | 94 | def test_workspace_list_objects_with_id(client: Beaker, alternate_workspace: Workspace): 95 | list_objects(client, alternate_workspace.id) 96 | 97 | 98 | def test_workspace_list_objects_with_short_name(client: Beaker, alternate_workspace: Workspace): 99 | list_objects(client, alternate_workspace.name) 100 | 101 | 102 | def test_workspace_list_objects_with_full_name(client: Beaker, alternate_workspace: Workspace): 103 | list_objects(client, alternate_workspace.full_name) 104 | 105 | 106 | def test_workspace_list_objects_with_object(client: Beaker, alternate_workspace: Workspace): 107 | list_objects(client, alternate_workspace) 108 | 109 | 110 | def test_workspace_get_permissions(client: Beaker): 111 | client.workspace.get_permissions() 112 | 113 | 114 | @pytest.mark.skip("Requires admin credentials") 115 | def test_workspace_grant_and_revoke_permissions(client: Beaker, alternate_user: Account): 116 | client.workspace.grant_permissions(Permission.read, alternate_user) 117 | client.workspace.revoke_permissions(alternate_user) 118 | 119 | 120 | @pytest.mark.skip("Requires admin credentials") 121 | def test_workspace_set_visibility(client: Beaker): 122 | client.workspace.set_visibility(public=False) 123 | 124 | 125 | @pytest.mark.skip("Requires admin credentials") 126 | def test_workspace_set_visibility_archived(client: Beaker, archived_workspace_name: str): 127 | client.workspace.set_visibility(public=False, workspace=archived_workspace_name) 128 | 129 | 130 | def test_workspace_url(client: Beaker): 131 | assert ( 132 | client.workspace.url("ai2/beaker-py-testing") 133 | == "https://beaker.org/ws/ai2/beaker-py-testing" 134 | ) 135 | 136 | 137 | def test_resolve_budget(client: Beaker, budget_name: str, budget_id: str): 138 | assert client.workspace.resolve_budget(budget_name) == budget_id 139 | with pytest.raises(BudgetNotFound): 140 | client.workspace.resolve_budget("ai2/foo-bar") 141 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | import logging 8 | import os 9 | import sys 10 | from datetime import datetime 11 | 12 | # -- Path setup -------------------------------------------------------------- 13 | 14 | # If extensions (or modules to document with autodoc) are in another directory, 15 | # add these directories to sys.path here. If the directory is relative to the 16 | # documentation root, use os.path.abspath to make it absolute, like shown here. 17 | # 18 | sys.path.insert(0, os.path.abspath("../../")) 19 | 20 | from beaker.version import VERSION, VERSION_SHORT # noqa: E402 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = "beaker-py" 25 | copyright = f"{datetime.today().year}, Allen Institute for Artificial Intelligence" 26 | author = "Allen Institute for Artificial Intelligence" 27 | version = VERSION_SHORT 28 | release = VERSION 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | "sphinx.ext.autodoc", 38 | "sphinx.ext.napoleon", 39 | "myst_parser", 40 | "sphinx.ext.intersphinx", 41 | "sphinx.ext.viewcode", 42 | "sphinx_copybutton", 43 | "sphinx_autodoc_typehints", 44 | "sphinx_inline_tabs", 45 | ] 46 | 47 | # Tell myst-parser to assign header anchors for h1-h3. 48 | myst_heading_anchors = 3 49 | 50 | suppress_warnings = ["myst.header"] 51 | 52 | # Add any paths that contain templates here, relative to this directory. 53 | templates_path = ["_templates"] 54 | 55 | # List of patterns, relative to source directory, that match files and 56 | # directories to ignore when looking for source files. 57 | # This pattern also affects html_static_path and html_extra_path. 58 | exclude_patterns = ["_build"] 59 | 60 | source_suffix = [".rst", ".md"] 61 | 62 | intersphinx_mapping = { 63 | "python": ("https://docs.python.org/3", None), 64 | "docker": ("https://docker-py.readthedocs.io/en/stable/", None), 65 | "requests": ("https://requests.readthedocs.io/en/stable/", None), 66 | } 67 | 68 | # By default, sort documented members by type within classes and modules. 69 | autodoc_member_order = "bysource" 70 | autodoc_default_options = {"show-inheritance": True, "undoc-members": True} 71 | 72 | # Include default values when documenting parameter types. 73 | typehints_defaults = "comma" 74 | 75 | copybutton_prompt_text = r">>> |\.\.\. " 76 | copybutton_prompt_is_regexp = True 77 | 78 | # -- Options for HTML output ------------------------------------------------- 79 | 80 | # The theme to use for HTML and HTML Help pages. See the documentation for 81 | # a list of builtin themes. 82 | # 83 | html_theme = "furo" 84 | 85 | html_title = f"beaker-py v{VERSION}" 86 | 87 | # Add any paths that contain custom static files (such as style sheets) here, 88 | # relative to this directory. They are copied after the builtin static files, 89 | # so a file named "default.css" will overwrite the builtin "default.css". 90 | html_static_path = ["_static"] 91 | 92 | html_css_files = ["css/custom.css"] 93 | 94 | html_favicon = "_static/favicon.ico" 95 | 96 | html_theme_options = { 97 | "light_logo": "beaker-500px-transparent.png", 98 | "dark_logo": "beaker-500px-transparent.png", 99 | "footer_icons": [ 100 | { 101 | "name": "GitHub", 102 | "url": "https://github.com/allenai/beaker-py", 103 | "html": """ 104 | 105 | 106 | 107 | """, # noqa: E501 108 | "class": "", 109 | }, 110 | ], 111 | "announcement": "Important! These docs are for version 1 of beaker-py.", 112 | } 113 | 114 | # -- Hack to get rid of stupid warnings from sphinx_autodoc_typehints -------- 115 | 116 | 117 | class ShutupSphinxAutodocTypehintsFilter(logging.Filter): 118 | def filter(self, record: logging.LogRecord) -> bool: 119 | if "Cannot resolve forward reference" in record.msg: 120 | return False 121 | if "Failed guarded type import" in record.msg: 122 | return False 123 | return True 124 | 125 | 126 | logging.getLogger("sphinx.sphinx_autodoc_typehints").addFilter(ShutupSphinxAutodocTypehintsFilter()) 127 | 128 | 129 | def autodoc_skip_member(app, what, name, obj, skip, options): 130 | """ 131 | Skip documenting these Pydantic-specific attributes. 132 | """ 133 | del app, what, obj, skip, options 134 | exclude = name in {"model_config", "model_fields", "model_computed_fields"} 135 | return True if exclude else None 136 | 137 | 138 | def setup(app): 139 | app.connect("autodoc-skip-member", autodoc_skip_member) 140 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import uuid 3 | from pathlib import Path 4 | from typing import Generator 5 | 6 | import petname 7 | import pytest 8 | 9 | from beaker import exceptions 10 | from beaker.client import Beaker 11 | from beaker.data_model import * 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def unique_name() -> str: 17 | return petname.generate() + "-" + str(uuid.uuid4())[:8] # type: ignore 18 | 19 | 20 | def beaker_object_fixture(client: Beaker, service: str, prefix: str = ""): 21 | name = prefix + unique_name() 22 | service_client = getattr(client, service) 23 | not_found_exception = getattr(exceptions, f"{service.title()}NotFound") 24 | yield name 25 | try: 26 | logger.info("Attempting to remove %s '%s' from Beaker", service, name) 27 | service_client.delete(name) 28 | logger.info("Successfully deleted %s '%s' from Beaker", service, name) 29 | except not_found_exception: 30 | logger.info("%s '%s' not found on Beaker", service.title(), name) 31 | 32 | 33 | @pytest.fixture() 34 | def workspace_name() -> str: 35 | name = "ai2/beaker-py-testing" 36 | return name 37 | 38 | 39 | @pytest.fixture() 40 | def budget_name() -> str: 41 | return "ai2/oe-training" 42 | 43 | 44 | @pytest.fixture() 45 | def budget_id() -> str: 46 | return "01HPM52AQC9E4NGTDV4K34CYTX" 47 | 48 | 49 | @pytest.fixture() 50 | def alternate_workspace_name() -> str: 51 | name = "ai2/beaker-py-testing-alternative" 52 | return name 53 | 54 | 55 | @pytest.fixture() 56 | def client(workspace_name): 57 | beaker_client = Beaker.from_env( 58 | session=True, default_workspace=workspace_name, default_org="ai2" 59 | ) 60 | return beaker_client 61 | 62 | 63 | @pytest.fixture() 64 | def alternate_workspace(client: Beaker, alternate_workspace_name: str) -> Workspace: 65 | return client.workspace.get(alternate_workspace_name) 66 | 67 | 68 | @pytest.fixture 69 | def beaker_org_name() -> str: 70 | return "ai2" 71 | 72 | 73 | @pytest.fixture() 74 | def beaker_org(client: Beaker, beaker_org_name: str) -> Organization: 75 | return client.organization.get(beaker_org_name) 76 | 77 | 78 | @pytest.fixture() 79 | def docker_image_name(client: Beaker): 80 | image = "hello-world" 81 | client.docker.images.pull(image) 82 | return image 83 | 84 | 85 | @pytest.fixture() 86 | def beaker_image_name(client: Beaker) -> Generator[str, None, None]: 87 | yield from beaker_object_fixture(client, "image") 88 | 89 | 90 | @pytest.fixture() 91 | def beaker_python_image_name() -> str: 92 | return "petew/python-3-10-alpine" 93 | 94 | 95 | @pytest.fixture() 96 | def alternate_beaker_image_name(client: Beaker) -> Generator[str, None, None]: 97 | yield from beaker_object_fixture(client, "image") 98 | 99 | 100 | @pytest.fixture() 101 | def beaker_cluster_name() -> str: 102 | return "ai2/canary" 103 | 104 | 105 | @pytest.fixture() 106 | def beaker_on_prem_cluster_name() -> str: 107 | return "ai2/jupiter-cirrascale-2" 108 | 109 | 110 | @pytest.fixture() 111 | def experiment_name(client: Beaker) -> Generator[str, None, None]: 112 | yield from beaker_object_fixture(client, "experiment") 113 | 114 | 115 | @pytest.fixture() 116 | def alternate_experiment_name(client: Beaker) -> Generator[str, None, None]: 117 | yield from beaker_object_fixture(client, "experiment") 118 | 119 | 120 | @pytest.fixture() 121 | def dataset_name(client: Beaker) -> Generator[str, None, None]: 122 | yield from beaker_object_fixture(client, "dataset") 123 | 124 | 125 | @pytest.fixture() 126 | def alternate_dataset_name(client: Beaker) -> Generator[str, None, None]: 127 | yield from beaker_object_fixture(client, "dataset") 128 | 129 | 130 | @pytest.fixture() 131 | def download_path(dataset_name, tmp_path) -> Path: 132 | path = tmp_path / dataset_name 133 | return path 134 | 135 | 136 | @pytest.fixture() 137 | def hello_world_experiment_name() -> str: 138 | return "hello-world-1" 139 | 140 | 141 | @pytest.fixture() 142 | def hello_world_experiment_id() -> str: 143 | return "01GRYY998GG0VP97MKRE574GKA" 144 | 145 | 146 | @pytest.fixture() 147 | def hello_world_image_name() -> str: 148 | return "petew/hello-world" 149 | 150 | 151 | @pytest.fixture() 152 | def hello_world_job_id() -> str: 153 | return "01GRYY9P9G5ZJ0F66NV3AHN9AN" 154 | 155 | 156 | @pytest.fixture() 157 | def beaker_node_id(client: Beaker, beaker_on_prem_cluster_name: str) -> str: 158 | return client.cluster.nodes(beaker_on_prem_cluster_name)[0].id 159 | 160 | 161 | @pytest.fixture() 162 | def secret_name(client: Beaker) -> Generator[str, None, None]: 163 | yield from beaker_object_fixture(client, "secret") 164 | 165 | 166 | @pytest.fixture() 167 | def archived_workspace_name() -> str: 168 | return "ai2/beaker-py-testing-archived" 169 | 170 | 171 | @pytest.fixture() 172 | def archived_workspace(client: Beaker, archived_workspace_name: str) -> Workspace: 173 | workspace = client.workspace.ensure(archived_workspace_name) 174 | if not workspace.archived: 175 | return client.workspace.archive(archived_workspace_name) 176 | else: 177 | return workspace 178 | 179 | 180 | @pytest.fixture() 181 | def squad_dataset_file_name() -> str: 182 | return "squad-train.arrow" 183 | 184 | 185 | @pytest.fixture() 186 | def squad_dataset_name(client: Beaker, squad_dataset_file_name) -> Generator[str, None, None]: 187 | for dataset_name in beaker_object_fixture(client, "dataset", prefix="squad"): 188 | dataset = client.dataset.create(dataset_name, commit=False) 189 | client.dataset.upload(dataset, b"blahblahblah", squad_dataset_file_name) 190 | client.dataset.commit(dataset) 191 | yield dataset_name 192 | 193 | 194 | @pytest.fixture() 195 | def alternate_user(client: Beaker) -> Account: 196 | return client.account.get("epwalsh10") 197 | 198 | 199 | @pytest.fixture() 200 | def group_name(client: Beaker) -> Generator[str, None, None]: 201 | yield from beaker_object_fixture(client, "group") 202 | 203 | 204 | @pytest.fixture() 205 | def alternate_group_name(client: Beaker) -> Generator[str, None, None]: 206 | yield from beaker_object_fixture(client, "group") 207 | 208 | 209 | @pytest.fixture() 210 | def experiment_id_with_metrics() -> str: 211 | return "01G371J03VGJGK720TMZWFQNV3" 212 | 213 | 214 | # experiment was deleted 215 | # @pytest.fixture() 216 | # def experiment_id_with_results() -> str: 217 | # return "01G371J03VGJGK720TMZWFQNV3" 218 | -------------------------------------------------------------------------------- /integration_tests/datasets_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import pytest 7 | 8 | from beaker.client import Beaker, DatasetClient 9 | from beaker.exceptions import DatasetWriteError 10 | 11 | 12 | class TestDataset: 13 | def setup_method(self): 14 | self.file_a = tempfile.NamedTemporaryFile(delete=False) 15 | self.file_a_contents = b"a" * 10 16 | self.file_a.write(self.file_a_contents) 17 | self.file_a.seek(0) 18 | self.file_a.close() 19 | 20 | self.file_b = tempfile.NamedTemporaryFile(delete=False) 21 | self.file_b_contents = b"b" * 10 22 | self.file_b.write(self.file_b_contents) 23 | self.file_b.seek(0) 24 | self.file_b.close() 25 | 26 | def teardown_method(self): 27 | os.remove(self.file_a.name) 28 | os.remove(self.file_b.name) 29 | 30 | def test_dataset_write_error(self, client: Beaker, dataset_name: str): 31 | dataset = client.dataset.create(dataset_name, self.file_a.name, commit=True) 32 | with pytest.raises(DatasetWriteError): 33 | client.dataset.sync(dataset, self.file_b.name) 34 | 35 | def test_dataset_basics(self, client: Beaker, dataset_name: str, alternate_dataset_name: str): 36 | # Test create dataset without budget parameter 37 | dataset = client.dataset.create( 38 | dataset_name, 39 | self.file_a.name, 40 | self.file_b.name, 41 | commit=True, 42 | description="Testing dataset", 43 | ) 44 | assert dataset.name == dataset_name 45 | 46 | # Stream the whole thing at once. 47 | contents = b"".join(list(client.dataset.stream_file(dataset, Path(self.file_a.name).name))) 48 | assert contents == self.file_a_contents 49 | 50 | # Stream just part of the file. 51 | contents = b"".join( 52 | list(client.dataset.stream_file(dataset, Path(self.file_a.name).name, offset=5)) 53 | ) 54 | assert contents == self.file_a_contents[5:] 55 | 56 | # Calculate the size. 57 | assert client.dataset.size(dataset) == 20 58 | 59 | # Rename it. 60 | dataset = client.dataset.rename(dataset, alternate_dataset_name) 61 | assert dataset.name == alternate_dataset_name 62 | 63 | # Test with dataset creation with specified budget parameter 64 | dataset_with_budget = client.dataset.create( 65 | alternate_dataset_name + "-budget", 66 | self.file_a.name, 67 | budget="ai2/compute", 68 | commit=True, 69 | description="Testing dataset with budget", 70 | ) 71 | assert dataset_with_budget.name == alternate_dataset_name + "-budget" 72 | 73 | contents = b"".join( 74 | list(client.dataset.stream_file(dataset_with_budget, Path(self.file_a.name).name)) 75 | ) 76 | assert contents == self.file_a_contents 77 | 78 | client.dataset.delete(dataset_with_budget) 79 | 80 | 81 | class TestLargeFileDataset: 82 | def setup_method(self): 83 | self.original_size_limit = DatasetClient.REQUEST_SIZE_LIMIT 84 | DatasetClient.REQUEST_SIZE_LIMIT = 1024 85 | self.large_file = tempfile.NamedTemporaryFile(delete=False) 86 | self.large_file_contents = b"a" * 1024 * 2 87 | self.large_file.write(self.large_file_contents) 88 | self.large_file.close() 89 | 90 | def teardown_method(self): 91 | DatasetClient.REQUEST_SIZE_LIMIT = self.original_size_limit 92 | os.remove(self.large_file.name) 93 | 94 | @pytest.mark.parametrize( 95 | "commit_right_away", 96 | (pytest.param(True, id="commit now"), pytest.param(False, id="commit later")), 97 | ) 98 | def test_large_file_dataset( 99 | self, client: Beaker, dataset_name: str, tmp_path: Path, commit_right_away: bool 100 | ): 101 | # Create the dataset. 102 | dataset = client.dataset.create( 103 | dataset_name, self.large_file.name, commit=commit_right_away 104 | ) 105 | if not commit_right_away: 106 | dataset = client.dataset.commit(dataset) 107 | 108 | # Verify fields. 109 | assert dataset.name == dataset_name 110 | assert dataset.committed is not None 111 | 112 | # Fetch the dataset. 113 | client.dataset.fetch(dataset, target=tmp_path) 114 | large_file_path = tmp_path / self.large_file.name 115 | assert large_file_path.is_file(), f"{list(tmp_path.iterdir())}" 116 | with open(large_file_path, "rb") as large_file: 117 | contents = large_file.read() 118 | assert contents == self.large_file_contents 119 | 120 | 121 | class TestManyFilesDataset: 122 | @pytest.mark.parametrize( 123 | "target", 124 | (pytest.param("target_dir", id="target dir"), pytest.param(None, id="no target dir")), 125 | ) 126 | def test_many_files_dataset( 127 | self, client: Beaker, dataset_name: str, tmp_path: Path, target: Optional[str] 128 | ): 129 | # Create the local sources. 130 | dir_to_upload = tmp_path / "dataset_dir" 131 | dir_to_upload.mkdir() 132 | for i in range(100): 133 | (dir_to_upload / f"file{i}.txt").open("w").write(str(i)) 134 | file_to_upload = tmp_path / "dataset_file.txt" 135 | file_to_upload.open("w").write("Hello, World!") 136 | 137 | # Create the dataset. 138 | dataset = client.dataset.create(dataset_name, dir_to_upload, file_to_upload, target=target) 139 | 140 | # List files in the dataset. 141 | files = list(client.dataset.ls(dataset)) 142 | assert len(files) == 101 143 | for file_info in files: 144 | if target is not None: 145 | assert file_info.path.startswith(target) 146 | assert file_info.path.endswith(".txt") 147 | 148 | # Download the dataset. 149 | download_dir = tmp_path / "download" 150 | client.dataset.fetch(dataset, target=download_dir) 151 | 152 | base_dir = (download_dir / target) if target is not None else download_dir 153 | for i in range(100): 154 | downloaded = base_dir / f"file{i}.txt" 155 | assert downloaded.is_file() 156 | assert downloaded.open("r").read() == str(i) 157 | assert (base_dir / "dataset_file.txt").is_file() 158 | assert (base_dir / "dataset_file.txt").open("r").read() == "Hello, World!" 159 | -------------------------------------------------------------------------------- /beaker/services/organization.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from ..data_model import * 4 | from ..exceptions import * 5 | from .service_client import ServiceClient 6 | 7 | 8 | class OrganizationClient(ServiceClient): 9 | """ 10 | Accessed via :data:`Beaker.organization `. 11 | """ 12 | 13 | def get(self, org: Optional[str] = None) -> Organization: 14 | """ 15 | Get information about an organization. 16 | 17 | :param org: The organization name or ID. If not specified, 18 | :data:`Beaker.config.default_org ` is used. 19 | 20 | :raises OrganizationNotFound: If the organization doesn't exist. 21 | :raises OrganizationNotSet: If neither ``org`` nor 22 | :data:`Beaker.config.default_org ` are set. 23 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 24 | :raises RequestException: Any other exception that can occur when contacting the 25 | Beaker server. 26 | """ 27 | org = org or self.config.default_org 28 | if org is None: 29 | raise OrganizationNotSet("'org' argument required since default org not set") 30 | 31 | return Organization.from_json( 32 | self.request( 33 | f"orgs/{self.url_quote(org)}", 34 | method="GET", 35 | exceptions_for_status={404: OrganizationNotFound(org)}, 36 | ).json() 37 | ) 38 | 39 | def add_member( 40 | self, account: Union[str, Account], org: Optional[Union[str, Organization]] = None 41 | ) -> OrganizationMember: 42 | """ 43 | Add an account to an organization. 44 | 45 | :param account: The account name or object. 46 | :param org: The organization name or object. If not specified, 47 | :data:`Beaker.config.default_org ` is used. 48 | 49 | :raises OrganizationNotFound: If the organization doesn't exist. 50 | :raises OrganizationNotSet: If neither ``org`` nor 51 | :data:`Beaker.config.default_org ` are set. 52 | :raises AccountNotFound: If the account doesn't exist. 53 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 54 | :raises RequestException: Any other exception that can occur when contacting the 55 | Beaker server. 56 | """ 57 | org = self.resolve_org(org) 58 | account_name = account if isinstance(account, str) else account.name 59 | self.request( 60 | f"orgs/{self.url_quote(org.name)}/members/{account_name}", 61 | method="PUT", 62 | exceptions_for_status={404: AccountNotFound(account_name)}, 63 | ) 64 | return self.get_member(account_name, org=org) 65 | 66 | def get_member( 67 | self, account: Union[str, Account], org: Optional[Union[str, Organization]] = None 68 | ) -> OrganizationMember: 69 | """ 70 | Get information about an organization member. 71 | 72 | :param account: The account name or object. 73 | :param org: The organization name or object. If not specified, 74 | :data:`Beaker.config.default_org ` is used. 75 | 76 | :raises OrganizationNotFound: If the organization doesn't exist. 77 | :raises OrganizationNotSet: If neither ``org`` nor 78 | :data:`Beaker.config.default_org ` are set. 79 | :raises AccountNotFound: If the account doesn't exist or isn't a member of the org. 80 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 81 | :raises RequestException: Any other exception that can occur when contacting the 82 | Beaker server. 83 | """ 84 | org = self.resolve_org(org) 85 | account_name = account if isinstance(account, str) else account.name 86 | return OrganizationMember.from_json( 87 | self.request( 88 | f"orgs/{self.url_quote(org.name)}/members/{account_name}", 89 | method="GET", 90 | exceptions_for_status={404: AccountNotFound(account_name)}, 91 | ).json() 92 | ) 93 | 94 | def list_members(self, org: Optional[Union[str, Organization]] = None) -> List[Account]: 95 | """ 96 | List members of an organization. 97 | 98 | :param org: The organization name or object. If not specified, 99 | :data:`Beaker.config.default_org ` is used. 100 | 101 | :raises OrganizationNotFound: If the organization doesn't exist. 102 | :raises OrganizationNotSet: If neither ``org`` nor 103 | :data:`Beaker.config.default_org ` are set. 104 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 105 | :raises RequestException: Any other exception that can occur when contacting the 106 | Beaker server. 107 | """ 108 | org = self.resolve_org(org) 109 | return [ 110 | Account.from_json(d) 111 | for d in self.request( 112 | f"orgs/{self.url_quote(org.name)}/members", 113 | method="GET", 114 | exceptions_for_status={404: OrganizationNotFound(org.name)}, 115 | ).json()["data"] 116 | ] 117 | 118 | def remove_member( 119 | self, account: Union[str, Account], org: Optional[Union[str, Organization]] = None 120 | ): 121 | """ 122 | Remove a member from an organization. 123 | 124 | :param account: The account name or object. 125 | :param org: The organization name or object. If not specified, 126 | :data:`Beaker.config.default_org ` is used. 127 | 128 | :raises OrganizationNotFound: If the organization doesn't exist. 129 | :raises OrganizationNotSet: If neither ``org`` nor 130 | :data:`Beaker.config.default_org ` are set. 131 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 132 | :raises RequestException: Any other exception that can occur when contacting the 133 | Beaker server. 134 | """ 135 | org = self.resolve_org(org) 136 | account_name = account if isinstance(account, str) else account.name 137 | self.request( 138 | f"orgs/{self.url_quote(org.name)}/members/{account_name}", 139 | method="DELETE", 140 | exceptions_for_status={404: AccountNotFound(account_name)}, 141 | ) 142 | -------------------------------------------------------------------------------- /beaker/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import warnings 5 | from dataclasses import asdict, dataclass, fields 6 | from pathlib import Path 7 | from typing import ClassVar, Optional, Set 8 | 9 | import yaml 10 | 11 | from .exceptions import ConfigurationError 12 | 13 | DEFAULT_CONFIG_LOCATION: Optional[Path] = None 14 | DEFAULT_INTERNAL_CONFIG_LOCATION: Optional[Path] = None 15 | try: 16 | DEFAULT_CONFIG_LOCATION = Path.home() / ".beaker" / "config.yml" 17 | DEFAULT_INTERNAL_CONFIG_LOCATION = Path.home() / ".beaker" / ".beaker-py.json" 18 | except RuntimeError: 19 | # Can't locate home directory. 20 | pass 21 | 22 | 23 | __all__ = ["Config"] 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | @dataclass 29 | class Config: 30 | user_token: str 31 | """ 32 | Beaker user token that can be obtained from 33 | `beaker.org `_. 34 | """ 35 | 36 | agent_address: str = "https://beaker.org" 37 | """ 38 | The address of the Beaker server. 39 | """ 40 | 41 | default_org: Optional[str] = "ai2" 42 | """ 43 | Default Beaker organization to use. 44 | """ 45 | 46 | default_workspace: Optional[str] = None 47 | """ 48 | Default Beaker workspace to use. 49 | """ 50 | 51 | default_image: Optional[str] = None 52 | """ 53 | The default image used for interactive sessions. 54 | """ 55 | 56 | ADDRESS_KEY: ClassVar[str] = "BEAKER_ADDR" 57 | CONFIG_PATH_KEY: ClassVar[str] = "BEAKER_CONFIG" 58 | TOKEN_KEY: ClassVar[str] = "BEAKER_TOKEN" 59 | IGNORE_FIELDS: ClassVar[Set[str]] = {"updater_timestamp", "updater_message"} 60 | 61 | @property 62 | def rpc_address(self) -> str: 63 | # TODO: hard-coded for now since this isn't part of the Beaker YAML configs. 64 | return "beaker.org:443" 65 | 66 | def __str__(self) -> str: 67 | fields_str = "user_token=***, " + ", ".join( 68 | [f"{f.name}={getattr(self, f.name)}" for f in fields(self) if f.name != "user_token"] 69 | ) 70 | return f"{self.__class__.__name__}({fields_str})" 71 | 72 | @classmethod 73 | def from_env(cls, **overrides) -> "Config": 74 | """ 75 | Initialize a config from environment variables or a local config file if one 76 | can be found. 77 | 78 | .. note:: 79 | Environment variables take precedence over values in the config file. 80 | 81 | """ 82 | config: Config 83 | 84 | path = cls.find_config() 85 | if path is not None: 86 | config = cls.from_path(path) 87 | if cls.TOKEN_KEY in os.environ: 88 | config.user_token = os.environ[cls.TOKEN_KEY] 89 | elif cls.TOKEN_KEY in os.environ: 90 | config = cls( 91 | user_token=os.environ[cls.TOKEN_KEY], 92 | ) 93 | elif "user_token" in overrides: 94 | config = cls(user_token=overrides["user_token"]) 95 | else: 96 | raise ConfigurationError( 97 | f"Failed to find config file or environment variable '{cls.TOKEN_KEY}'" 98 | ) 99 | 100 | # Override with environment variables. 101 | if cls.ADDRESS_KEY in os.environ: 102 | config.agent_address = os.environ[cls.ADDRESS_KEY] 103 | 104 | # Override with any arguments passed to this method. 105 | for name, value in overrides.items(): 106 | if hasattr(config, name): 107 | setattr(config, name, value) 108 | else: 109 | raise ConfigurationError(f"Beaker config has no attribute '{name}'") 110 | 111 | if not config.user_token: 112 | raise ConfigurationError("Invalid Beaker user token, token is empty") 113 | 114 | return config 115 | 116 | @classmethod 117 | def from_path(cls, path: Path) -> "Config": 118 | """ 119 | Initialize a config from a local config file. 120 | """ 121 | with open(path) as config_file: 122 | logger.debug("Loading beaker config from '%s'", path) 123 | field_names = {f.name for f in fields(cls)} 124 | data = yaml.load(config_file, Loader=yaml.SafeLoader) 125 | for key in list(data.keys()): 126 | if key in cls.IGNORE_FIELDS: 127 | data.pop(key) 128 | continue 129 | value = data[key] 130 | if key not in field_names: 131 | del data[key] 132 | warnings.warn( 133 | f"Unknown field '{key}' found in config '{path}'. " 134 | f"If this is a bug, please report it at https://github.com/allenai/beaker-py/issues/new/", 135 | RuntimeWarning, 136 | ) 137 | elif isinstance(value, str) and value == "": 138 | # Replace empty strings with `None` 139 | data[key] = None 140 | return cls(**data) 141 | 142 | def save(self, path: Optional[Path] = None): 143 | """ 144 | Save the config to the given path. 145 | """ 146 | if path is None: 147 | if self.CONFIG_PATH_KEY in os.environ: 148 | path = Path(os.environ[self.CONFIG_PATH_KEY]) 149 | elif DEFAULT_CONFIG_LOCATION is not None: 150 | path = DEFAULT_CONFIG_LOCATION 151 | if path is None: 152 | raise ValueError("param 'path' is required") 153 | path.parent.mkdir(parents=True, exist_ok=True) 154 | with open(path, "w") as config_file: 155 | yaml.dump(asdict(self), config_file) 156 | 157 | @classmethod 158 | def find_config(cls) -> Optional[Path]: 159 | if cls.CONFIG_PATH_KEY in os.environ: 160 | path = Path(os.environ[cls.CONFIG_PATH_KEY]) 161 | if path.is_file(): 162 | return path 163 | elif DEFAULT_CONFIG_LOCATION is not None and DEFAULT_CONFIG_LOCATION.is_file(): 164 | return DEFAULT_CONFIG_LOCATION 165 | 166 | return None 167 | 168 | 169 | @dataclass 170 | class InternalConfig: 171 | version_checked: Optional[float] = None 172 | 173 | @classmethod 174 | def load(cls) -> Optional["InternalConfig"]: 175 | path = DEFAULT_INTERNAL_CONFIG_LOCATION 176 | if path is None: 177 | return None 178 | elif path.is_file(): 179 | with open(path, "r") as f: 180 | return cls(**json.load(f)) 181 | else: 182 | return cls() 183 | 184 | def save(self): 185 | path = DEFAULT_INTERNAL_CONFIG_LOCATION 186 | if path is None: 187 | return None 188 | else: 189 | path.parent.mkdir(exist_ok=True, parents=True) 190 | with open(path, "w") as f: 191 | json.dump(asdict(self), f) 192 | -------------------------------------------------------------------------------- /beaker/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Exceptions that can be raised by the :class:`~beaker.Beaker` client. 3 | 4 | .. tip:: 5 | All exceptions inherit from :class:`BeakerError` other than :exc:`HTTPError`, 6 | which is re-exported from :exc:`requests.exceptions.HTTPError`, 7 | and :exc:`ValidationError`, which is re-exported from `pydantic `_. 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | from typing import TYPE_CHECKING, Optional 13 | 14 | from grpc import RpcError 15 | from pydantic import ValidationError # noqa: F401, re-imported here for convenience 16 | from requests.exceptions import ( # noqa: F401, re-imported here for convenience 17 | HTTPError, 18 | RequestException, 19 | ) 20 | 21 | if TYPE_CHECKING: 22 | from .data_model.experiment import Task 23 | from .data_model.job import Job 24 | 25 | ValidationError.__doc__ = """ 26 | Raised when data passed into a :mod:`DataModel ` is invalid. 27 | """ 28 | 29 | 30 | __all__ = [ 31 | "BeakerError", 32 | "ValidationError", 33 | "HTTPError", 34 | "RpcError", 35 | "RequestException", 36 | "BeakerPermissionsError", 37 | "NotFoundError", 38 | "AccountNotFound", 39 | "OrganizationNotFound", 40 | "OrganizationNotSet", 41 | "BudgetNotFound", 42 | "ConfigurationError", 43 | "ImageNotFound", 44 | "ImageConflict", 45 | "WorkspaceNotFound", 46 | "WorkspaceWriteError", 47 | "WorkspaceConflict", 48 | "ClusterNotFound", 49 | "ClusterConflict", 50 | "ExperimentNotFound", 51 | "ExperimentConflict", 52 | "DatasetConflict", 53 | "DatasetNotFound", 54 | "UnexpectedEOFError", 55 | "JobNotFound", 56 | "WorkspaceNotSet", 57 | "NodeNotFound", 58 | "DatasetWriteError", 59 | "DatasetReadError", 60 | "SecretNotFound", 61 | "GroupConflict", 62 | "GroupNotFound", 63 | "DuplicateJobError", 64 | "DuplicateExperimentError", 65 | "TaskNotFound", 66 | "ChecksumFailedError", 67 | "TaskStoppedError", 68 | "JobFailedError", 69 | "JobTimeoutError", 70 | "ExperimentSpecError", 71 | "ThreadCanceledError", 72 | ] 73 | 74 | 75 | class BeakerError(Exception): 76 | """ 77 | Base class for all Beaker errors other than :exc:`HTTPError`, which is re-exported 78 | from :exc:`requests.exceptions.HTTPError`, and :exc:`ValidationError`, which is 79 | re-exported from `pydantic `_. 80 | """ 81 | 82 | 83 | class BeakerPermissionsError(BeakerError): 84 | """ 85 | Raised when a user doesn't have sufficient permissions to perform an action. 86 | """ 87 | 88 | 89 | class NotFoundError(BeakerError): 90 | """ 91 | Base class for all "not found" error types. 92 | """ 93 | 94 | 95 | class AccountNotFound(NotFoundError): 96 | pass 97 | 98 | 99 | class OrganizationNotFound(NotFoundError): 100 | """ 101 | Raised when a specified organization doesn't exist. 102 | """ 103 | 104 | 105 | class OrganizationNotSet(BeakerError): 106 | """ 107 | Raised when an identifying doesn't start with an organization name and 108 | :data:`Config.default_org ` is not set. 109 | """ 110 | 111 | 112 | class BudgetNotFound(NotFoundError): 113 | """ 114 | Raised when a specified budget doesn't exist. 115 | """ 116 | 117 | 118 | class ConfigurationError(BeakerError): 119 | """ 120 | Raised when the :class:`~beaker.Config` fails to instantiate. 121 | """ 122 | 123 | 124 | class ImageNotFound(NotFoundError): 125 | pass 126 | 127 | 128 | class ImageConflict(BeakerError): 129 | """ 130 | Raised when attempting to create/rename an image if an image by that name already exists. 131 | """ 132 | 133 | 134 | class WorkspaceNotFound(NotFoundError): 135 | pass 136 | 137 | 138 | class WorkspaceWriteError(BeakerError): 139 | """ 140 | Raised when attempting to modify or add to a workspace that's been archived. 141 | """ 142 | 143 | 144 | class WorkspaceConflict(BeakerError): 145 | """ 146 | Raised when attempting to create/rename a workspace if a workspace by that name already exists. 147 | """ 148 | 149 | 150 | class ClusterNotFound(NotFoundError): 151 | pass 152 | 153 | 154 | class ClusterConflict(BeakerError): 155 | """ 156 | Raised when attempting to create a cluster if a cluster by that name already exists. 157 | """ 158 | 159 | 160 | class ExperimentNotFound(NotFoundError): 161 | pass 162 | 163 | 164 | class ExperimentConflict(BeakerError): 165 | """ 166 | Raised when attempting to create/rename/stop an experiment that already exists or is already stopped. 167 | """ 168 | 169 | 170 | class DatasetConflict(BeakerError): 171 | """ 172 | Raised when attempting to create/rename a dataset if a dataset by that name already exists. 173 | """ 174 | 175 | 176 | class DatasetNotFound(NotFoundError): 177 | pass 178 | 179 | 180 | class UnexpectedEOFError(BeakerError): 181 | """ 182 | Raised when creating a dataset when an empty source file is encountered. 183 | """ 184 | 185 | 186 | class JobNotFound(NotFoundError): 187 | pass 188 | 189 | 190 | class WorkspaceNotSet(BeakerError): 191 | """ 192 | Raised when workspace argument is not provided and there is no default workspace set. 193 | """ 194 | 195 | 196 | class NodeNotFound(NotFoundError): 197 | pass 198 | 199 | 200 | class DatasetWriteError(BeakerError): 201 | """ 202 | Raised when a write operation on a dataset fails because the dataset has already been committed. 203 | """ 204 | 205 | 206 | class DatasetReadError(BeakerError): 207 | """ 208 | Raised when a read operation on a dataset fails because the dataset hasn't been committed yet, 209 | or the :data:`~beaker.data_model.Dataset.storage` hasn't been set for some other reason. 210 | """ 211 | 212 | 213 | class SecretNotFound(NotFoundError): 214 | pass 215 | 216 | 217 | class GroupConflict(BeakerError): 218 | """ 219 | Raised when attempting to create/rename a group if a group by that name already exists. 220 | """ 221 | 222 | 223 | class GroupNotFound(NotFoundError): 224 | pass 225 | 226 | 227 | class DuplicateJobError(BeakerError): 228 | """ 229 | Raised when duplicate jobs are passed into a method that expects unique jobs. 230 | """ 231 | 232 | 233 | class DuplicateExperimentError(BeakerError): 234 | """ 235 | Raised when duplicate experiments are passed into a method that expects unique experiments. 236 | """ 237 | 238 | 239 | class TaskNotFound(NotFoundError): 240 | pass 241 | 242 | 243 | class ChecksumFailedError(BeakerError): 244 | """ 245 | Raised when a downloaded file from a Beaker dataset is corrupted. 246 | """ 247 | 248 | 249 | class TaskStoppedError(BeakerError): 250 | def __init__(self, msg: Optional[str] = None, task: Optional[Task] = None): 251 | super().__init__(msg) 252 | self.task = task 253 | 254 | 255 | class JobFailedError(BeakerError): 256 | def __init__(self, msg: Optional[str] = None, job: Optional[Job] = None): 257 | super().__init__(msg) 258 | self.job = job 259 | 260 | 261 | class JobTimeoutError(BeakerError, TimeoutError): 262 | pass 263 | 264 | 265 | class ExperimentSpecError(BeakerError): 266 | pass 267 | 268 | 269 | class ThreadCanceledError(BeakerError): 270 | pass 271 | -------------------------------------------------------------------------------- /beaker/util.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import re 3 | import time 4 | import warnings 5 | from collections import OrderedDict 6 | from datetime import datetime, timedelta, timezone 7 | from functools import wraps 8 | from pathlib import Path 9 | from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union 10 | 11 | from .aliases import PathOrStr 12 | from .exceptions import RequestException 13 | 14 | BUG_REPORT_URL = ( 15 | "https://github.com/allenai/beaker-py/issues/new?assignees=&labels=bug&template=bug_report.yml" 16 | ) 17 | 18 | _VALIDATION_WARNINGS_ISSUED: Set[Tuple[str, str]] = set() 19 | 20 | 21 | def issue_data_model_warning(cls: Type, key: str, value: Any): 22 | warn_about = (cls.__name__, key) 23 | if warn_about not in _VALIDATION_WARNINGS_ISSUED: 24 | _VALIDATION_WARNINGS_ISSUED.add(warn_about) 25 | warnings.warn( 26 | f"Found unknown field '{key}: {value}' for data model '{cls.__name__}'. " 27 | "This may be a newly added field that hasn't been defined in beaker-py yet. " 28 | "Please submit an issue report about this here:\n" 29 | f"{BUG_REPORT_URL}", 30 | RuntimeWarning, 31 | ) 32 | 33 | 34 | def to_lower_camel(s: str) -> str: 35 | """ 36 | Convert a snake-case string into lower camel case. 37 | """ 38 | parts = s.split("_") 39 | out = parts[0] + "".join([p.title() for p in parts[1:]]) 40 | out = re.sub(r"(^|[a-z0-9])Id($|[A-Z0-9])", r"\g<1>ID\g<2>", out) 41 | return out 42 | 43 | 44 | def to_snake_case(s: str) -> str: 45 | """ 46 | Convert a lower camel case strings into snake case. 47 | """ 48 | if s.islower(): 49 | return s 50 | s = re.sub(r"(^|[a-z0-9])ID", r"\g<1>Id", s) 51 | parts = [] 52 | for c in s: 53 | if c.isupper(): 54 | parts.append("_") 55 | parts.append(c.lower()) 56 | return "".join(parts) 57 | 58 | 59 | def path_is_relative_to(path: Path, other: PathOrStr) -> bool: 60 | """ 61 | This is copied from :meth:`pathlib.PurePath.is_relative_to` to support older Python 62 | versions (before 3.9, when this method was introduced). 63 | """ 64 | try: 65 | path.relative_to(other) 66 | return True 67 | except ValueError: 68 | return False 69 | 70 | 71 | T = TypeVar("T") 72 | 73 | _property_cache: "OrderedDict[Tuple[str, str], Tuple[float, Any]]" = OrderedDict() 74 | _property_cache_max_size = 50 75 | 76 | 77 | def cached_property(ttl: float = 60): 78 | """ 79 | This is used to create a cached property on a :class:`~beaker.services.service_client.ServiceClient` 80 | subclass. 81 | 82 | :param ttl: The time-to-live in seconds. The cached value will be evicted from the cache 83 | after this many seconds to ensure it stays fresh. 84 | 85 | See :meth:`~beaker.services.account.AccountClient.name`, for example. 86 | """ 87 | 88 | def ttl_cached_property(prop) -> property: 89 | @property # type: ignore[misc] 90 | def prop_with_cache(self): 91 | key = (prop.__qualname__, repr(self.config)) 92 | cached = _property_cache.get(key) 93 | if cached is not None: 94 | time_cached, value = cached 95 | if time.monotonic() - time_cached <= ttl: 96 | return value 97 | value = prop(self) 98 | _property_cache[key] = (time.monotonic(), value) 99 | while len(_property_cache) > _property_cache_max_size: 100 | _property_cache.popitem(last=False) 101 | return value 102 | 103 | return prop_with_cache # type: ignore[return-value] 104 | 105 | return ttl_cached_property 106 | 107 | 108 | def format_since(since: Union[datetime, timedelta, str]) -> str: 109 | if isinstance(since, datetime): 110 | if since.tzinfo is not None: 111 | # Convert to UTC. 112 | since = since.astimezone(timezone.utc) 113 | return since.strftime("%Y-%m-%dT%H:%M:%S.%fZ") 114 | elif isinstance(since, timedelta): 115 | return format_since(datetime.now(tz=timezone.utc) - abs(since)) 116 | else: 117 | return since 118 | 119 | 120 | def parse_duration(dur: str) -> int: 121 | """ 122 | Parse a duration string into nanoseconds. 123 | """ 124 | dur_normalized = dur.replace(" ", "").lower() 125 | match = re.match(r"^([0-9.e-]+)([a-z]*)$", dur_normalized) 126 | if not match: 127 | raise ValueError(f"invalid duration string '{dur}'") 128 | 129 | value_str, unit = match.group(1), match.group(2) 130 | try: 131 | value = float(value_str) 132 | except ValueError: 133 | raise ValueError(f"invalid duration string '{dur}'") 134 | 135 | if not unit: 136 | # assume seconds 137 | unit = "s" 138 | 139 | if unit in ("ns", "nanosecond", "nanoseconds"): 140 | # nanoseconds 141 | return int(value) 142 | elif unit in ("µs", "microsecond", "microseconds"): 143 | return int(value * 1_000) 144 | elif unit in ("ms", "millisecond", "milliseconds"): 145 | # milliseconds 146 | return int(value * 1_000_000) 147 | elif unit in ("s", "sec", "second", "seconds"): 148 | # seconds 149 | return int(value * 1_000_000_000) 150 | elif unit in ("m", "min", "minute", "minutes"): 151 | # minutes 152 | return int(value * 60_000_000_000) 153 | elif unit in ("h", "hr", "hour", "hours"): 154 | # hours 155 | return int(value * 3_600_000_000_000) 156 | else: 157 | raise ValueError(f"invalid duration string '{dur}'") 158 | 159 | 160 | TIMESTAMP_RE = re.compile(rb"^([0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]+Z)(.*)$") 161 | 162 | 163 | def split_timestamp(s: bytes) -> Optional[str]: 164 | match = TIMESTAMP_RE.match(s) 165 | if match is not None: 166 | return match.group(1).decode() 167 | else: 168 | return None 169 | 170 | 171 | def log_and_wait(retries_so_far: int, err: Exception) -> None: 172 | from .client import Beaker 173 | 174 | retry_in = min(Beaker.BACKOFF_FACTOR * (2**retries_so_far), Beaker.BACKOFF_MAX) 175 | Beaker.logger.debug("Request failed with: %s\nRetrying in %d seconds...", err, retry_in) 176 | time.sleep(retry_in) 177 | 178 | 179 | def retriable( 180 | on_failure: Optional[Callable[..., None]] = None, 181 | recoverable_errors: Tuple[Type[Exception], ...] = (RequestException,), 182 | ): 183 | """ 184 | Use to make a service client method more robust by allowing retries. 185 | """ 186 | 187 | def parametrize_decorator(func: Callable[..., T]) -> Callable[..., T]: 188 | @wraps(func) 189 | def retriable_method(*args, **kwargs) -> T: 190 | from .client import Beaker 191 | 192 | retries = 0 193 | while True: 194 | try: 195 | return func(*args, **kwargs) 196 | except recoverable_errors as err: 197 | if retries < Beaker.MAX_RETRIES: 198 | if on_failure is not None: 199 | on_failure() 200 | log_and_wait(retries, err) 201 | retries += 1 202 | else: 203 | raise 204 | 205 | return retriable_method 206 | 207 | return parametrize_decorator 208 | 209 | 210 | def format_cursor(cursor: int) -> str: 211 | if cursor < 0: 212 | raise ValueError("cursor must be >= 0") 213 | 214 | return base64.urlsafe_b64encode(cursor.to_bytes(8, "little")).decode() 215 | 216 | 217 | def protobuf_to_json_dict(data) -> Dict[str, Any]: 218 | from google.protobuf.json_format import MessageToDict 219 | 220 | return MessageToDict(data) 221 | -------------------------------------------------------------------------------- /tests/experiment_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from beaker import ( 4 | Beaker, 5 | ClusterNotFound, 6 | CurrentJobStatus, 7 | DataMount, 8 | DatasetNotFound, 9 | DataSource, 10 | ExperimentSpec, 11 | ImageNotFound, 12 | ImageSource, 13 | ResultSpec, 14 | SecretNotFound, 15 | TaskContext, 16 | TaskNotFound, 17 | TaskSpec, 18 | ) 19 | 20 | 21 | def test_parse_create_args(client: Beaker): 22 | spec, name, workspace = client.experiment._parse_create_args( 23 | "my-experiment", ExperimentSpec.new(budget="ai2/allennlp", docker_image="hello-world") 24 | ) 25 | assert workspace is None 26 | assert name == "my-experiment" 27 | 28 | spec, name, workspace = client.experiment._parse_create_args( 29 | ExperimentSpec.new(budget="ai2/allennlp", docker_image="hello-world") 30 | ) 31 | assert workspace is None 32 | assert name is None 33 | assert spec is not None 34 | 35 | spec, name, workspace = client.experiment._parse_create_args( 36 | ExperimentSpec.new(budget="ai2/allennlp", docker_image="hello-world"), 37 | name="my-experiment", 38 | workspace="ai2/petew", 39 | ) 40 | assert workspace == "ai2/petew" 41 | assert name == "my-experiment" 42 | assert spec is not None 43 | 44 | spec, name, workspace = client.experiment._parse_create_args( 45 | name="my-experiment", 46 | spec=ExperimentSpec.new(budget="ai2/allennlp", docker_image="hello-world"), 47 | workspace="ai2/petew", 48 | ) 49 | assert workspace == "ai2/petew" 50 | assert name == "my-experiment" 51 | assert spec is not None 52 | 53 | spec, name, workspace = client.experiment._parse_create_args( 54 | "my-experiment", 55 | ExperimentSpec.new(budget="ai2/allennlp", docker_image="hello-world"), 56 | "ai2/petew", 57 | ) 58 | assert workspace == "ai2/petew" 59 | assert name == "my-experiment" 60 | assert spec is not None 61 | 62 | 63 | def test_experiment_get(client: Beaker, hello_world_experiment_id: str): 64 | exp = client.experiment.get(hello_world_experiment_id) 65 | assert exp.id == hello_world_experiment_id 66 | assert exp.jobs 67 | assert exp.jobs[0].status.current == CurrentJobStatus.finalized 68 | # Get with name. 69 | assert exp.name is not None 70 | client.experiment.get(exp.name) 71 | # Get with full name. 72 | assert exp.full_name is not None 73 | client.experiment.get(exp.full_name) 74 | 75 | 76 | def test_experiment_tasks(client: Beaker, hello_world_experiment_id: str): 77 | tasks = client.experiment.tasks(hello_world_experiment_id) 78 | assert len(tasks) == 1 79 | 80 | 81 | def test_experiment_metrics_none(client: Beaker, hello_world_experiment_id: str): 82 | metrics = client.experiment.metrics(hello_world_experiment_id) 83 | assert metrics is None 84 | 85 | 86 | def test_experiment_metrics(client: Beaker, experiment_id_with_metrics: str): 87 | metrics = client.experiment.metrics(experiment_id_with_metrics) 88 | assert metrics is not None 89 | 90 | 91 | # experiment was deleted 92 | # def test_experiment_results(client, experiment_id_with_results: str): 93 | # results = client.experiment.results(experiment_id_with_results) 94 | # assert results is not None 95 | # assert client.dataset.size(results) > 0 96 | 97 | 98 | def test_experiment_empty_results(client: Beaker, hello_world_experiment_id: str): 99 | results = client.experiment.results(hello_world_experiment_id) 100 | assert results is None or (client.dataset.size(results) == 0) 101 | 102 | 103 | def test_experiment_spec(client: Beaker, hello_world_experiment_id: str): 104 | spec = client.experiment.spec(hello_world_experiment_id) 105 | assert isinstance(spec, ExperimentSpec) 106 | 107 | 108 | def test_create_experiment_image_not_found( 109 | client: Beaker, 110 | experiment_name: str, 111 | beaker_cluster_name: str, 112 | ): 113 | spec = ExperimentSpec( 114 | budget="ai2/allennlp", 115 | tasks=[ 116 | TaskSpec( 117 | name="main", 118 | image=ImageSource(beaker="does-not-exist"), 119 | context=TaskContext(cluster=beaker_cluster_name), 120 | result=ResultSpec(path="/unused"), 121 | ), 122 | ], 123 | ) 124 | with pytest.raises(ImageNotFound): 125 | client.experiment.create(experiment_name, spec) 126 | 127 | 128 | def test_create_experiment_dataset_not_found( 129 | client: Beaker, 130 | experiment_name: str, 131 | beaker_cluster_name: str, 132 | ): 133 | spec = ExperimentSpec( 134 | budget="ai2/allennlp", 135 | tasks=[ 136 | TaskSpec( 137 | name="main", 138 | image=ImageSource(docker="hello-world"), 139 | context=TaskContext(cluster=beaker_cluster_name), 140 | result=ResultSpec(path="/unused"), 141 | datasets=[ 142 | DataMount(source=DataSource(beaker="does-not-exist"), mount_path="/data") 143 | ], 144 | ), 145 | ], 146 | ) 147 | with pytest.raises(DatasetNotFound): 148 | client.experiment.create(experiment_name, spec) 149 | 150 | 151 | def test_create_experiment_secret_not_found( 152 | client: Beaker, 153 | experiment_name: str, 154 | beaker_cluster_name: str, 155 | ): 156 | spec = ExperimentSpec( 157 | budget="ai2/allennlp", 158 | tasks=[ 159 | TaskSpec( 160 | name="main", 161 | image=ImageSource(docker="hello-world"), 162 | context=TaskContext(cluster=beaker_cluster_name), 163 | result=ResultSpec(path="/unused"), 164 | datasets=[ 165 | DataMount(source=DataSource(secret="does-not-exist"), mount_path="/data") 166 | ], 167 | ), 168 | ], 169 | ) 170 | with pytest.raises(SecretNotFound): 171 | client.experiment.create(experiment_name, spec) 172 | 173 | 174 | def test_create_experiment_result_not_found( 175 | client: Beaker, 176 | experiment_name: str, 177 | beaker_cluster_name: str, 178 | ): 179 | spec = ExperimentSpec( 180 | budget="ai2/allennlp", 181 | tasks=[ 182 | TaskSpec( 183 | name="main", 184 | image=ImageSource(docker="hello-world"), 185 | context=TaskContext(cluster=beaker_cluster_name), 186 | result=ResultSpec(path="/unused"), 187 | datasets=[ 188 | DataMount(source=DataSource(result="does-not-exist"), mount_path="/data") 189 | ], 190 | ), 191 | ], 192 | ) 193 | with pytest.raises(ValueError, match="does-not-exist"): 194 | client.experiment.create(experiment_name, spec) 195 | 196 | 197 | def test_create_experiment_cluster_not_found( 198 | client: Beaker, 199 | experiment_name: str, 200 | ): 201 | spec = ExperimentSpec( 202 | budget="ai2/allennlp", 203 | tasks=[ 204 | TaskSpec( 205 | name="main", 206 | image=ImageSource(docker="hello-world"), 207 | context=TaskContext(cluster="does-not-exist"), 208 | result=ResultSpec(path="/unused"), 209 | ), 210 | ], 211 | ) 212 | with pytest.raises(ClusterNotFound): 213 | client.experiment.create(experiment_name, spec) 214 | 215 | 216 | def test_experiment_url(client: Beaker, hello_world_experiment_id: str): 217 | assert ( 218 | client.experiment.url(hello_world_experiment_id) 219 | == f"https://beaker.org/ex/{hello_world_experiment_id}" 220 | ) 221 | assert ( 222 | client.experiment.url(hello_world_experiment_id, "main") 223 | == f"https://beaker.org/ex/{hello_world_experiment_id}/tasks/01GRYY999VAT2QY75G89A826YS" 224 | ) 225 | with pytest.raises(TaskNotFound, match="No task"): 226 | client.experiment.url(hello_world_experiment_id, "foo") 227 | -------------------------------------------------------------------------------- /beaker/data_model/dataset.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional, Tuple, Union 3 | from urllib.parse import urlparse 4 | 5 | from .account import Account 6 | from .base import BaseModel, BasePage, StrEnum, field_validator 7 | from .workspace import WorkspaceRef 8 | 9 | __all__ = [ 10 | "DatasetStorage", 11 | "DatasetSize", 12 | "Dataset", 13 | "DatasetInfo", 14 | "DatasetInfoPage", 15 | "Digest", 16 | "DigestHashAlgorithm", 17 | "FileInfo", 18 | "DatasetsPage", 19 | "DatasetSpec", 20 | "DatasetPatch", 21 | "DatasetSort", 22 | ] 23 | 24 | 25 | class DatasetStorage(BaseModel): 26 | id: str 27 | token: str 28 | token_expires: datetime 29 | address: Optional[str] = None 30 | url: Optional[str] = None 31 | urlv2: Optional[str] = None 32 | total_size: Optional[int] = None 33 | num_files: Optional[int] = None 34 | 35 | @field_validator("address") 36 | def _validate_address(cls, v: Optional[str]) -> Optional[str]: 37 | if v is not None and v.startswith("fh://"): 38 | # HACK: fix prior to https://github.com/allenai/beaker/pull/2962 39 | return v.replace("fh://", "https://", 1) 40 | else: 41 | return v 42 | 43 | @property 44 | def scheme(self) -> Optional[str]: 45 | return "fh" if self.urlv2 is None else urlparse(self.urlv2).scheme 46 | 47 | @property 48 | def base_url(self) -> str: 49 | if self.address is not None: 50 | return self.address 51 | elif self.urlv2 is not None: 52 | return f"https://{urlparse(self.urlv2).netloc}" 53 | else: 54 | raise ValueError("Missing field 'urlv2' or 'address'") 55 | 56 | 57 | class DatasetSize(BaseModel): 58 | files: int 59 | bytes: int 60 | final: Optional[bool] = None 61 | bytes_human: Optional[str] = None 62 | 63 | 64 | class Dataset(BaseModel): 65 | id: str 66 | name: Optional[str] = None 67 | full_name: Optional[str] = None 68 | description: Optional[str] = None 69 | author: Account 70 | created: datetime 71 | committed: Optional[datetime] = None 72 | workspace_ref: WorkspaceRef 73 | source_execution: Optional[str] = None 74 | storage: Optional[DatasetStorage] = None 75 | budget: Optional[str] = None 76 | 77 | @property 78 | def display_name(self) -> str: 79 | return self.name if self.name is not None else self.id 80 | 81 | @property 82 | def workspace(self) -> WorkspaceRef: 83 | return self.workspace_ref 84 | 85 | @field_validator("committed") 86 | def _validate_datetime(cls, v: Optional[datetime]) -> Optional[datetime]: 87 | if v is not None and v.year == 1: 88 | return None 89 | return v 90 | 91 | 92 | class DigestHashAlgorithm(StrEnum): 93 | """ 94 | Supported hash algorithms for file :class:`Digest`. 95 | """ 96 | 97 | SHA256 = "SHA256" 98 | 99 | SHA512 = "SHA512" 100 | 101 | MD5 = "MD5" 102 | 103 | def hasher(self): 104 | """ 105 | Get a :mod:`hasher ` object for the given algorithm. 106 | 107 | .. seealso:: 108 | :meth:`Digest.new_hasher()`. 109 | """ 110 | import hashlib 111 | 112 | if self == DigestHashAlgorithm.SHA256: 113 | return hashlib.sha256() 114 | elif self == DigestHashAlgorithm.SHA512: 115 | return hashlib.sha512() 116 | elif self == DigestHashAlgorithm.MD5: 117 | return hashlib.md5() 118 | else: 119 | raise NotImplementedError(f"hasher() not yet implemented for {str(self)}") 120 | 121 | 122 | class Digest(BaseModel): 123 | """ 124 | A digest is a checksum / hash of a files contents. These are used to verify 125 | the integrity of files downloaded from Beaker datasets. 126 | """ 127 | 128 | value: str 129 | """ 130 | The hex-encoded value of the digest. 131 | """ 132 | 133 | algorithm: DigestHashAlgorithm 134 | """ 135 | The algorithm used to create and verify the digest. 136 | """ 137 | 138 | def __init__(self, *args, **kwargs): 139 | if len(args) == 1 and "value" not in kwargs: 140 | value = args[0] 141 | if isinstance(value, str) and "algorithm" not in kwargs: 142 | # Assume 'value' is the string-encoded form of a digest. 143 | digest = Digest.from_encoded(value) 144 | kwargs = digest.model_dump() 145 | elif isinstance(value, str): 146 | # Assume 'value' is the hex-encoded hash. 147 | kwargs["value"] = value 148 | elif isinstance(value, bytes): 149 | # Assume 'value' is raw bytes of the hash. 150 | digest = Digest.from_decoded(value, **kwargs) 151 | kwargs = digest.model_dump() 152 | super().__init__(**kwargs) 153 | 154 | @field_validator("algorithm") 155 | def _validate_algorithm(cls, v: Union[str, DigestHashAlgorithm]) -> DigestHashAlgorithm: 156 | return DigestHashAlgorithm(v) 157 | 158 | def __str__(self) -> str: 159 | return self.encode() 160 | 161 | def __hash__(self): 162 | return hash(self.encode()) 163 | 164 | @classmethod 165 | def from_encoded(cls, encoded: str) -> "Digest": 166 | """ 167 | Initialize a digest from a string encoding of the form ``{ALGORITHM} {ENCODED_STRING}``, 168 | e.g. ``SHA256 iA02Sx8UNLYvMi49fDwdGjyy5ssU+ttuN1L4L3/JvZA=``. 169 | 170 | :param encoded: The string encoding of the digest. 171 | """ 172 | import base64 173 | import binascii 174 | 175 | algorithm, value_b64 = encoded.split(" ", 1) 176 | value_bytes = base64.standard_b64decode(value_b64) 177 | value = binascii.hexlify(value_bytes).decode() 178 | return cls(value=value, algorithm=DigestHashAlgorithm(algorithm)) 179 | 180 | @classmethod 181 | def from_decoded(cls, decoded: bytes, algorithm: Union[str, DigestHashAlgorithm]) -> "Digest": 182 | """ 183 | Initialize a digest from raw bytes. 184 | 185 | :param decoded: The raw bytes of the digest. 186 | :param algorithm: The algorithm used to produce the bytes of the digest 187 | from the contents of the corresponding file. 188 | """ 189 | import binascii 190 | 191 | value = binascii.hexlify(decoded).decode() 192 | return Digest(value=value, algorithm=DigestHashAlgorithm(algorithm)) 193 | 194 | def encode(self) -> str: 195 | """ 196 | Encode the digest into its string form. 197 | 198 | This is the inverse of :meth:`.from_encoded()`. 199 | """ 200 | import base64 201 | import binascii 202 | 203 | value_bytes = binascii.unhexlify(self.value) 204 | value_b64 = base64.standard_b64encode(value_bytes).decode() 205 | 206 | return f"{str(self.algorithm)} {value_b64}" 207 | 208 | def decode(self) -> bytes: 209 | """ 210 | Decode a digest into its raw bytes form. 211 | 212 | This is the inverse of :meth:`.from_decoded()`. 213 | """ 214 | import binascii 215 | 216 | return binascii.unhexlify(self.value) 217 | 218 | def new_hasher(self): 219 | """ 220 | Get a fresh :mod:`hasher ` object for the given algorithm. 221 | 222 | .. seealso:: 223 | :meth:`DigestHashAlgorithm.hasher()`. 224 | """ 225 | return DigestHashAlgorithm(self.algorithm).hasher() 226 | 227 | 228 | class FileInfo(BaseModel, arbitrary_types_allowed=True): 229 | path: str 230 | """ 231 | The path of the file within the dataset. 232 | """ 233 | 234 | updated: datetime 235 | """ 236 | The time that the file was last updated. 237 | """ 238 | 239 | digest: Optional[Digest] = None 240 | """ 241 | The digest of the contents of the file. 242 | """ 243 | 244 | size: Optional[int] = None 245 | """ 246 | The size of the file in bytes, if known. 247 | """ 248 | 249 | IGNORE_FIELDS = {"url"} 250 | 251 | @field_validator("digest", mode="before") 252 | def _validate_digest(cls, v: Union[str, Digest, None]) -> Optional[Digest]: 253 | if isinstance(v, Digest): 254 | return v 255 | elif isinstance(v, str): 256 | return Digest.from_encoded(v) 257 | elif isinstance(v, dict): 258 | return Digest(**v) 259 | else: 260 | raise ValueError(f"Unexpected value for 'digest': {v}") 261 | 262 | 263 | class DatasetsPage(BasePage[Dataset]): 264 | data: Tuple[Dataset, ...] 265 | 266 | 267 | class DatasetInfoPage(BasePage[FileInfo]): 268 | data: Tuple[FileInfo, ...] 269 | 270 | 271 | class DatasetInfo(BaseModel): 272 | page: DatasetInfoPage 273 | size: DatasetSize 274 | 275 | 276 | class DatasetSpec(BaseModel): 277 | workspace: Optional[str] = None 278 | description: Optional[str] = None 279 | budget: Optional[str] = None 280 | 281 | 282 | class DatasetPatch(BaseModel): 283 | name: Optional[str] = None 284 | description: Optional[str] = None 285 | commit: Optional[bool] = None 286 | 287 | 288 | class DatasetSort(StrEnum): 289 | created = "created" 290 | author = "author" 291 | dataset_name = "name" 292 | dataset_name_or_description = "nameOrDescription" 293 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thanks for considering contributing! Please read this document to learn the various ways you can contribute to this project and how to go about doing it. 4 | 5 | ## Bug reports and feature requests 6 | 7 | ### Did you find a bug? 8 | 9 | First, do [a quick search](https://github.com/allenai/beaker-py/issues) to see whether your issue has already been reported. 10 | If your issue has already been reported, please comment on the existing issue. 11 | 12 | Otherwise, open [a new GitHub issue](https://github.com/allenai/beaker-py/issues). Be sure to include a clear title 13 | and description. The description should include as much relevant information as possible. The description should 14 | explain how to reproduce the erroneous behavior as well as the behavior you expect to see. Ideally you would include a 15 | code sample or an executable test case demonstrating the expected behavior. 16 | 17 | ### Do you have a suggestion for an enhancement or new feature? 18 | 19 | We use GitHub issues to track feature requests. Before you create an feature request: 20 | 21 | * Make sure you have a clear idea of the enhancement you would like. If you have a vague idea, consider discussing 22 | it first on a GitHub issue. 23 | * Check the documentation to make sure your feature does not already exist. 24 | * Do [a quick search](https://github.com/allenai/beaker-py/issues) to see whether your feature has already been suggested. 25 | 26 | When creating your request, please: 27 | 28 | * Provide a clear title and description. 29 | * Explain why the enhancement would be useful. It may be helpful to highlight the feature in other libraries. 30 | * Include code examples to demonstrate how the enhancement would be used. 31 | 32 | ## Making a pull request 33 | 34 | When you're ready to contribute code to address an open issue, please follow these guidelines to help us be able to review your pull request (PR) quickly. 35 | 36 | 1. **Initial setup** (only do this once) 37 | 38 |
Expand details 👇
39 | 40 | If you haven't already done so, please [fork](https://help.github.com/en/enterprise/2.13/user/articles/fork-a-repo) this repository on GitHub. 41 | 42 | Then clone your fork locally with 43 | 44 | git clone https://github.com/USERNAME/beaker-py.git 45 | 46 | or 47 | 48 | git clone git@github.com:USERNAME/beaker-py.git 49 | 50 | At this point the local clone of your fork only knows that it came from *your* repo, github.com/USERNAME/beaker-py.git, but doesn't know anything the *main* repo, [https://github.com/allenai/beaker-py.git](https://github.com/allenai/beaker-py). You can see this by running 51 | 52 | git remote -v 53 | 54 | which will output something like this: 55 | 56 | origin https://github.com/USERNAME/beaker-py.git (fetch) 57 | origin https://github.com/USERNAME/beaker-py.git (push) 58 | 59 | This means that your local clone can only track changes from your fork, but not from the main repo, and so you won't be able to keep your fork up-to-date with the main repo over time. Therefore you'll need to add another "remote" to your clone that points to [https://github.com/allenai/beaker-py.git](https://github.com/allenai/beaker-py). To do this, run the following: 60 | 61 | git remote add upstream https://github.com/allenai/beaker-py.git 62 | 63 | Now if you do `git remote -v` again, you'll see 64 | 65 | origin https://github.com/USERNAME/beaker-py.git (fetch) 66 | origin https://github.com/USERNAME/beaker-py.git (push) 67 | upstream https://github.com/allenai/beaker-py.git (fetch) 68 | upstream https://github.com/allenai/beaker-py.git (push) 69 | 70 | Finally, you'll need to create a Python 3 virtual environment suitable for working on this project. There a number of tools out there that making working with virtual environments easier. 71 | The most direct way is with the [`venv` module](https://docs.python.org/3.7/library/venv.html) in the standard library, but if you're new to Python or you don't already have a recent Python 3 version installed on your machine, 72 | we recommend [Miniconda](https://docs.conda.io/en/latest/miniconda.html). 73 | 74 | On Mac, for example, you can install Miniconda with [Homebrew](https://brew.sh/): 75 | 76 | brew install miniconda 77 | 78 | Then you can create and activate a new Python environment by running: 79 | 80 | conda create -n beaker-py python=3.9 81 | conda activate beaker-py 82 | 83 | Once your virtual environment is activated, you can install your local clone in "editable mode" with 84 | 85 | pip install -U pip setuptools wheel 86 | pip install -e .[dev] 87 | 88 | The "editable mode" comes from the `-e` argument to `pip`, and essential just creates a symbolic link from the site-packages directory of your virtual environment to the source code in your local clone. That way any changes you make will be immediately reflected in your virtual environment. 89 | 90 |
91 | 92 | 2. **Ensure your fork is up-to-date** 93 | 94 |
Expand details 👇
95 | 96 | Once you've added an "upstream" remote pointing to [https://github.com/allenai/beaker-py.git](https://github.com/allenai/beaker-py), keeping your fork up-to-date is easy: 97 | 98 | git checkout main # if not already on main 99 | git pull --rebase upstream main 100 | git push 101 | 102 |
103 | 104 | 3. **Create a new branch to work on your fix or enhancement** 105 | 106 |
Expand details 👇
107 | 108 | Commiting directly to the main branch of your fork is not recommended. It will be easier to keep your fork clean if you work on a seperate branch for each contribution you intend to make. 109 | 110 | You can create a new branch with 111 | 112 | # replace BRANCH with whatever name you want to give it 113 | git checkout -b BRANCH 114 | git push -u origin BRANCH 115 | 116 |
117 | 118 | 4. **Test your changes** 119 | 120 |
Expand details 👇
121 | 122 | Our continuous integration (CI) testing runs [a number of checks](https://github.com/allenai/beaker-py/actions) for each pull request on [GitHub Actions](https://github.com/features/actions). You can run most of these tests locally, which is something you should do *before* opening a PR to help speed up the review process and make it easier for us. 123 | 124 | First, you should run [`isort`](https://github.com/PyCQA/isort) and [`black`](https://github.com/psf/black) to make sure you code is formatted consistently. 125 | Many IDEs support code formatters as plugins, so you may be able to setup isort and black to run automatically everytime you save. 126 | For example, [`black.vim`](https://github.com/psf/black/tree/master/plugin) will give you this functionality in Vim. But both `isort` and `black` are also easy to run directly from the command line. 127 | Just run this from the root of your clone: 128 | 129 | isort . 130 | black . 131 | 132 | Our CI also uses [`flake8`](https://github.com/allenai/beaker-py/tree/main/tests) to lint the code base and [`mypy`](http://mypy-lang.org/) for type-checking. You should run both of these next with 133 | 134 | flake8 . 135 | 136 | and 137 | 138 | mypy . 139 | 140 | We also strive to maintain high test coverage, so most contributions should include additions to [the unit tests](https://github.com/allenai/beaker-py/tree/main/tests). These tests are run with [`pytest`](https://docs.pytest.org/en/latest/), which you can use to locally run any test modules that you've added or changed. 141 | 142 | For example, if you've fixed a bug in `beaker/a/b.py`, you can run the tests specific to that module with 143 | 144 | pytest -v tests/a/b_test.py 145 | 146 | Our CI will automatically check that test coverage stays above a certain threshold (around 90%). To check the coverage locally in this example, you could run 147 | 148 | pytest -v --cov beaker.a.b tests/a/b_test.py 149 | 150 | If your contribution involves additions to any public part of the API, we require that you write docstrings 151 | for each function, method, class, or module that you add. 152 | See the [Writing docstrings](#writing-docstrings) section below for details on the syntax. 153 | You should test to make sure the API documentation can build without errors by running 154 | 155 | make docs 156 | 157 | If the build fails, it's most likely due to small formatting issues. If the error message isn't clear, feel free to comment on this in your pull request. 158 | 159 | And finally, please update the [CHANGELOG](https://github.com/allenai/beaker-py/blob/main/CHANGELOG.md) with notes on your contribution in the "Unreleased" section at the top. 160 | 161 | After all of the above checks have passed, you can now open [a new GitHub pull request](https://github.com/allenai/beaker-py/pulls). 162 | Make sure you have a clear description of the problem and the solution, and include a link to relevant issues. 163 | 164 | We look forward to reviewing your PR! 165 | 166 |
167 | 168 | ### Writing docstrings 169 | 170 | We use [Sphinx](https://www.sphinx-doc.org/en/master/index.html) to build our API docs, which automatically parses all docstrings 171 | of public classes and methods using the [autodoc](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html) extension. 172 | Please refer to autoc's documentation to learn about the docstring syntax. 173 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Main 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | pull_request: 9 | branches: 10 | - main 11 | push: 12 | branches: 13 | - main 14 | tags: 15 | - 'v*.*.*' 16 | 17 | env: 18 | CACHE_PREFIX: v0 # Change this to invalidate existing cache. 19 | DEFAULT_PYTHON: 3.8 20 | BEAKER_WORKSPACE: ai2/petew-testing 21 | 22 | jobs: 23 | compatibility: 24 | name: Compatibility 25 | runs-on: ubuntu-latest 26 | timeout-minutes: 15 27 | steps: 28 | - uses: actions/checkout@v3 29 | 30 | - name: Setup Python environment 31 | uses: ./.github/actions/setup-venv 32 | with: 33 | python-version: ${{ env.DEFAULT_PYTHON }} 34 | cache-prefix: ${{ env.CACHE_PREFIX }} 35 | 36 | - name: Setup Beaker 37 | uses: allenai/setup-beaker@v2 38 | with: 39 | token: ${{ secrets.BEAKER_TOKEN }} 40 | workspace: ${{ env.BEAKER_WORKSPACE }} 41 | 42 | - name: Check config compatibility 43 | shell: bash 44 | run: | 45 | . .venv/bin/activate 46 | python -c 'from beaker import Beaker; print(Beaker.from_env().account.name)' 47 | 48 | - name: Clean up 49 | if: always() 50 | shell: bash 51 | run: | 52 | . .venv/bin/activate 53 | pip uninstall -y beaker-py 54 | 55 | pydantic_v1: 56 | name: Pydantic V1 57 | runs-on: ubuntu-latest 58 | timeout-minutes: 15 59 | steps: 60 | - uses: actions/checkout@v3 61 | 62 | - name: Setup Python environment 63 | uses: ./.github/actions/setup-venv 64 | with: 65 | python-version: ${{ env.DEFAULT_PYTHON }} 66 | cache-prefix: pydantic-v1-${{ env.CACHE_PREFIX }} 67 | packages: pydantic<2.0 68 | 69 | - name: Check Pydantic V1 compatibility 70 | shell: bash 71 | run: | 72 | . .venv/bin/activate 73 | pytest -v tests/data_model_test.py 74 | 75 | - name: Clean up 76 | if: always() 77 | shell: bash 78 | run: | 79 | . .venv/bin/activate 80 | pip uninstall -y beaker-py 81 | 82 | checks: 83 | name: ${{ matrix.os }} - python ${{ matrix.python }} - ${{ matrix.task.name }} 84 | runs-on: ${{ matrix.os }} 85 | timeout-minutes: 15 86 | env: 87 | BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} 88 | strategy: 89 | fail-fast: false 90 | matrix: 91 | os: [ubuntu-latest] 92 | python: [3.8] 93 | task: 94 | - name: Lint 95 | run: | 96 | ruff check . 97 | 98 | - name: Type check 99 | run: | 100 | mypy --check-untyped-defs . 101 | 102 | - name: Unit tests 103 | run: | 104 | pytest -v --color=yes --durations=10 tests/ 105 | 106 | # - name: Doc tests 107 | # run: | 108 | # pytest -v --color=yes --durations=10 --doctest-modules -k 'beaker/__init__.py or beaker.client' beaker/ 109 | 110 | # - name: Doc tests (B) 111 | # run: | 112 | # pytest -v --color=yes --durations=10 --doctest-modules -k 'not beaker/__init__.py and not beaker.client' beaker/ 113 | 114 | - name: Images 115 | run: | 116 | # Clean up local cache. 117 | docker system prune --all --force 118 | # Build test image for uploading. 119 | cd test_fixtures/docker \ 120 | && docker build --build-arg "COMMIT_SHA=$COMMIT_SHA" -t beaker-py-test . \ 121 | && cd - 122 | pytest -rP -v --color=yes integration_tests/images_test.py 123 | 124 | - name: Jobs 125 | run: | 126 | pytest -rP -v --color=yes integration_tests/jobs_test.py 127 | 128 | - name: Experiments 129 | run: | 130 | pytest -rP -v --color=yes integration_tests/experiments_test.py 131 | 132 | - name: Datasets 133 | run: | 134 | pytest -rP -v --color=yes integration_tests/datasets_test.py 135 | 136 | - name: Sweep example 137 | run: | 138 | cd examples/sweep 139 | # NOTE: anytime you change something here, make sure the run instructions 140 | # in 'examples/sweep/README.md' are still up-to-date. 141 | docker build -t sweep . 142 | python run.py "sweep" "ai2/beaker-py-sweep-example" 143 | 144 | - name: Build 145 | run: | 146 | python -m build 147 | 148 | - name: Style 149 | run: | 150 | isort --check . 151 | black --check . 152 | 153 | - name: Docs 154 | run: | 155 | cd docs && make html 156 | 157 | steps: 158 | - uses: actions/checkout@v3 159 | 160 | - name: Determine current commit SHA (pull request) 161 | if: github.event_name == 'pull_request' 162 | run: | 163 | echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV 164 | 165 | - name: Determine current commit SHA (push) 166 | if: github.event_name != 'pull_request' 167 | run: | 168 | echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV 169 | 170 | - name: Setup Python environment 171 | uses: ./.github/actions/setup-venv 172 | with: 173 | python-version: ${{ matrix.python }} 174 | cache-prefix: ${{ env.CACHE_PREFIX }} 175 | 176 | - name: ${{ matrix.task.name }} 177 | shell: bash 178 | run: | 179 | set -euo pipefail 180 | . .venv/bin/activate 181 | ${{ matrix.task.run }} 182 | 183 | - name: Upload package distribution files 184 | if: matrix.task.name == 'Build' 185 | uses: actions/upload-artifact@v4 186 | with: 187 | name: package 188 | path: dist 189 | 190 | - name: Clean up 191 | if: always() 192 | shell: bash 193 | run: | 194 | . .venv/bin/activate 195 | pip uninstall -y beaker-py 196 | 197 | docker: 198 | name: Docker 199 | runs-on: ubuntu-latest 200 | env: 201 | image: ghcr.io/allenai/beaker-py 202 | steps: 203 | - uses: actions/checkout@v3 204 | 205 | - name: Log in to ghcr.io 206 | run: | 207 | echo ${{ secrets.GHCR_TOKEN }} | docker login ghcr.io -u ${{ secrets.GHCR_USER }} --password-stdin 208 | 209 | - name: Build image 210 | run: | 211 | docker build -t "${image}" . 212 | 213 | - name: Test image 214 | run: | 215 | docker run \ 216 | --rm \ 217 | --entrypoint python \ 218 | -e BEAKER_TOKEN=${{ secrets.BEAKER_TOKEN }} \ 219 | "${image}" \ 220 | -c "from beaker import Beaker; beaker = Beaker.from_env(); print(beaker.account.whoami())" 221 | 222 | - name: Publish image to container registry 223 | if: startsWith(github.ref, 'refs/tags/') 224 | shell: bash 225 | run: | 226 | TAG_MAJOR_MINOR_PATCH=${GITHUB_REF#refs/tags/} 227 | TAG_MAJOR_MINOR=${TAG_MAJOR_MINOR_PATCH%.*} 228 | TAG_MAJOR=${TAG_MAJOR_MINOR_PATCH%.*.*} 229 | TAG_LATEST="latest" 230 | 231 | for TAG in $TAG_MAJOR_MINOR_PATCH $TAG_MAJOR_MINOR $TAG_MAJOR $TAG_LATEST 232 | do 233 | echo "Pushing ${image}:${TAG}" 234 | docker tag "${image}" "${image}:${TAG}" 235 | docker push "${image}:${TAG}" 236 | done 237 | 238 | release: 239 | name: Release 240 | runs-on: ubuntu-latest 241 | needs: [checks, docker] 242 | if: startsWith(github.ref, 'refs/tags/') 243 | steps: 244 | - uses: actions/checkout@v1 # needs v1 for now 245 | 246 | - name: Log in to ghcr.io 247 | run: | 248 | echo ${{ secrets.GHCR_TOKEN }} | docker login ghcr.io -u ${{ secrets.GHCR_USER }} --password-stdin 249 | 250 | - name: Setup Python 251 | uses: actions/setup-python@v4 252 | with: 253 | python-version: ${{ env.DEFAULT_PYTHON }} 254 | 255 | - name: Install requirements 256 | run: | 257 | pip install --upgrade pip setuptools wheel build 258 | pip install -e .[dev] 259 | 260 | - name: Prepare environment 261 | run: | 262 | echo "RELEASE_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV 263 | echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV 264 | 265 | - name: Download package distribution files 266 | uses: actions/download-artifact@v4 267 | with: 268 | name: package 269 | path: dist 270 | 271 | - name: Generate release notes 272 | run: | 273 | python scripts/release_notes.py > ${{ github.workspace }}-RELEASE_NOTES.md 274 | 275 | - name: Publish package to PyPI 276 | run: | 277 | twine upload -u __token__ -p ${{ secrets.PYPI_PASSWORD }} dist/* 278 | 279 | - name: Publish GitHub release 280 | uses: softprops/action-gh-release@v1 281 | env: 282 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 283 | with: 284 | body_path: ${{ github.workspace }}-RELEASE_NOTES.md 285 | prerelease: ${{ contains(env.TAG, 'rc') }} 286 | files: | 287 | dist/* 288 | 289 | - name: Add PR comments on release 290 | env: 291 | GH_TOKEN: ${{ github.token }} 292 | run: | 293 | ./scripts/add_pr_comments_on_release.sh 294 | -------------------------------------------------------------------------------- /beaker/data_model/job.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | from pydantic import Field 5 | 6 | from .account import Account 7 | from .base import BaseModel, IntEnum, StrEnum, field_validator 8 | from .experiment_spec import ( 9 | DataMount, 10 | EnvVar, 11 | ImageSource, 12 | Priority, 13 | ResultSpec, 14 | TaskSpec, 15 | ) 16 | 17 | __all__ = [ 18 | "CurrentJobStatus", 19 | "CanceledCode", 20 | "JobStatus", 21 | "ExecutionResult", 22 | "JobRequests", 23 | "JobLimits", 24 | "JobExecution", 25 | "JobKind", 26 | "Job", 27 | "Jobs", 28 | "JobStatusUpdate", 29 | "JobPatch", 30 | "Session", 31 | "SummarizedJobEvent", 32 | "JobLog", 33 | ] 34 | 35 | 36 | class CurrentJobStatus(StrEnum): 37 | """ 38 | The status of a job. 39 | """ 40 | 41 | created = "created" 42 | scheduled = "scheduled" 43 | running = "running" 44 | idle = "idle" 45 | exited = "exited" 46 | failed = "failed" 47 | finalized = "finalized" 48 | canceled = "canceled" 49 | preempted = "preempted" 50 | 51 | 52 | class CanceledCode(IntEnum): 53 | not_set = 0 54 | system_preemption = 1 55 | user_preemption = 2 56 | idle = 3 57 | manual_cancellation = 4 58 | 59 | 60 | class JobStatus(BaseModel): 61 | created: datetime 62 | scheduled: Optional[datetime] = None 63 | started: Optional[datetime] = None 64 | exited: Optional[datetime] = None 65 | failed: Optional[datetime] = None 66 | finalized: Optional[datetime] = None 67 | canceled: Optional[datetime] = None 68 | canceled_for: Optional[str] = None 69 | canceled_code: Optional[Union[CanceledCode, int]] = None 70 | idle_since: Optional[datetime] = None 71 | ready: Optional[datetime] = None 72 | exit_code: Optional[int] = None 73 | message: Optional[str] = None 74 | 75 | @field_validator( 76 | "created", "scheduled", "started", "exited", "failed", "finalized", "canceled", "idle_since" 77 | ) 78 | def _validate_datetime(cls, v: Optional[datetime]) -> Optional[datetime]: 79 | if v is not None and v.year == 1: 80 | return None 81 | return v 82 | 83 | @property 84 | def current(self) -> CurrentJobStatus: 85 | """ 86 | Get the :class:`CurrentJobStatus`. 87 | 88 | :raises ValueError: If status can't be determined. 89 | """ 90 | if self.finalized is not None: 91 | return CurrentJobStatus.finalized 92 | elif self.failed is not None: 93 | return CurrentJobStatus.failed 94 | elif self.exited is not None: 95 | return CurrentJobStatus.exited 96 | elif self.canceled is not None: 97 | if self.canceled_code in {CanceledCode.system_preemption, CanceledCode.user_preemption}: 98 | return CurrentJobStatus.preempted 99 | else: 100 | return CurrentJobStatus.canceled 101 | elif self.idle_since is not None: 102 | return CurrentJobStatus.idle 103 | elif self.started is not None: 104 | return CurrentJobStatus.running 105 | elif self.scheduled is not None: 106 | return CurrentJobStatus.scheduled 107 | elif self.created is not None: 108 | return CurrentJobStatus.created 109 | else: 110 | raise ValueError(f"Invalid status {self}") 111 | 112 | 113 | class ExecutionResult(BaseModel): 114 | beaker: Optional[str] = None 115 | 116 | 117 | class JobRequests(BaseModel): 118 | gpu_count: Optional[int] = None 119 | cpu_count: Optional[float] = None 120 | memory: Optional[str] = None 121 | shared_memory: Optional[str] = None 122 | 123 | 124 | class JobLimits(BaseModel): 125 | cpu_count: Optional[float] = None 126 | memory: Optional[str] = None 127 | gpus: Tuple[str, ...] = Field(default_factory=tuple) 128 | 129 | 130 | class JobExecution(BaseModel): 131 | task: str 132 | experiment: str 133 | spec: TaskSpec 134 | result: ExecutionResult 135 | workspace: Optional[str] = None 136 | replica_rank: Optional[int] = None 137 | replica_group_id: Optional[str] = None 138 | retry_ancestor: Optional[str] = None 139 | 140 | 141 | class JobKind(StrEnum): 142 | """ 143 | The kind of job. 144 | """ 145 | 146 | execution = "execution" 147 | session = "session" 148 | 149 | 150 | class Session(BaseModel): 151 | command: Optional[Tuple[str, ...]] = None 152 | env_vars: Optional[Tuple[EnvVar, ...]] = None 153 | datasets: Optional[Tuple[DataMount, ...]] = None 154 | image: Optional[ImageSource] = None 155 | save_image: bool = False 156 | ports: Optional[Tuple[int, ...]] = None 157 | ports_v2: Optional[Tuple[Tuple[int, int], ...]] = None 158 | priority: Optional[Priority] = None 159 | work_dir: Optional[str] = None 160 | identity: Optional[str] = None 161 | constraints: Optional[Dict[str, List[str]]] = None 162 | result: Optional[ResultSpec] = None 163 | 164 | 165 | class Job(BaseModel): 166 | """ 167 | A :class:`Job` is an execution of a :class:`Task`. 168 | 169 | .. tip:: 170 | You can check a job's exit code with :data:`job.status.exit_code `. 171 | """ 172 | 173 | id: str 174 | kind: JobKind 175 | author: Account 176 | workspace: str 177 | status: JobStatus 178 | name: Optional[str] = None 179 | cluster: Optional[str] = None 180 | budget: Optional[str] = None 181 | execution: Optional[JobExecution] = None 182 | execution_results: Optional[Dict[str, Any]] = None 183 | node: Optional[str] = None 184 | node_has_gpus: Optional[bool] = None 185 | requests: Optional[JobRequests] = None 186 | limits: Optional[JobLimits] = None 187 | session: Optional[Session] = None 188 | host_networking: bool = False 189 | port_mappings: Optional[Dict[str, int]] = None 190 | result: Optional[ExecutionResult] = None 191 | preemptible: Optional[bool] = None 192 | 193 | @property 194 | def display_name(self) -> str: 195 | return self.name if self.name is not None else self.id 196 | 197 | @property 198 | def is_finalized(self) -> bool: 199 | return self.status.current == CurrentJobStatus.finalized 200 | 201 | @property 202 | def is_done(self) -> bool: 203 | """ 204 | Same as :meth:`is_finalized()`, kept for backwards compatibility. 205 | """ 206 | return self.status.current == CurrentJobStatus.finalized 207 | 208 | @property 209 | def is_running(self) -> bool: 210 | return self.status.current in (CurrentJobStatus.running, CurrentJobStatus.idle) 211 | 212 | @property 213 | def is_queued(self) -> bool: 214 | return self.status.current == CurrentJobStatus.created 215 | 216 | @property 217 | def was_preempted(self) -> bool: 218 | return self.status.canceled is not None and self.status.canceled_code in { 219 | CanceledCode.system_preemption, 220 | CanceledCode.user_preemption, 221 | } 222 | 223 | @property 224 | def is_preemptible(self) -> bool: 225 | return self.preemptible or (self.priority == Priority.preemptible) 226 | 227 | @property 228 | def priority(self) -> Optional[Priority]: 229 | """ 230 | Get the priority of the job. 231 | """ 232 | if self.session is not None: 233 | return self.session.priority 234 | elif self.execution is not None: 235 | return self.execution.spec.context.priority 236 | else: 237 | return None 238 | 239 | def check(self): 240 | """ 241 | :raises JobFailedError: If the job failed or was canceled. 242 | """ 243 | from ..exceptions import JobFailedError 244 | 245 | if self.status.exit_code is not None and self.status.exit_code > 0: 246 | raise JobFailedError( 247 | f"Job '{self.id}' exited with non-zero exit code ({self.status.exit_code})", 248 | job=self, 249 | ) 250 | elif self.status.canceled is not None: 251 | raise JobFailedError(f"Job '{self.id}' was canceled", job=self) 252 | elif self.status.failed is not None: 253 | raise JobFailedError(f"Job '{self.id}' failed", job=self) 254 | 255 | 256 | class Jobs(BaseModel): 257 | data: Optional[Tuple[Job, ...]] = None 258 | next: Optional[str] = None 259 | next_cursor: Optional[str] = None 260 | 261 | 262 | class JobStatusUpdate(BaseModel): 263 | scheduled: Optional[bool] = None 264 | started: Optional[bool] = None 265 | exit_code: Optional[int] = None 266 | failed: Optional[bool] = None 267 | finalized: Optional[bool] = None 268 | canceled: Optional[bool] = None 269 | canceled_for: Optional[str] = None 270 | canceled_code: Optional[Union[CanceledCode, int]] = None 271 | idle: Optional[bool] = None 272 | message: Optional[str] = None 273 | 274 | 275 | class JobPatch(BaseModel): 276 | status: Optional[JobStatusUpdate] = None 277 | limits: Optional[JobLimits] = None 278 | priority: Optional[Priority] = None 279 | 280 | 281 | class SummarizedJobEvent(BaseModel): 282 | job_id: str 283 | status: str 284 | occurrences: int 285 | earliest_occurrence: datetime 286 | latest_occurrence: datetime 287 | latest_message: str 288 | 289 | 290 | class JobLog(BaseModel): 291 | timestamp: datetime 292 | """ 293 | The time that the log line was recorded. 294 | """ 295 | message: str = "" 296 | """ 297 | The contents of the log line. 298 | """ 299 | -------------------------------------------------------------------------------- /tests/data_model_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from beaker.data_model import * 6 | from beaker.data_model.base import MappedSequence 7 | from beaker.exceptions import ValidationError 8 | 9 | 10 | def test_data_source_validation(): 11 | with pytest.raises(ValidationError, match="Exactly one"): 12 | DataSource() 13 | 14 | with pytest.raises(ValidationError, match="Exactly one"): 15 | DataSource(beaker="foo", host_path="bar") 16 | 17 | with pytest.raises(ValidationError, match="Exactly one"): 18 | DataSource(beaker="foo", hostPath="bar") # type: ignore 19 | 20 | assert DataSource(host_path="bar").host_path == "bar" 21 | 22 | 23 | def test_experiment_spec_from_and_to_json_and_file(beaker_cluster_name: str, tmp_path: Path): 24 | json_spec = { 25 | "version": "v2", 26 | "budget": "ai2/allennlp", 27 | "tasks": [ 28 | { 29 | "name": "main", 30 | "image": {"docker": "hello-world"}, 31 | "context": {"cluster": beaker_cluster_name}, 32 | "result": {"path": "/unused"}, 33 | "resources": {"memory": "512m", "sharedMemory": "512m"}, 34 | "hostNetworking": False, 35 | "leaderSelection": False, 36 | }, 37 | ], 38 | } 39 | 40 | spec = ExperimentSpec.from_json(json_spec) 41 | assert spec.to_json() == json_spec 42 | 43 | spec_path = tmp_path / "spec.yml" 44 | spec.to_file(spec_path) 45 | assert ExperimentSpec.from_file(spec_path) == spec 46 | 47 | 48 | def test_experiment_spec_from_with_timeout(beaker_cluster_name: str): 49 | json_spec = { 50 | "version": "v2", 51 | "budget": "ai2/allennlp", 52 | "tasks": [ 53 | { 54 | "name": "main", 55 | "image": {"docker": "hello-world"}, 56 | "context": {"cluster": beaker_cluster_name}, 57 | "result": {"path": "/unused"}, 58 | "resources": {"memory": "512m", "sharedMemory": "512m"}, 59 | "hostNetworking": False, 60 | "leaderSelection": False, 61 | "timeout": "10m", 62 | }, 63 | ], 64 | } 65 | 66 | spec = ExperimentSpec.from_json(json_spec) 67 | assert spec.tasks[0].timeout == 600000000000 68 | 69 | json_spec = { 70 | "version": "v2", 71 | "budget": "ai2/allennlp", 72 | "tasks": [ 73 | { 74 | "name": "main", 75 | "image": {"docker": "hello-world"}, 76 | "context": {"cluster": beaker_cluster_name}, 77 | "result": {"path": "/unused"}, 78 | "resources": {"memory": "512m", "sharedMemory": "512m"}, 79 | "hostNetworking": False, 80 | "leaderSelection": False, 81 | "timeout": None, 82 | }, 83 | ], 84 | } 85 | 86 | spec = ExperimentSpec.from_json(json_spec) 87 | assert spec.tasks[0].timeout is None 88 | 89 | json_spec = { 90 | "version": "v2", 91 | "budget": "ai2/allennlp", 92 | "tasks": [ 93 | { 94 | "name": "main", 95 | "image": {"docker": "hello-world"}, 96 | "context": {"cluster": beaker_cluster_name}, 97 | "result": {"path": "/unused"}, 98 | "resources": {"memory": "512m", "sharedMemory": "512m"}, 99 | "hostNetworking": False, 100 | "leaderSelection": False, 101 | "timeout": 600000000000.0, 102 | }, 103 | ], 104 | } 105 | 106 | spec = ExperimentSpec.from_json(json_spec) 107 | assert isinstance(spec.tasks[0].timeout, int) 108 | assert spec.tasks[0].timeout == 600000000000 109 | 110 | 111 | def test_experiment_spec_validation(): 112 | with pytest.raises(ValidationError, match="Duplicate task name"): 113 | ExperimentSpec.from_json( 114 | { 115 | "budget": "ai2/allennlp", 116 | "tasks": [ 117 | { 118 | "name": "main", 119 | "image": {"docker": "hello-world"}, 120 | "context": {"cluster": "foo"}, 121 | "result": {"path": "/unused"}, 122 | }, 123 | { 124 | "name": "main", 125 | "image": {"docker": "hello-world"}, 126 | "context": {"cluster": "bar"}, 127 | "result": {"path": "/unused"}, 128 | }, 129 | ], 130 | } 131 | ) 132 | with pytest.raises(ValidationError, match="Duplicate task name"): 133 | ExperimentSpec( 134 | budget="ai2/allennlp", 135 | tasks=[ 136 | TaskSpec( 137 | name="main", 138 | image={"docker": "hello-world"}, # type: ignore 139 | context={"cluster": "foo"}, # type: ignore 140 | result={"path": "/unused"}, # type: ignore 141 | ), 142 | TaskSpec( 143 | name="main", 144 | image={"docker": "hello-world"}, # type: ignore 145 | context={"cluster": "bar"}, # type: ignore 146 | result={"path": "/unused"}, # type: ignore 147 | ), 148 | ], 149 | ) 150 | spec = ExperimentSpec(budget="ai2/allennlp").with_task( 151 | TaskSpec.new("main", "foo", docker_image="hello-world") 152 | ) 153 | with pytest.raises(ValueError, match="A task with the name"): 154 | spec.with_task(TaskSpec.new("main", "bar", docker_image="hello-world")) 155 | 156 | 157 | def test_snake_case_vs_lower_camel_case(): 158 | for x in (DataSource(host_path="/tmp/foo"), DataSource(hostPath="/tmp/foo")): # type: ignore 159 | assert ( 160 | str(x) 161 | == "DataSource(beaker=None, host_path='/tmp/foo', weka=None, result=None, secret=None)" 162 | ) 163 | assert x.host_path == "/tmp/foo" 164 | x.host_path = "/tmp/bar" 165 | assert ( 166 | str(x) 167 | == "DataSource(beaker=None, host_path='/tmp/bar', weka=None, result=None, secret=None)" 168 | ) 169 | assert x.to_json() == {"hostPath": "/tmp/bar"} 170 | 171 | 172 | def test_digest_init(): 173 | # All of these are equivalent: 174 | for digest in ( 175 | # String form. 176 | Digest("SHA256 iA02Sx8UNLYvMi49fDwdGjyy5ssU+ttuN1L4L3/JvZA="), 177 | # Hex-encoded string. 178 | Digest( 179 | "880d364b1f1434b62f322e3d7c3c1d1a3cb2e6cb14fadb6e3752f82f7fc9bd90", algorithm="SHA256" 180 | ), 181 | # Raw bytes. 182 | Digest( 183 | b"\x88\r6K\x1f\x144\xb6/2.=|<\x1d\x1a<\xb2\xe6\xcb\x14\xfa\xdbn7R\xf8/\x7f\xc9\xbd\x90", 184 | algorithm="SHA256", 185 | ), 186 | ): 187 | assert digest.value == "880d364b1f1434b62f322e3d7c3c1d1a3cb2e6cb14fadb6e3752f82f7fc9bd90" 188 | 189 | 190 | def test_digest_hashable(): 191 | digest = Digest.from_encoded("SHA256 0Q/XIPetp+QFDce6EIYNVcNTCZSlPqmEfVs1eFEMK0Y=") 192 | d = {digest: 1} 193 | assert digest in d 194 | 195 | 196 | def test_mapped_sequence(): 197 | ms = MappedSequence([1, 2, 3], {"a": 1, "b": 2, "c": 3}) 198 | assert ms["a"] == 1 199 | assert ms[0] == 1 200 | assert len(ms) == 3 201 | assert "a" in ms 202 | assert 1 in ms 203 | assert list(ms) == [1, 2, 3] 204 | assert set(ms.keys()) == {"a", "b", "c"} 205 | assert ms.get("a") == 1 206 | assert "z" not in ms 207 | 208 | 209 | @pytest.mark.parametrize( 210 | "cluster", [["ai2/jupiter-cirrascale-2", "ai2/saturn-cirrascale"], "ai2/jupiter-cirrascale-2"] 211 | ) 212 | def test_experiment_spec_new_with_cluster(cluster): 213 | spec = ExperimentSpec.new("ai2/allennlp", cluster=cluster) 214 | assert spec.tasks[0].context.cluster is None 215 | assert spec.tasks[0].constraints is not None 216 | assert isinstance(spec.tasks[0].constraints.cluster, list) 217 | 218 | 219 | def test_task_spec_with_constraint(): 220 | task_spec = TaskSpec.new("main", constraints=Constraints(cluster=["ai2/saturn-cirrascale"])) 221 | new_task_spec = task_spec.with_constraint(cluster=["ai2/jupiter-cirrascale-2"]) 222 | assert new_task_spec.constraints is not None 223 | assert new_task_spec.constraints.cluster == ["ai2/jupiter-cirrascale-2"] 224 | # Shouldn't modify the original. 225 | assert task_spec.constraints is not None 226 | assert task_spec.constraints.cluster == ["ai2/saturn-cirrascale"] 227 | 228 | # These methods should all be equivalent. 229 | for task_spec in ( 230 | TaskSpec.new("main", constraints={"cluster": ["ai2/general-cirrascale"]}), 231 | TaskSpec.new("main", cluster="ai2/general-cirrascale"), 232 | TaskSpec.new("main", cluster=["ai2/general-cirrascale"]), 233 | ): 234 | assert task_spec.constraints is not None 235 | assert task_spec.constraints.cluster == ["ai2/general-cirrascale"] 236 | 237 | 238 | def test_constraints_behave_like_dictionaries(): 239 | c = Constraints() 240 | c["cluster"] = ["ai2/general-cirrascale"] 241 | assert c.cluster == ["ai2/general-cirrascale"] 242 | 243 | 244 | def test_constraints_extra_fields(): 245 | c = Constraints(cluster=["ai2/general-cirrascale"], gpus=["A100"]) # type: ignore 246 | assert hasattr(c, "gpus") 247 | 248 | 249 | def test_job_status_with_canceled_code(): 250 | from datetime import datetime 251 | 252 | status = JobStatus(created=datetime.utcnow(), canceled_code=0) 253 | assert status.canceled_code == CanceledCode.not_set 254 | 255 | status = JobStatus(created=datetime.utcnow(), canceled_code=6) 256 | assert status.canceled_code == 6 257 | -------------------------------------------------------------------------------- /beaker/services/group.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, List, Optional, Union 2 | 3 | from ..data_model import * 4 | from ..exceptions import * 5 | from .service_client import ServiceClient 6 | 7 | 8 | class GroupClient(ServiceClient): 9 | """ 10 | Accessed via :data:`Beaker.group `. 11 | """ 12 | 13 | def get(self, group: str) -> Group: 14 | """ 15 | Get info about a group. 16 | 17 | :param group: The group ID or name. 18 | 19 | :raises GroupNotFound: If the group can't be found. 20 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 21 | :raises RequestException: Any other exception that can occur when contacting the 22 | Beaker server. 23 | """ 24 | 25 | def _get(id: str) -> Group: 26 | return Group.from_json( 27 | self.request( 28 | f"groups/{self.url_quote(id)}", 29 | exceptions_for_status={404: GroupNotFound(self._not_found_err_msg(id))}, 30 | ).json() 31 | ) 32 | 33 | try: 34 | # Could be an ID or full name, so we try that first. 35 | return _get(group) 36 | except GroupNotFound: 37 | if "/" not in group: 38 | # Try with adding the account name. 39 | try: 40 | return _get(f"{self.beaker.account.name}/{group}") 41 | except GroupNotFound: 42 | pass 43 | raise 44 | 45 | def create( 46 | self, 47 | name: str, 48 | *experiments: Union[str, Experiment], 49 | description: Optional[str] = None, 50 | workspace: Optional[Union[Workspace, str]] = None, 51 | ) -> Group: 52 | """ 53 | :param name: The name to assign the group. 54 | :param experiments: Experiments to add to the group. 55 | :param description: Group description. 56 | :param workspace: The workspace to create the group under. If not specified, 57 | :data:`Beaker.config.default_workspace ` is used. 58 | 59 | :raises ValueError: If the name is invalid. 60 | :raises GroupConflict: If a group with the given name already exists. 61 | :raises ExperimentNotFound: If any of the given experiments don't exist. 62 | :raises WorkspaceNotSet: If neither ``workspace`` nor 63 | :data:`Beaker.config.default_workspace ` are set. 64 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 65 | :raises RequestException: Any other exception that can occur when contacting the 66 | Beaker server. 67 | """ 68 | self.validate_beaker_name(name) 69 | workspace = self.resolve_workspace(workspace) 70 | exp_ids: List[str] = list( 71 | set([self.resolve_experiment(experiment).id for experiment in experiments]) 72 | ) 73 | group_data = self.request( 74 | "groups", 75 | method="POST", 76 | data=GroupSpec( 77 | name=name, 78 | description=description, 79 | workspace=workspace.full_name, 80 | experiments=exp_ids, 81 | ), 82 | exceptions_for_status={409: GroupConflict(name)}, 83 | ).json() 84 | return self.get(group_data["id"]) 85 | 86 | def delete(self, group: Union[str, Group]): 87 | """ 88 | Delete a group. 89 | 90 | :param group: The group ID, name, or object. 91 | 92 | :raises GroupNotFound: If the group can't be found. 93 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 94 | :raises RequestException: Any other exception that can occur when contacting the 95 | Beaker server. 96 | """ 97 | group_id = self.resolve_group(group).id 98 | self.request( 99 | f"groups/{self.url_quote(group_id)}", 100 | method="DELETE", 101 | exceptions_for_status={404: GroupNotFound(self._not_found_err_msg(group))}, 102 | ) 103 | 104 | def rename(self, group: Union[str, Group], name: str) -> Group: 105 | """ 106 | Rename a group. 107 | 108 | :param group: The group ID, name, or object. 109 | :param name: The new name for the group. 110 | 111 | :raises ValueError: If the new name is invalid. 112 | :raises GroupNotFound: If the group can't be found. 113 | :raises GroupConflict: If a group by that name already exists. 114 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 115 | :raises RequestException: Any other exception that can occur when contacting the 116 | Beaker server. 117 | """ 118 | self.validate_beaker_name(name) 119 | group_id = self.resolve_group(group).id 120 | return Group.from_json( 121 | self.request( 122 | f"groups/{self.url_quote(group_id)}", 123 | method="PATCH", 124 | data=GroupPatch(name=name), 125 | exceptions_for_status={ 126 | 404: GroupNotFound(self._not_found_err_msg(group)), 127 | 409: GroupConflict(name), 128 | }, 129 | ).json() 130 | ) 131 | 132 | def add_experiments(self, group: Union[str, Group], *experiments: Union[str, Experiment]): 133 | """ 134 | Add experiments to a group. 135 | 136 | :param group: The group ID, name, or object. 137 | :param experiments: Experiments to add to the group. 138 | 139 | :raises GroupNotFound: If the group can't be found. 140 | :raises ExperimentNotFound: If any of the given experiments don't exist. 141 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 142 | :raises RequestException: Any other exception that can occur when contacting the 143 | Beaker server. 144 | """ 145 | group_id = self.resolve_group(group).id 146 | exp_ids: List[str] = list( 147 | set([self.resolve_experiment(experiment).id for experiment in experiments]) 148 | ) 149 | self.request( 150 | f"groups/{self.url_quote(group_id)}", 151 | method="PATCH", 152 | data=GroupPatch(add_experiments=exp_ids), 153 | exceptions_for_status={404: GroupNotFound(self._not_found_err_msg(group))}, 154 | ) 155 | 156 | def remove_experiments(self, group: Union[str, Group], *experiments: Union[str, Experiment]): 157 | """ 158 | Remove experiments from a group. 159 | 160 | :param group: The group ID, name, or object. 161 | :param experiments: Experiments to remove from the group. 162 | 163 | :raises GroupNotFound: If the group can't be found. 164 | :raises ExperimentNotFound: If any of the given experiments don't exist. 165 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 166 | :raises RequestException: Any other exception that can occur when contacting the 167 | Beaker server. 168 | """ 169 | group_id = self.resolve_group(group).id 170 | exp_ids: List[str] = list( 171 | set([self.resolve_experiment(experiment).id for experiment in experiments]) 172 | ) 173 | self.request( 174 | f"groups/{self.url_quote(group_id)}", 175 | method="PATCH", 176 | data=GroupPatch(remove_experiments=exp_ids), 177 | exceptions_for_status={404: GroupNotFound(self._not_found_err_msg(group))}, 178 | ) 179 | 180 | def list_experiments(self, group: Union[str, Group]) -> List[Experiment]: 181 | """ 182 | List experiments in a group. 183 | 184 | :param group: The group ID, name, or object. 185 | 186 | :raises GroupNotFound: If the group can't be found. 187 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 188 | :raises RequestException: Any other exception that can occur when contacting the 189 | Beaker server. 190 | """ 191 | group_id = self.resolve_group(group).id 192 | exp_ids = self.request( 193 | f"groups/{self.url_quote(group_id)}/experiments", 194 | method="GET", 195 | exceptions_for_status={404: GroupNotFound(self._not_found_err_msg(group))}, 196 | ).json() 197 | # TODO: make these requests concurrently. 198 | return [self.beaker.experiment.get(exp_id) for exp_id in exp_ids or []] 199 | 200 | def export_experiments( 201 | self, group: Union[str, Group], quiet: bool = False 202 | ) -> Generator[bytes, None, None]: 203 | """ 204 | Export all experiments and metrics in a group as a CSV. 205 | 206 | Returns a generator that should be exhausted to get the complete file. 207 | 208 | :param group: The group ID, name, or object. 209 | :param quiet: If ``True``, progress won't be displayed. 210 | 211 | :raises GroupNotFound: If the group can't be found. 212 | :raises BeakerError: Any other :class:`~beaker.exceptions.BeakerError` type that can occur. 213 | :raises RequestException: Any other exception that can occur when contacting the 214 | Beaker server. 215 | """ 216 | group_id = self.resolve_group(group).id 217 | resp = self.request( 218 | f"groups/{self.url_quote(group_id)}/export.csv", 219 | method="GET", 220 | exceptions_for_status={404: GroupNotFound(self._not_found_err_msg(group))}, 221 | stream=True, 222 | ).iter_content(chunk_size=1024) 223 | 224 | from ..progress import get_group_experiments_progress 225 | 226 | with get_group_experiments_progress(quiet) as progress: 227 | task_id = progress.add_task("Downloading:") 228 | total = 0 229 | for chunk in resp: 230 | if chunk: 231 | advance = len(chunk) 232 | total += advance 233 | progress.update(task_id, total=total + 1, advance=advance) 234 | yield chunk 235 | 236 | def url(self, group: Union[str, Group]) -> str: 237 | """ 238 | Get the URL for a group. 239 | 240 | :param group: The group ID, name, or object. 241 | 242 | :raises GroupNotFound: If the group can't be found. 243 | """ 244 | group_id = self.resolve_group(group).id 245 | return f"{self.config.agent_address}/gr/{self.url_quote(group_id)}/compare" 246 | 247 | def _not_found_err_msg(self, group: Union[str, Group]) -> str: 248 | group = group if isinstance(group, str) else group.id 249 | return ( 250 | f"'{group}': Make sure you're using a valid Beaker group ID or the " 251 | f"*full* name of the group (with the account prefix, e.g. 'username/group_name')" 252 | ) 253 | --------------------------------------------------------------------------------