├── bin ├── hermit.hcl ├── pip ├── .python3@3.9.pkg ├── pip3 ├── pip3.9 ├── pydoc3 ├── pydoc3.9 ├── python ├── python3 ├── python3.9 ├── python3-config ├── python3.9-config ├── README.hermit.md ├── activate-hermit └── hermit ├── tests ├── __init__.py ├── resource_fixtures.py ├── test_remote_no_prefect.py ├── test_prefect_v3_environment.py ├── test_remote.py ├── test_local_executor.py ├── test_vertex_job.py ├── test_resources.py ├── test_prefect_v2_environment.py ├── executors │ └── databricks │ │ └── resource │ │ └── test_python_library.py ├── test_vertex_executor.py ├── test_databricks_executor.py ├── test_config.py └── test_torch.py ├── block_cascade ├── executors │ ├── databricks │ │ ├── __init__.py │ │ ├── run.py │ │ ├── job.py │ │ ├── filesystem.py │ │ ├── resource.py │ │ └── executor.py │ ├── vertex │ │ ├── distributed │ │ │ ├── __init__.py │ │ │ ├── torchrun_target.py │ │ │ ├── torch_job.py │ │ │ └── distributed_job.py │ │ ├── __init__.py │ │ ├── tune.py │ │ ├── run.py │ │ ├── resource.py │ │ ├── job.py │ │ └── executor.py │ ├── local │ │ ├── __init__.py │ │ └── executor.py │ ├── __init__.py │ └── executor.py ├── consts.py ├── __init__.py ├── prefect │ ├── v1 │ │ ├── __init__.py │ │ └── environment.py │ ├── __init__.py │ ├── v3 │ │ ├── environment.py │ │ └── __init__.py │ └── v2 │ │ ├── environment.py │ │ └── __init__.py ├── concurrency │ └── __init__.py ├── gcp │ ├── __init__.py │ └── monitoring.py ├── utils.py ├── config.py └── decorators.py ├── assets └── Cash_App.png ├── LICENSE.txt ├── .github └── workflows │ ├── python-publish.yml │ └── python-app.yml ├── pyproject.toml ├── .gitignore └── README.md /bin/hermit.hcl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bin/pip: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/.python3@3.9.pkg: -------------------------------------------------------------------------------- 1 | hermit -------------------------------------------------------------------------------- /bin/pip3: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/pip3.9: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/pydoc3: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/pydoc3.9: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/python: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/python3: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/python3.9: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/python3-config: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /bin/python3.9-config: -------------------------------------------------------------------------------- 1 | .python3@3.9.pkg -------------------------------------------------------------------------------- /block_cascade/executors/databricks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /block_cascade/executors/vertex/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /block_cascade/consts.py: -------------------------------------------------------------------------------- 1 | SERVICE = "aiplatform.googleapis.com" 2 | -------------------------------------------------------------------------------- /assets/Cash_App.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/square/cascade/main/assets/Cash_App.png -------------------------------------------------------------------------------- /block_cascade/executors/local/__init__.py: -------------------------------------------------------------------------------- 1 | from block_cascade.executors.local.executor import LocalExecutor # noqa: F401 2 | -------------------------------------------------------------------------------- /bin/README.hermit.md: -------------------------------------------------------------------------------- 1 | # Hermit environment 2 | 3 | This is a [Hermit](https://github.com/cashapp/hermit) bin directory. 4 | 5 | The symlinks in this directory are managed by Hermit and will automatically 6 | download and install Hermit itself as well as packages. These packages are 7 | local to this environment. 8 | -------------------------------------------------------------------------------- /block_cascade/executors/__init__.py: -------------------------------------------------------------------------------- 1 | from block_cascade.executors.databricks.executor import DatabricksExecutor # noqa: F401 2 | from block_cascade.executors.executor import Executor # noqa: F401 3 | from block_cascade.executors.local.executor import LocalExecutor # noqa: F401 4 | from block_cascade.executors.vertex.executor import VertexExecutor # noqa: F401 5 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2022 Square Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /block_cascade/__init__.py: -------------------------------------------------------------------------------- 1 | from block_cascade.executors.databricks.resource import ( # noqa: F401 2 | DatabricksAutoscaleConfig, 3 | DatabricksResource, 4 | ) 5 | from block_cascade.executors.vertex.resource import ( # noqa: F401 6 | GcpAcceleratorConfig, 7 | GcpEnvironmentConfig, 8 | GcpMachineConfig, 9 | GcpResource, 10 | ) 11 | from block_cascade.decorators import remote # noqa: F401 12 | 13 | __all__ = [ 14 | "DatabricksAutoscaleConfig", 15 | "DatabricksResource", 16 | "GcpAcceleratorConfig", 17 | "GcpEnvironmentConfig", 18 | "GcpMachineConfig", 19 | "GcpResource", 20 | "remote", 21 | ] 22 | -------------------------------------------------------------------------------- /block_cascade/executors/vertex/__init__.py: -------------------------------------------------------------------------------- 1 | from block_cascade.executors.vertex.executor import VertexExecutor # noqa: F401 2 | from block_cascade.executors.vertex.job import VertexJob # noqa: F401 3 | from block_cascade.executors.vertex.resource import ( 4 | GcpAcceleratorConfig, # noqa: F401 5 | GcpEnvironmentConfig, # noqa: F401 6 | GcpMachineConfig, # noqa: F401 7 | GcpResource, # noqa: F401 8 | ) 9 | from block_cascade.executors.vertex.tune import ( 10 | ParamCategorical, # noqa: F401 11 | ParamDiscrete, # noqa: F401 12 | ParamDouble, # noqa: F401 13 | ParamInteger, # noqa: F401 14 | Tune, # noqa: F401 15 | TuneResult, # noqa: F401 16 | ) 17 | -------------------------------------------------------------------------------- /bin/activate-hermit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This file must be used with "source bin/activate-hermit" from bash or zsh. 3 | # You cannot run it directly 4 | # 5 | # THIS FILE IS GENERATED; DO NOT MODIFY 6 | 7 | if [ "${BASH_SOURCE-}" = "$0" ]; then 8 | echo "You must source this script: \$ source $0" >&2 9 | exit 33 10 | fi 11 | 12 | BIN_DIR="$(dirname "${BASH_SOURCE[0]:-${(%):-%x}}")" 13 | if "${BIN_DIR}/hermit" noop > /dev/null; then 14 | eval "$("${BIN_DIR}/hermit" activate "${BIN_DIR}/..")" 15 | 16 | if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ]; then 17 | hash -r 2>/dev/null 18 | fi 19 | 20 | echo "Hermit environment $("${HERMIT_ENV}"/bin/hermit env HERMIT_ENV) activated" 21 | fi 22 | -------------------------------------------------------------------------------- /tests/resource_fixtures.py: -------------------------------------------------------------------------------- 1 | from block_cascade.executors.vertex.resource import ( 2 | GcpEnvironmentConfig, 3 | GcpMachineConfig, 4 | GcpResource, 5 | ) 6 | 7 | GCP_PROJECT = "test-project" 8 | TEST_BUCKET = GCP_PROJECT 9 | REGION = "us-west1" 10 | 11 | chief_machine = GcpMachineConfig( 12 | type="n1-standard-4", 13 | count=1 14 | ) 15 | gcp_environment = GcpEnvironmentConfig( 16 | project=GCP_PROJECT, 17 | storage_location="gs://bucket/path/to/file", 18 | region="us-west1", 19 | service_account=f"{GCP_PROJECT}@{GCP_PROJECT}.iam.gserviceaccount.com", 20 | image="cascade", 21 | ) 22 | gcp_resource = GcpResource( 23 | chief=chief_machine, 24 | environment=gcp_environment, 25 | ) 26 | -------------------------------------------------------------------------------- /block_cascade/prefect/v1/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | import prefect 5 | from prefect.backend.flow import FlowView 6 | 7 | 8 | def get_from_prefect_context(attr: str, default: str = "") -> str: 9 | return prefect.context.get(attr, default) 10 | 11 | 12 | def is_prefect_cloud_deployment() -> bool: 13 | flow_id = get_from_prefect_context("flow_id") 14 | if not flow_id: 15 | return False 16 | 17 | try: 18 | FlowView.from_flow_id(flow_id) 19 | return True 20 | except prefect.exceptions.ClientError: 21 | return False 22 | 23 | 24 | def get_prefect_logger(name: str = "") -> Union[logging.LoggerAdapter, logging.Logger]: 25 | return prefect.context.get("logger") 26 | -------------------------------------------------------------------------------- /block_cascade/prefect/__init__.py: -------------------------------------------------------------------------------- 1 | from block_cascade.utils import PREFECT_VERSION 2 | 3 | if PREFECT_VERSION == 1: 4 | from .v1 import ( 5 | get_from_prefect_context, 6 | get_prefect_logger, 7 | is_prefect_cloud_deployment, 8 | ) 9 | from .v1.environment import PrefectEnvironmentClient 10 | elif PREFECT_VERSION == 2: 11 | from .v2 import ( 12 | get_from_prefect_context, 13 | get_prefect_logger, 14 | is_prefect_cloud_deployment, 15 | ) 16 | from .v2.environment import PrefectEnvironmentClient 17 | else: 18 | from .v3 import ( 19 | get_from_prefect_context, 20 | get_prefect_logger, 21 | is_prefect_cloud_deployment, 22 | ) 23 | from .v3.environment import PrefectEnvironmentClient 24 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | deploy: 10 | 11 | runs-on: ubuntu-latest 12 | 13 | permissions: 14 | id-token: write 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: '3.x' 22 | - name: Install poetry 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install poetry 26 | - name: Install dependencies 27 | run: | 28 | poetry install 29 | - name: Build 30 | run: | 31 | poetry build 32 | - name: Publish package to PyPI 33 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /block_cascade/concurrency/__init__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Coroutine, TypeVar 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | def run_async(async_fn: Coroutine[Any, Any, T]) -> T: 8 | """ 9 | Executes a coroutine and blocks until the result is returned. 10 | 11 | This minimal implementation currently supports the case 12 | where there is a event loop already created in the current thread 13 | or no event loop exists at all in the current thread. 14 | """ 15 | try: 16 | running_loop = asyncio.get_running_loop() 17 | except RuntimeError: 18 | running_loop = None 19 | 20 | if running_loop: 21 | task = running_loop.create_task(async_fn) 22 | return running_loop.run_until_complete(task) 23 | else: 24 | return asyncio.run(async_fn) 25 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | name: Python application 2 | 3 | on: 4 | push: 5 | branches: '*' 6 | pull_request: 7 | branches: '*' 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ["3.9", "3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install poetry 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install poetry 27 | - name: Install dependencies 28 | run: | 29 | poetry install --without torch 30 | - name: Test 31 | run: | 32 | poetry run pytest -------------------------------------------------------------------------------- /tests/test_remote_no_prefect.py: -------------------------------------------------------------------------------- 1 | from block_cascade import GcpEnvironmentConfig, GcpMachineConfig, GcpResource 2 | from block_cascade import remote 3 | 4 | GCP_PROJECT = "test-project" 5 | GCP_STORAGE_LOCATION = f"gs://{GCP_PROJECT}-cascade/" 6 | 7 | 8 | def test_remote_local_override(): 9 | """Test the remote decorator with the local executor using the no_resource_on_local flag.""" 10 | machine_config = GcpMachineConfig(type="n1-standard-4") 11 | gcp_resource = GcpResource( 12 | chief=machine_config, 13 | environment=GcpEnvironmentConfig( 14 | storage_location=GCP_STORAGE_LOCATION, project=GCP_PROJECT 15 | ), 16 | ) 17 | 18 | @remote(resource=gcp_resource, remote_resource_on_local=False) 19 | def multiply(a: int, b: int) -> int: 20 | return a * b 21 | 22 | result = multiply(1, 2) 23 | 24 | assert result == 2 25 | -------------------------------------------------------------------------------- /block_cascade/executors/vertex/tune.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum, auto 3 | from typing import List, Optional, Union 4 | 5 | 6 | class Scale(Enum): 7 | UNIT_LINEAR_SCALE = auto() 8 | UNIT_LOG_SCALE = auto() 9 | UNIT_REVERSE_LOG_SCALE = auto() 10 | 11 | 12 | @dataclass 13 | class ParamDiscrete: 14 | name: str 15 | values: List[float] 16 | 17 | 18 | @dataclass 19 | class ParamCategorical: 20 | name: str 21 | values: List[str] 22 | 23 | 24 | @dataclass 25 | class ParamDouble: 26 | name: str 27 | min: float 28 | max: float 29 | scale: Optional[Scale] = None 30 | 31 | 32 | @dataclass 33 | class ParamInteger: 34 | name: str 35 | min: float 36 | max: float 37 | scale: Optional[Scale] = None 38 | 39 | 40 | @dataclass 41 | class Tune: 42 | metric: str 43 | params: List[Union[ParamDiscrete, ParamCategorical, ParamInteger, ParamDouble]] 44 | goal: str = "MAXIMIZE" 45 | trials: int = 1 46 | parallel: int = 1 47 | resume_previous_job_id: Optional[str] = None 48 | algorithm: Optional[str] = None 49 | 50 | 51 | @dataclass 52 | class TuneResult: 53 | # TODO because AIP tune only reports the "main" metric, the trials df 54 | # only contains that metric as well. Whenever we start to want to inspect 55 | # these ping @baxen to come back to have this actually go find them all 56 | metric: float 57 | hyperparameters: dict 58 | trials: List[dict] 59 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "block-cascade" 3 | packages = [ 4 | {include = "block_cascade"} 5 | ] 6 | version = "3.1.0" 7 | description = "Library for model training in multi-cloud environment." 8 | readme = "README.md" 9 | authors = ["Block"] 10 | 11 | [tool.poetry.dependencies] 12 | python = ">=3.9,<3.13" 13 | cloudml-hypertune = "==0.1.0.dev6" 14 | cloudpickle = "^2.0" 15 | databricks-cli = ">=0.17.7" 16 | gcsfs = ">=2024" 17 | google-auth = "^2.23.2" 18 | google-cloud-aiplatform = "^1.39.0" 19 | google-cloud-monitoring = "^2.16.0" 20 | google-cloud-resource-manager = "^1.10.4" 21 | importlib_resources = {version="*", python="<3.9"} 22 | prefect = ">=2.0,<4.0.0" 23 | pydantic = ">=2.0.0,<3.0.0" 24 | s3fs = ">=2024" 25 | 26 | [tool.poetry.group.torch.dependencies] 27 | torch = ">=1.13.1" 28 | torchvision = ">=0.14.1" 29 | 30 | [tool.poetry.group.dev.dependencies] 31 | pytest = ">=7.3.1" 32 | pytest-env = "^0.8.1" 33 | pytest-mock = "^3.11.1" 34 | dask = {extras = ["distributed"], version = ">=2022"} 35 | pyfakefs = "<5.3" 36 | 37 | [[source]] 38 | name = "pypi" 39 | url = "https://pypi.org/simple" 40 | verify_ssl = true 41 | 42 | [build-system] 43 | requires = ["poetry-core>=1.0.0"] 44 | build-backend = "poetry.core.masonry.api" 45 | 46 | [tool.pytest.ini_options] 47 | log_cli = true 48 | log_cli_level = "INFO" 49 | log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" 50 | log_cli_date_format = "%Y-%m-%d %H:%M:%S" 51 | 52 | [tool.poetry.scripts] 53 | cascade = "block_cascade.cli:cli" 54 | -------------------------------------------------------------------------------- /block_cascade/executors/local/executor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from typing import Callable 4 | 5 | from fsspec.implementations.local import LocalFileSystem 6 | 7 | from block_cascade.executors.executor import Executor 8 | 9 | 10 | class LocalExecutor(Executor): 11 | """Submits or runs tasks in a local process""" 12 | 13 | def __init__(self, func: Callable): 14 | super().__init__(func=func) 15 | self._fs = LocalFileSystem(auto_mkdir=True) 16 | 17 | @property 18 | def storage_location(self): 19 | """ 20 | Returns the path to the local storage location for staging pickled functions 21 | and returning results. The directory is in the 22 | format /Users//cascade-storage/. 23 | """ 24 | return f"{os.path.expanduser('~')}/cascade-storage/" 25 | 26 | def _start(self): 27 | """ 28 | Starts a task to the local process by calling _execute. 29 | """ 30 | return self._result() 31 | 32 | def _run(self): 33 | """ 34 | Runs the task locally by calling _start(), called from Executor.run() 35 | """ 36 | return self._start() 37 | 38 | def _result(self): 39 | """ 40 | Executes a task by calling task.function() or tune(task) if 41 | task.tune is not None. 42 | """ 43 | function = self.func 44 | result = function() 45 | with self.fs.open(self.output_filepath, "wb") as f: 46 | pickle.dump(result, f) 47 | return result 48 | -------------------------------------------------------------------------------- /bin/hermit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # THIS FILE IS GENERATED; DO NOT MODIFY 4 | 5 | set -eo pipefail 6 | 7 | export HERMIT_USER_HOME=~ 8 | 9 | if [ -z "${HERMIT_STATE_DIR}" ]; then 10 | case "$(uname -s)" in 11 | Darwin) 12 | export HERMIT_STATE_DIR="${HERMIT_USER_HOME}/Library/Caches/hermit" 13 | ;; 14 | Linux) 15 | export HERMIT_STATE_DIR="${XDG_CACHE_HOME:-${HERMIT_USER_HOME}/.cache}/hermit" 16 | ;; 17 | esac 18 | fi 19 | 20 | export HERMIT_DIST_URL="${HERMIT_DIST_URL:-https://github.com/cashapp/hermit/releases/download/stable}" 21 | HERMIT_CHANNEL="$(basename "${HERMIT_DIST_URL}")" 22 | export HERMIT_CHANNEL 23 | export HERMIT_EXE=${HERMIT_EXE:-${HERMIT_STATE_DIR}/pkg/hermit@${HERMIT_CHANNEL}/hermit} 24 | 25 | if [ ! -x "${HERMIT_EXE}" ]; then 26 | echo "Bootstrapping ${HERMIT_EXE} from ${HERMIT_DIST_URL}" 1>&2 27 | INSTALL_SCRIPT="$(mktemp)" 28 | # This value must match that of the install script 29 | INSTALL_SCRIPT_SHA256="180e997dd837f839a3072a5e2f558619b6d12555cd5452d3ab19d87720704e38" 30 | if [ "${INSTALL_SCRIPT_SHA256}" = "BYPASS" ]; then 31 | curl -fsSL "${HERMIT_DIST_URL}/install.sh" -o "${INSTALL_SCRIPT}" 32 | else 33 | # Install script is versioned by its sha256sum value 34 | curl -fsSL "${HERMIT_DIST_URL}/install-${INSTALL_SCRIPT_SHA256}.sh" -o "${INSTALL_SCRIPT}" 35 | # Verify install script's sha256sum 36 | openssl dgst -sha256 "${INSTALL_SCRIPT}" | \ 37 | awk -v EXPECTED="$INSTALL_SCRIPT_SHA256" \ 38 | '$2!=EXPECTED {print "Install script sha256 " $2 " does not match " EXPECTED; exit 1}' 39 | fi 40 | /bin/bash "${INSTALL_SCRIPT}" 1>&2 41 | fi 42 | 43 | exec "${HERMIT_EXE}" --level=fatal exec "$0" -- "$@" 44 | -------------------------------------------------------------------------------- /tests/test_prefect_v3_environment.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import Mock, patch 3 | 4 | from prefect.client.schemas.responses import DeploymentResponse 5 | 6 | @pytest.fixture 7 | def prefect_environment_client(mock_deployment_response): 8 | with patch("block_cascade.utils.PREFECT_VERSION", 3), \ 9 | patch("block_cascade.prefect.v3.environment.runtime.deployment.id", "mock-deployment-id"), \ 10 | patch("block_cascade.prefect.v3.environment._fetch_deployment", return_value=mock_deployment_response): 11 | from block_cascade.prefect.v3.environment import PrefectEnvironmentClient 12 | client = PrefectEnvironmentClient() 13 | yield client 14 | 15 | 16 | @pytest.fixture 17 | def mock_job_variables(): 18 | return { 19 | "image": "job_image", 20 | "network": "job_network", 21 | "credentials": {"project": "job_project"}, 22 | "region": "job_region", 23 | "service_account_name": "job_service_account" 24 | } 25 | 26 | @pytest.fixture 27 | def mock_deployment_response(mock_job_variables): 28 | mock_deployment = Mock(spec=DeploymentResponse) 29 | mock_deployment.job_variables = mock_job_variables 30 | mock_deployment.infrastructure_document_id = "mock_infrastructure_id" 31 | return mock_deployment 32 | 33 | def test_get_container_image(prefect_environment_client): 34 | assert prefect_environment_client.get_container_image() == "job_image" 35 | 36 | def test_get_network(prefect_environment_client): 37 | assert prefect_environment_client.get_network() == "job_network" 38 | 39 | def test_get_project(prefect_environment_client): 40 | assert prefect_environment_client.get_project() == "job_project" 41 | 42 | def test_get_region(prefect_environment_client): 43 | assert prefect_environment_client.get_region() == "job_region" 44 | 45 | def test_get_service_account(prefect_environment_client): 46 | assert prefect_environment_client.get_service_account() == "job_service_account" 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | *.ipynb 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # hermit 100 | .hermit/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | .dmypy.json 115 | dmypy.json 116 | 117 | # Pyre type checker 118 | .pyre/ 119 | 120 | # Datafiles 121 | *.csv 122 | *.gz 123 | *.h5 124 | *.pkl 125 | *.pk 126 | *.db 127 | *.db-journal 128 | 129 | # Configuration 130 | .vscode/ 131 | .idea/ 132 | -------------------------------------------------------------------------------- /tests/test_remote.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from unittest.mock import Mock 3 | 4 | import pytest 5 | 6 | from block_cascade import GcpEnvironmentConfig, GcpMachineConfig, GcpResource 7 | from block_cascade import remote 8 | from block_cascade.utils import PREFECT_VERSION 9 | 10 | if PREFECT_VERSION == 2: 11 | from prefect.context import FlowRunContext, TaskRunContext 12 | 13 | GCP_PROJECT = "test-project" 14 | GCP_STORAGE_LOCATION = f"gs://{GCP_PROJECT}-cascade/" 15 | 16 | 17 | @pytest.fixture(autouse=True) 18 | def patch_prefect_apis(mocker): 19 | if PREFECT_VERSION == 2: 20 | mocker.patch( 21 | "block_cascade.prefect.v2.get_run_logger", 22 | return_value=Mock(spec=logging.Logger), 23 | ) 24 | mocker.patch( 25 | "block_cascade.prefect.v2.FlowRunContext", 26 | return_value=Mock(spec=FlowRunContext), 27 | ) 28 | mocker.patch( 29 | "block_cascade.prefect.v2.TaskRunContext", 30 | return_value=Mock(spec=TaskRunContext), 31 | ) 32 | mocker.patch( 33 | "block_cascade.prefect.v2.get_current_deployment", return_value=None 34 | ) 35 | 36 | 37 | def test_remote(): 38 | """Test the remote decorator with the local executor.""" 39 | 40 | @remote 41 | def addition(a: int, b: int) -> int: 42 | return a + b 43 | 44 | result = addition(1, 2) 45 | 46 | assert result == 3 47 | 48 | 49 | def test_remote_no_sugar(): 50 | """Test using the decorator with syntactic sugar.""" 51 | 52 | def addition(a: int, b: int) -> int: 53 | return a + b 54 | 55 | addition_remote = remote(func=addition) 56 | result = addition_remote(1, 2) 57 | assert result == 3 58 | 59 | 60 | def test_exception_when_environment_cannot_be_inferred(): 61 | machine_config = GcpMachineConfig(type="n1-standard-4") 62 | remote_resource = GcpResource( 63 | chief=machine_config, 64 | environment=GcpEnvironmentConfig( 65 | storage_location=GCP_STORAGE_LOCATION, project=GCP_PROJECT 66 | ), 67 | ) 68 | 69 | @remote 70 | def addition(a: int, b: int) -> int: 71 | return a + b 72 | 73 | with pytest.raises(RuntimeError): 74 | addition( 75 | 1, 76 | 2, 77 | remote_resource=remote_resource, 78 | ) 79 | -------------------------------------------------------------------------------- /tests/test_local_executor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import MagicMock, patch 3 | 4 | import cloudpickle 5 | from fsspec.implementations.local import LocalFileSystem 6 | 7 | from block_cascade.executors import LocalExecutor 8 | from block_cascade.utils import wrapped_partial 9 | 10 | 11 | def addition(a: int, b: int) -> int: 12 | """Adds two numbers together.""" 13 | return a + b 14 | 15 | 16 | # create a mock of addition and set its name to "addition" 17 | mocked_addition = MagicMock(return_value=3) 18 | mocked_addition.__name__ = "addition" 19 | 20 | prepared_addition = wrapped_partial(addition, 1, 2) 21 | 22 | 23 | def test_local_executor(): 24 | """Test the local executor run method.""" 25 | 26 | executor = LocalExecutor(func=prepared_addition) 27 | result = executor.run() 28 | assert result == 3 29 | 30 | 31 | def test_run_twice(): 32 | """Tests that if the executor is run twice 33 | the second run executes the function again and stores it in a unique file. 34 | """ 35 | 36 | executor = LocalExecutor(func=mocked_addition) 37 | 38 | result1 = executor.run() 39 | result2 = executor.run() 40 | 41 | assert mocked_addition.call_count == 2 42 | assert result1 == result2 43 | 44 | 45 | def test_new_executor(): 46 | """ 47 | Tests generating a new executor from an existing one. 48 | """ 49 | mocked_addition.call_count = 0 50 | 51 | executor1 = LocalExecutor(func=mocked_addition) 52 | result1 = executor1.run() 53 | 54 | executor2 = executor1.with_() 55 | result2 = executor2.run() 56 | 57 | assert mocked_addition.call_count == 2 58 | assert executor1 != executor2 59 | assert result1 == result2 60 | 61 | 62 | @patch("block_cascade.executors.executor.uuid4", return_value="12345") 63 | def test_result(mock_uuid4): 64 | """ 65 | Tests that a file containing a pickled function can be opened, the function run 66 | and the results written to a local filepath. 67 | """ 68 | fs = LocalFileSystem(auto_mkdir=True) 69 | executor = LocalExecutor(func=mocked_addition) 70 | 71 | path_root = os.path.expanduser("~") 72 | 73 | # test that the staged_filepath was created correctly 74 | assert executor.staged_filepath == f"{path_root}/cascade-storage/12345/function.pkl" 75 | 76 | # stage the pickled function to the staged_filedpath 77 | with fs.open(executor.staged_filepath, "wb") as f: 78 | cloudpickle.dump(wrapped_partial, f) 79 | 80 | result = executor._result() 81 | assert result == 3 82 | -------------------------------------------------------------------------------- /tests/test_vertex_job.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from block_cascade.executors.vertex.job import VertexJob 4 | from block_cascade.executors.vertex.tune import ParamDouble, ParamInteger, Tune 5 | from tests.resource_fixtures import gcp_resource, GCP_PROJECT 6 | 7 | CONTAINER_SPEC = { 8 | "image_uri": "us.gcr.io/test_project/cascade", 9 | "command": [ 10 | "python", 11 | "-m", 12 | "cascade.executors.vertex.run", 13 | "gs://bucket/path/to/file", 14 | "gs://bucket/path/to/output", 15 | ], 16 | "args": [], 17 | } 18 | 19 | 20 | def test_vertex_job(): 21 | job = VertexJob( 22 | display_name="test_job", 23 | resource=gcp_resource, 24 | storage_path="gs://bucket/path/to/file", 25 | labels={"hello": "WORLD"}, 26 | ) 27 | payload = job.create_payload() 28 | assert payload["display_name"] == "test_job" 29 | assert payload["labels"]["hello"] == "world" 30 | 31 | job_spec = payload["job_spec"] 32 | service_account = job_spec["service_account"] 33 | 34 | assert service_account == f"{GCP_PROJECT}@{GCP_PROJECT}.iam.gserviceaccount.com" 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "key,val", 39 | [ 40 | ("1key", "val"), 41 | ("", "val"), 42 | ], 43 | ids=[ 44 | "key starts with number", 45 | "empty key", 46 | ], 47 | ) 48 | def test_invalid_labels_for_vertex_job(key, val): 49 | with pytest.raises(RuntimeError): 50 | job = VertexJob( 51 | display_name="test_job", 52 | resource=gcp_resource, 53 | storage_path="gs://bucket/path/to/file", 54 | labels={key: val}, 55 | ) 56 | job.create_payload() 57 | 58 | 59 | def test_vertex_tune_job(): 60 | """Test that a tuning job can be successfully created.""" 61 | 62 | tune_obj = Tune( 63 | metric="sum", 64 | trials=4, 65 | parallel=2, 66 | params=[ParamDouble("a", 0, 9.3), ParamInteger("b", 0, 4)], 67 | ) 68 | 69 | tune_job = VertexJob( 70 | display_name="test_tune_job", 71 | resource=gcp_resource, 72 | storage_path="gs://bucket/path/to/file", 73 | tune=tune_obj, 74 | ) 75 | 76 | payload = tune_job.create_payload() 77 | assert payload["display_name"] == "test_tune_job" 78 | 79 | assert payload.keys() == { 80 | "display_name", 81 | "trial_job_spec", 82 | "max_trial_count", 83 | "parallel_trial_count", 84 | "study_spec", 85 | "labels", 86 | } 87 | -------------------------------------------------------------------------------- /tests/test_resources.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | from block_cascade.executors.vertex.resource import GcpEnvironmentConfig 4 | from tests.resource_fixtures import ( 5 | gcp_environment, 6 | gcp_resource, 7 | TEST_BUCKET, 8 | GCP_PROJECT, 9 | ) 10 | 11 | BASIC_MACHINE = "n1-standard-4" 12 | STORAGE_LOCATION = "gs://bucket/path/to/file" 13 | 14 | 15 | def test_gcp_resource(): 16 | """Tests that a GCP resource can be instantiated from valid parameters.""" 17 | 18 | # Test that the chief node is correctly configured 19 | assert gcp_resource.chief.type == BASIC_MACHINE 20 | assert gcp_resource.chief.count == 1 21 | 22 | # Test that the worker node is absent as expected 23 | assert gcp_resource.workers is None 24 | 25 | # Test that the environment is correctly configured 26 | assert gcp_resource.environment.project == "test-project" 27 | assert gcp_resource.environment.region == "us-west1" 28 | assert ( 29 | gcp_resource.environment.service_account 30 | == "test-project@test-project.iam.gserviceaccount.com" 31 | ) 32 | assert gcp_resource.environment.image == f"us.gcr.io/{TEST_BUCKET}/cascade" 33 | 34 | 35 | def test_gcp_environment(): 36 | """ 37 | Tests that a GCP environment can be instantiated from valid parameters. 38 | And that its "is_complete" method works as expected. 39 | """ 40 | gcp_environment_complete = gcp_environment 41 | 42 | assert gcp_environment_complete.is_complete is True 43 | 44 | gcp_environment_incomplete = copy(gcp_environment_complete) 45 | gcp_environment_incomplete.image = None 46 | 47 | assert gcp_environment_incomplete.is_complete is False 48 | 49 | 50 | def test_gcp_resource_image(): 51 | """Tests that an image can be overriden on a GcpEnvironment object.""" 52 | environment_config = gcp_environment 53 | 54 | assert environment_config.image == f"us.gcr.io/{TEST_BUCKET}/cascade" 55 | 56 | # update the image value to use a different GCR path 57 | environment_config.image = f"gcr.io/{TEST_BUCKET}/rapids" 58 | assert environment_config.image == f"gcr.io/{TEST_BUCKET}/rapids" 59 | 60 | # update the image value to a different name with a tag 61 | environment_config.image = "cascade:latest" 62 | assert environment_config.image == f"us.gcr.io/{TEST_BUCKET}/cascade:latest" 63 | 64 | # create an object with no image value 65 | environment_config_no_image = GcpEnvironmentConfig( 66 | project=GCP_PROJECT, storage_location=STORAGE_LOCATION 67 | ) 68 | assert environment_config_no_image.image is None 69 | -------------------------------------------------------------------------------- /block_cascade/prefect/v3/environment.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from prefect import runtime 4 | from prefect.client.schemas.responses import DeploymentResponse 5 | 6 | from block_cascade.concurrency import run_async 7 | from block_cascade.gcp import VertexAIEnvironmentInfoProvider 8 | from block_cascade.prefect.v3 import _fetch_deployment 9 | 10 | 11 | class PrefectEnvironmentClient(VertexAIEnvironmentInfoProvider): 12 | """ 13 | A client for fetching Deployment related 14 | metadata from a Prefect 3 Flow. 15 | """ 16 | 17 | def __init__(self): 18 | self._current_deployment = None 19 | self._current_job_variables = None 20 | self._current_infrastructure = None 21 | 22 | def get_container_image(self) -> Optional[str]: 23 | job_variables = self._get_job_variables() 24 | if job_variables: 25 | return job_variables.get("image") 26 | 27 | return None 28 | 29 | def get_network(self) -> Optional[str]: 30 | job_variables = self._get_job_variables() 31 | if job_variables: 32 | return job_variables.get("network") 33 | 34 | return None 35 | 36 | def get_project(self) -> Optional[str]: 37 | job_variables = self._get_job_variables() 38 | if job_variables: 39 | return job_variables.get("credentials", {}).get("project") 40 | 41 | return None 42 | 43 | def get_region(self) -> Optional[str]: 44 | job_variables = self._get_job_variables() 45 | if job_variables: 46 | return job_variables.get("region") 47 | 48 | return None 49 | 50 | def get_service_account(self) -> Optional[str]: 51 | job_variables = self._get_job_variables() 52 | if job_variables: 53 | return job_variables.get("service_account_name") 54 | 55 | return None 56 | 57 | def _get_job_variables(self) -> Optional[Dict]: 58 | current_deployment = self._get_current_deployment() 59 | if not current_deployment: 60 | return None 61 | 62 | if not self._current_job_variables: 63 | self._current_job_variables = current_deployment.job_variables 64 | return self._current_job_variables 65 | 66 | def _get_current_deployment(self) -> Optional[DeploymentResponse]: 67 | deployment_id = runtime.deployment.id 68 | if not deployment_id: 69 | return None 70 | 71 | if not self._current_deployment: 72 | self._current_deployment = run_async( 73 | _fetch_deployment(deployment_id) 74 | ) 75 | return self._current_deployment 76 | -------------------------------------------------------------------------------- /block_cascade/gcp/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | import requests 5 | 6 | 7 | class VertexAIEnvironmentInfoProvider(ABC): 8 | """ 9 | Abstract Base Class for obtaining values 10 | necessary for a Vertex AI Training Custom Job. 11 | """ 12 | 13 | @abstractmethod 14 | def get_container_image(self) -> Optional[str]: 15 | pass 16 | 17 | @abstractmethod 18 | def get_network(self) -> Optional[str]: 19 | pass 20 | 21 | @abstractmethod 22 | def get_project(self) -> Optional[str]: 23 | pass 24 | 25 | @abstractmethod 26 | def get_region(self) -> Optional[str]: 27 | pass 28 | 29 | @abstractmethod 30 | def get_service_account(self) -> Optional[str]: 31 | pass 32 | 33 | 34 | class VMMetadataServerClient(VertexAIEnvironmentInfoProvider): 35 | """ 36 | A client for interacting with the metadata server 37 | for a GCP virtual machine. 38 | """ 39 | 40 | def __init__(self): 41 | self._session = requests.Session() 42 | self._session.headers.update({"Metadata-Flavor": "Google"}) 43 | 44 | def get_container_image(self) -> Optional[str]: 45 | response = self._session.get( 46 | "http://metadata.google.internal/computeMetadata/v1/instance/attributes/", 47 | params={"recursive": True}, 48 | ) 49 | response.raise_for_status() 50 | 51 | instance_attributes = response.json() 52 | return instance_attributes["container"] 53 | 54 | def get_network(self) -> Optional[str]: 55 | response = self._session.get( 56 | "http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/network" 57 | ) 58 | response.raise_for_status() 59 | _, project, _, network = response.text.split("/") 60 | return f"projects/{project}/global/networks/{network}" 61 | 62 | def get_project(self) -> Optional[str]: 63 | response = self._session.get( 64 | "http://metadata.google.internal/computeMetadata/v1/project/project-id" 65 | ) 66 | response.raise_for_status() 67 | return response.text 68 | 69 | def get_region(self) -> Optional[str]: 70 | response = self._session.get( 71 | "http://metadata.google.internal/computeMetadata/v1/instance/zone" 72 | ) 73 | response.raise_for_status() 74 | # Response is in format projects/{project_id}/zones/{region}-{zone} 75 | return response.text.split("/").pop().rsplit("-", maxsplit=1)[0] 76 | 77 | def get_service_account(self) -> Optional[str]: 78 | response = self._session.get( 79 | "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/email" 80 | ) 81 | response.raise_for_status() 82 | return response.text 83 | -------------------------------------------------------------------------------- /tests/test_prefect_v2_environment.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import Mock, patch 3 | 4 | from prefect.client.schemas.responses import DeploymentResponse 5 | 6 | @pytest.fixture 7 | def prefect_environment_client(mock_infrastructure_block, mock_deployment_response): 8 | with patch("block_cascade.utils.PREFECT_VERSION", 2), \ 9 | patch("block_cascade.utils.PREFECT_SUBVERSION", 8), \ 10 | patch("prefect.runtime.deployment.id", "mock-deployment-id"), \ 11 | patch("block_cascade.prefect.v2.environment._fetch_block", return_value=mock_infrastructure_block), \ 12 | patch("block_cascade.prefect.v2.environment._fetch_deployment", return_value=mock_deployment_response): 13 | from block_cascade.prefect.v2.environment import PrefectEnvironmentClient 14 | client = PrefectEnvironmentClient() 15 | yield client 16 | 17 | @pytest.fixture 18 | def mock_infrastructure_block(): 19 | infra_block = Mock() 20 | infra_block.data = { 21 | "image": "infra_image", 22 | "network": "infra_network", 23 | "gcp_credentials": {"project": "infra_project"}, 24 | "region": "infra_region", 25 | "service_account": "infra_service_account" 26 | } 27 | return infra_block 28 | 29 | @pytest.fixture 30 | def mock_job_variables(): 31 | return { 32 | "image": "job_image", 33 | "network": "job_network", 34 | "credentials": {"project": "job_project"}, 35 | "region": "job_region", 36 | "service_account_name": "job_service_account" 37 | } 38 | 39 | @pytest.fixture 40 | def mock_deployment_response(mock_job_variables): 41 | mock_deployment = Mock(spec=DeploymentResponse) 42 | mock_deployment.job_variables = mock_job_variables 43 | mock_deployment.infrastructure_document_id = "mock_infrastructure_id" 44 | return mock_deployment 45 | 46 | def test_get_container_image(prefect_environment_client): 47 | assert prefect_environment_client.get_container_image() == "job_image" 48 | 49 | def test_get_network(prefect_environment_client): 50 | assert prefect_environment_client.get_network() == "job_network" 51 | 52 | def test_get_project(prefect_environment_client): 53 | assert prefect_environment_client.get_project() == "job_project" 54 | 55 | def test_get_region(prefect_environment_client): 56 | assert prefect_environment_client.get_region() == "job_region" 57 | 58 | def test_get_service_account(prefect_environment_client): 59 | assert prefect_environment_client.get_service_account() == "job_service_account" 60 | 61 | def test_fallback_to_infrastructure(prefect_environment_client, mock_deployment_response): 62 | mock_deployment_response.job_variables = None 63 | 64 | assert prefect_environment_client.get_container_image() == "infra_image" 65 | assert prefect_environment_client.get_network() == "infra_network" 66 | assert prefect_environment_client.get_project() == "infra_project" 67 | assert prefect_environment_client.get_region() == "infra_region" 68 | assert prefect_environment_client.get_service_account() == "infra_service_account" -------------------------------------------------------------------------------- /block_cascade/executors/vertex/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import sys 4 | 5 | import cloudpickle 6 | import gcsfs 7 | from hypertune import HyperTune 8 | 9 | from block_cascade.utils import parse_hyperparameters 10 | 11 | INPUT_FILENAME = "function.pkl" 12 | DISTRIBUTED_JOB_FILENAME = "distributed_job.pkl" 13 | OUTPUT_FILENAME = "output.pkl" 14 | 15 | # Clear the Prefect Handler until that 16 | # dependency gets removed. 17 | logging.getLogger().handlers.clear() 18 | logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=logging.INFO) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def run(): 24 | """ 25 | Entrypoint to run a staged pickled function in VertexAI. 26 | This module is invoked as a Python script and the values for 27 | "path" and "hyperparameters" are passed as arguments. 28 | 29 | See cascade.executors.vertex.job.VertexJob._create_container_spec 30 | """ 31 | 32 | path_prefix, distributed_job, code_path, hyperparameters = ( 33 | sys.argv[1], 34 | sys.argv[2], 35 | sys.argv[3], 36 | sys.argv[4:], 37 | ) 38 | 39 | staged_path = f"{path_prefix}/{INPUT_FILENAME}" 40 | distributed_job_path = f"{path_prefix}/{DISTRIBUTED_JOB_FILENAME}" 41 | output_path = f"{path_prefix}/{OUTPUT_FILENAME}" 42 | 43 | fs = gcsfs.GCSFileSystem() 44 | if code_path: 45 | logger.info(f"Fetching {code_path} and added to sys.path.") 46 | fs.get(code_path, ".", recursive=True) 47 | sys.path.insert(0, ".") 48 | 49 | with fs.open(staged_path, "rb") as f: 50 | func = cloudpickle.load(f) 51 | 52 | hyperparameters = parse_hyperparameters(hyperparameters) 53 | # If we received hyperparameters as args, we infer we are doing a tune 54 | # the result of the function is the metric and we report that to Vertex 55 | if hyperparameters: 56 | logger.info(f"Starting execution with hyperparameters: {hyperparameters}") 57 | metrics = func(hyperparameters=hyperparameters) 58 | logger.info(f"Reporting metrics to hyperparameter tune: {metrics}") 59 | htune = HyperTune() 60 | for k, v in metrics.items(): 61 | htune.report_hyperparameter_tuning_metric(k, v) 62 | 63 | # If the job is a distributed job (including Dask and Torch) 64 | elif distributed_job == "True": 65 | logger.info("Starting execution of distributed job") 66 | with fs.open(distributed_job_path, "rb") as f: 67 | distributed_job = cloudpickle.load(f) 68 | distributed_job.run(func=func, storage_path=path_prefix) 69 | 70 | # If neither of the above are true, we assume this is a regular 71 | # single node job and the result of the function is saved 72 | # to storage to be read by the submitting call 73 | else: 74 | logger.info("Starting execution") 75 | result = func() 76 | logger.info(f"Saving output of task to {output_path}") 77 | with fs.open(output_path, "wb") as f: 78 | pickle.dump(result, f) 79 | 80 | 81 | if __name__ == "__main__": 82 | # This module is run as main in order to execute a task on a worker 83 | run() 84 | -------------------------------------------------------------------------------- /block_cascade/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from functools import partial 3 | import inspect 4 | from inspect import signature 5 | import itertools 6 | import logging 7 | from typing import List 8 | import subprocess 9 | import json 10 | 11 | import prefect 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def detect_prefect_version(n) -> int: 17 | return int(prefect.__version__.split(".")[n]) 18 | 19 | 20 | PREFECT_VERSION, PREFECT_SUBVERSION = detect_prefect_version(0), detect_prefect_version( 21 | 1 22 | ) 23 | 24 | INPUT_FILENAME = "function.pkl" 25 | DISTRIBUTED_JOB_FILENAME = "distributed_job.pkl" 26 | OUTPUT_FILENAME = "output.pkl" 27 | 28 | 29 | def get_gcloud_config() -> dict: 30 | """Get the current gcloud config if available in a user's environment.""" 31 | try: 32 | # Run the gcloud config list command and parse the output as JSON 33 | result = subprocess.run( 34 | ["gcloud", "config", "list", "--format", "json"], 35 | capture_output=True, 36 | text=True, 37 | ) 38 | # Check if the command was executed successfully 39 | if result.returncode == 0: 40 | config = json.loads(result.stdout) 41 | else: 42 | logger.error("Error listing gcloud configuration:", result.stderr) 43 | config = dict() 44 | except Exception as e: 45 | logger.error("Error listing gcloud configuration:", e) 46 | config = dict() 47 | 48 | return config 49 | 50 | 51 | def get_args(obj): 52 | return list(signature(obj).parameters.keys()) 53 | 54 | 55 | def maybe_convert(arg: str) -> str: 56 | """Convenience function for parsing hyperparameters""" 57 | try: 58 | return ast.literal_eval(arg) 59 | except ValueError: 60 | return arg 61 | 62 | 63 | def parse_hyperparameters(args: List[str]) -> dict: 64 | # support both --a=3 and --a 3 syntax 65 | args = list(itertools.chain(*(a.split("=") for a in args))) 66 | return { 67 | args[i].lstrip("-"): maybe_convert(args[i + 1]) for i in range(0, len(args), 2) 68 | } 69 | 70 | 71 | def _get_object_args(obj: object): 72 | return list(signature(obj).parameters.keys()) 73 | 74 | 75 | def _infer_base_module(func): 76 | """ 77 | Inspects the function to find the base module in which it was defined. 78 | Args: 79 | func (Callable): a function 80 | Returns: 81 | str: the name of the base module in which the function was defined 82 | """ 83 | func_module = inspect.getmodule(func) 84 | try: 85 | base_name, *_ = func_module.__name__.partition(".") 86 | return base_name 87 | except AttributeError: 88 | return None 89 | 90 | 91 | def wrapped_partial(func, *args, **kwargs): 92 | """ 93 | Return a partial function, while keeping the original function's name. 94 | Note: not using functools.update_wrapper here due to incompatability with Prefect 2 95 | for more details see: https://github.com/squareup/cascade/pull/188 96 | """ 97 | partial_func = partial(func, *args, **kwargs) 98 | setattr(partial_func, "__name__", func.__name__) 99 | return partial_func 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cascade 2 | 3 | Cascade is a library for submitting and managing jobs across multiple cloud environments. It is designed to integrate seamlessly into existing Prefect workflows or can be used as a standalone library. 4 | 5 | ## Getting Started 6 | 7 | ### Installation 8 | 9 | ```bash 10 | poetry add block-cascade 11 | ``` 12 | or 13 | ``` 14 | pip install block-cascade 15 | ``` 16 | 17 | ### Example Usage 18 | 19 | ```python 20 | from block_cascade import remote 21 | from block_cascade import GcpEnvironmentConfig, GcpMachineConfig, GcpResource 22 | 23 | machine_config = GcpMachineConfig("n2-standard-4", 1) 24 | environment_config = GcpEnvironmentConfig( 25 | project="example-project", 26 | region="us-west1", 27 | service_account=f"example-project@vertex.iam.gserviceaccount.com", 28 | image="us.gcr.io/example-project/cascade/cascade-test", 29 | network="projects/123456789123/global/networks/shared-vpc" 30 | ) 31 | gcp_resource = GcpResource( 32 | chief=machine_config, 33 | environment=environment_config, 34 | ) 35 | 36 | @remote(resource=gcp_resource) 37 | def addition(a: int, b: int) -> int: 38 | return a + b 39 | 40 | result = addition(1, 2) 41 | assert result == 3 42 | ``` 43 | 44 | ### Configuration 45 | Cascade supports defining different resource requirements via a configuration file titled either cascade.yaml or cascade.yml. This configuration file must be located in the working directory of the code execution to be discovered at runtime. 46 | 47 | ```yaml 48 | calculate: 49 | type: GcpResource 50 | chief: 51 | type: n1-standard-1 52 | You can even define a default configuration that can be overridden by specific tasks to eliminate redundant definitions. 53 | 54 | default: 55 | GcpResource: 56 | environment: 57 | project: example-project 58 | service_account: example-project@vertex.iam.gserviceaccount.com 59 | region: us-central-1 60 | chief: 61 | type: n1-standard-4 62 | ``` 63 | 64 | ### Authorization 65 | Cascade requires authorization both to submit jobs to either GCP or Databricks and to stage picklied code to a cloud storage bucket. In the GCP example below, an authorization token is obtained via IAM by running the following command: 66 | 67 | ```bash 68 | gcloud auth login --update-adc 69 | ``` 70 | No additional configuration is required in your application's code to use this token. 71 | 72 | However, for authenticating to Databricks and AWS you will need to provide a token and secret key respectively. These can be passed directly to the `DatabricksResource` object or set as environment variables. The following example shows how to provide these values in the configuration file. 73 | 74 | ## For Developers 75 | 76 | ### Using hermit for managing Python 77 | When developing cascade, you can optionally use [hermit](https://cashapp.github.io/hermit/usage/get-started/) to manage the Python executable used by cascade. Together with using poetry to manage dependencies, this will ensure that your development environment is identical to other contributors. Follow the linked instructions for installing hermit and then you can create a virtualenv with Python@3.9 by running: 78 | 79 | `. ./bin/activate-hermit` 80 | 81 | Then, install the dependencies with poetry: 82 | `poetry install` -------------------------------------------------------------------------------- /block_cascade/executors/vertex/distributed/torchrun_target.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pickle 5 | import sys 6 | 7 | import cloudpickle 8 | import gcsfs 9 | import torch 10 | import torch.distributed as dist 11 | from torch.distributed.elastic.multiprocessing.errors import record 12 | 13 | 14 | class InvalidReturnDictionaryError(Exception): 15 | """User returned dictionaries must contain the key MODEL_STATE. Raise this exception 16 | if the user returned dict is not compliant 17 | """ 18 | 19 | def __init__(self): 20 | message = ( 21 | "User-defined training function must return a dictionary with the " 22 | "key `MODEL_STATE` that contains a Pytorch state_dict" 23 | ) 24 | super().__init__(message) 25 | 26 | 27 | class DistributedSetup: 28 | """Context Manager to setup Pytorch distributed environment to enable 29 | distributed training 30 | 31 | If the job fails, tear down the process group. 32 | """ 33 | 34 | def __init__(self) -> None: 35 | cluster_spec = json.loads(os.environ["CLUSTER_SPEC"]) 36 | master_hostname, master_port = cluster_spec["cluster"]["workerpool0"][0].split( 37 | ":" 38 | ) 39 | 40 | os.environ["MASTER_ADDR"] = str(master_hostname) 41 | os.environ["MASTER_PORT"] = str(master_port) 42 | 43 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) 44 | 45 | def __enter__(self) -> None: 46 | dist.init_process_group("nccl") 47 | 48 | def __exit__(self, *args, **kwargs) -> None: 49 | dist.destroy_process_group() 50 | 51 | 52 | @record 53 | def run_job_function() -> None: 54 | """Runs the wrapped, pickled, user-defined function. Wraps the user-defined function 55 | in the necessary boilerplate to enable distributed training 56 | """ 57 | fs = gcsfs.GCSFileSystem() 58 | 59 | input_filepath = sys.argv[1] 60 | output_filepath = sys.argv[2] 61 | 62 | with fs.open(input_filepath, "rb") as f: 63 | job = cloudpickle.load(f) 64 | 65 | with DistributedSetup(): 66 | logging.info("Starting user code execution") 67 | snapshot = job.func() 68 | 69 | # This is a "Look before you leap" check to make sure the user-returned value is 70 | # something we expect: a dictionary that at least contains the key "MODEL_STATE" 71 | if not isinstance(snapshot, dict) or "MODEL_STATE" not in snapshot: 72 | raise InvalidReturnDictionaryError() 73 | 74 | # Save the model state to job.output. All processes will have a consistent model 75 | # state so we only need to save the model once, from the Rank 0 process 76 | if os.environ["RANK"] == "0": 77 | model_state_dict = snapshot["MODEL_STATE"] 78 | 79 | # "Moving" all of the model state tensors to the CPU allows the state dict 80 | # to be accessed as expected downstream. This does introduce an abstraction leak 81 | # since it requires the user to pass a specifically formatted return value 82 | model_state_cpu = {k: v.cpu() for k, v in model_state_dict.items()} 83 | snapshot["MODEL_STATE"] = model_state_cpu 84 | 85 | with fs.open(output_filepath, "wb") as f: 86 | pickle.dump(snapshot, f) 87 | print(f"Result object saved at {output_filepath}") 88 | 89 | 90 | if __name__ == "__main__": 91 | run_job_function() 92 | -------------------------------------------------------------------------------- /block_cascade/prefect/v1/environment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Optional 4 | 5 | import google.auth 6 | from google.cloud import resourcemanager_v3 7 | from prefect.backend.flow import FlowView 8 | 9 | from block_cascade.gcp import VertexAIEnvironmentInfoProvider 10 | from block_cascade.prefect.v1 import get_from_prefect_context 11 | 12 | 13 | class PrefectEnvironmentClient(VertexAIEnvironmentInfoProvider): 14 | """ 15 | A client for fetching Deployment related 16 | metadata from a Prefect 1 Flow. 17 | """ 18 | 19 | def __init__(self): 20 | super().__init__() 21 | self._flow_id = get_from_prefect_context("flow_id") 22 | 23 | def get_container_image(self) -> Optional[str]: 24 | if not self._flow_id: 25 | return None 26 | 27 | flow_view = FlowView.from_flow_id(self._flow_id) 28 | registry_url = flow_view.storage.registry_url 29 | image_tag = flow_view.storage.image_tag 30 | image_name = flow_view.storage.image_name 31 | image_url = f"{registry_url}/{image_name}:{image_tag}" 32 | return image_url 33 | 34 | def get_network(self) -> Optional[str]: 35 | if not self._flow_id: 36 | return None 37 | 38 | flow_view = FlowView.from_flow_id(self._flow_id) 39 | if flow_view: 40 | return getattr(flow_view.run_config, "network", None) 41 | return None 42 | 43 | def get_project(self) -> Optional[str]: 44 | # Set manually in deployed Docker image 45 | project_name = os.environ.get("GCP_PROJECT") 46 | # Set by Vertex training job 47 | project_number = os.environ.get("CLOUD_ML_PROJECT_ID") 48 | 49 | if project_name: 50 | return project_name 51 | 52 | elif project_number: 53 | client = resourcemanager_v3.ProjectsClient() 54 | request = resourcemanager_v3.GetProjectRequest( 55 | name=f"projects/{project_number}" 56 | ) 57 | response = client.get_project(request=request) 58 | return response.project_id 59 | 60 | elif sys.platform == "darwin": # Only works locally - not on Vertex 61 | _, project = google.auth.default() 62 | return project 63 | 64 | else: 65 | # protect against any edge cases where project cannot be determined 66 | # via other methods 67 | raise google.auth.exceptions.DefaultCredentialsError( 68 | "Project could not be determined from environment" 69 | ) 70 | 71 | def get_region(self) -> Optional[str]: 72 | if not self._flow_id: 73 | return None 74 | flow_view = FlowView.from_flow_id(self._flow_id) 75 | if flow_view: 76 | return getattr(flow_view.run_config, "region", None) 77 | return None 78 | 79 | def get_service_account(self) -> Optional[str]: 80 | if not self._flow_id: 81 | return 82 | 83 | flow_view = FlowView.from_flow_id(self._flow_id) 84 | if flow_view: 85 | service_account = getattr(flow_view.run_config, "service_account", None) 86 | 87 | # Check environment for service account. 88 | # It will injected by Google later. 89 | service_account = ( 90 | service_account 91 | or os.getenv("SERVICE_ACCOUNT") 92 | or os.getenv("CLOUD_ML_JOB_SA") 93 | ) 94 | return service_account 95 | -------------------------------------------------------------------------------- /block_cascade/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Dict, Optional, Union 4 | 5 | import yaml 6 | 7 | try: 8 | from yaml import CSafeLoader as SafeLoader 9 | except ImportError: 10 | from yaml import SafeLoader 11 | 12 | from block_cascade.executors.databricks.resource import DatabricksResource 13 | from block_cascade.executors.vertex.resource import GcpResource 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | GCP_RESOURCE = GcpResource.__name__ 19 | DATABRICKS_RESOURCE = DatabricksResource.__name__ 20 | SUPPORTED_FILENAMES = ("cascade.yaml", "cascade.yml") 21 | ACCEPTED_TYPES = (GCP_RESOURCE, DATABRICKS_RESOURCE) 22 | 23 | 24 | def _merge(a: dict, b: dict) -> dict: 25 | """ 26 | Deep merges two dictionaries with values in b 27 | overriding values in a (if present). 28 | 29 | e.g. _merge({"hello": "world}, {"hello", "goodbye"}) -> {"hello": "goodbye"} 30 | 31 | This function is focused primarily at merging the values of nested dictionaries 32 | with all other types being simply overriden if keys are found in both input 33 | dictionaries. 34 | """ 35 | merged = {} 36 | 37 | for key, val in a.items(): 38 | if key not in b: 39 | merged[key] = val 40 | 41 | for key, val in b.items(): 42 | if key not in a: 43 | merged[key] = val 44 | else: 45 | val_2 = a[key] 46 | if isinstance(val, dict) and isinstance(val_2, dict): 47 | merged[key] = _merge(val_2, val) 48 | else: 49 | merged[key] = val 50 | return merged 51 | 52 | 53 | def find_default_configuration( 54 | root: str = ".", 55 | ) -> Optional[Dict[str, Union[GcpResource, DatabricksResource]]]: 56 | for filename in SUPPORTED_FILENAMES: 57 | potential_configuration_path = os.path.join(root, filename) 58 | if not os.path.exists(potential_configuration_path): 59 | continue 60 | 61 | logger.info(f"Found cascade configuration at {potential_configuration_path}") 62 | with open(potential_configuration_path) as f: 63 | configuration = yaml.load(f, SafeLoader) 64 | 65 | job_configurations = {} 66 | for job_name, resource_definition in configuration.items(): 67 | if job_name == "default": 68 | continue 69 | default_resource_definition = configuration.get("default", {}) 70 | resource_type = resource_definition.pop("type") 71 | if resource_type not in ACCEPTED_TYPES: 72 | raise ValueError( 73 | f"Only types: {','.join(ACCEPTED_TYPES)} are supported for resource definitions." # noqa: E501 74 | ) 75 | elif resource_type == GCP_RESOURCE: 76 | merged_resource_definition = _merge( 77 | default_resource_definition.get(GCP_RESOURCE, {}), 78 | resource_definition, 79 | ) 80 | job_configurations[job_name] = GcpResource(**merged_resource_definition) 81 | else: 82 | merged_resource_definition = _merge( 83 | default_resource_definition.get(DATABRICKS_RESOURCE, {}), 84 | resource_definition, 85 | ) 86 | job_configurations[job_name] = DatabricksResource( 87 | **merged_resource_definition 88 | ) 89 | return job_configurations 90 | return None 91 | -------------------------------------------------------------------------------- /block_cascade/prefect/v3/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Union 3 | 4 | import prefect 5 | from prefect import get_run_logger 6 | from prefect import runtime 7 | from prefect.client.orchestration import get_client 8 | from prefect.client.schemas.responses import DeploymentResponse 9 | from prefect.context import FlowRunContext, TaskRunContext 10 | from prefect.server.schemas.core import BlockDocument 11 | 12 | from block_cascade.concurrency import run_async 13 | 14 | 15 | _CACHED_DEPLOYMENT: Optional[DeploymentResponse] = None 16 | _CACHED_STORAGE: Optional[BlockDocument] = None 17 | 18 | 19 | async def _fetch_deployment(deployment_id: str) -> BlockDocument: 20 | async with get_client() as client: 21 | return await client.read_deployment(deployment_id) 22 | 23 | async def _fetch_block_by_name(block_name: str, block_type_slug: str = "gcs-bucket") -> Optional[BlockDocument]: 24 | async with get_client() as client: 25 | return await client.read_block_document_by_name( 26 | name=block_name, 27 | block_type_slug=block_type_slug, 28 | ) 29 | 30 | def get_from_prefect_context(attr: str, default: str = "") -> str: 31 | flow_context = FlowRunContext.get() 32 | task_context = TaskRunContext.get() 33 | if not flow_context or not task_context: 34 | return default 35 | 36 | if attr == "flow_name" or attr == "flow_run_name": # noqa: PLR1714 37 | return str(getattr(flow_context.flow_run, "name", default)) 38 | if attr == "flow_id" or attr == "flow_run_id": # noqa: PLR1714 39 | return str(getattr(flow_context.flow_run, "id", default)) 40 | if attr == "task_run" or attr == "task_full_name": # noqa: PLR1714 41 | return str(getattr(task_context.task_run, "name", default)) 42 | if attr == "task_run_id": 43 | return str(getattr(task_context.task_run, "id", default)) 44 | raise RuntimeError("Unsupported attribute: {attr}.") 45 | 46 | 47 | def get_current_deployment() -> Optional[DeploymentResponse]: 48 | deployment_id = runtime.deployment.id 49 | if not deployment_id: 50 | return None 51 | 52 | global _CACHED_DEPLOYMENT # noqa: PLW0603 53 | if not _CACHED_DEPLOYMENT: 54 | _CACHED_DEPLOYMENT = run_async( 55 | _fetch_deployment(deployment_id) 56 | ) 57 | return _CACHED_DEPLOYMENT 58 | 59 | 60 | def get_storage_block() -> Optional[BlockDocument]: 61 | current_deployment = get_current_deployment() 62 | if not current_deployment: 63 | return None 64 | 65 | global _CACHED_STORAGE # noqa: PLW0603 66 | if not _CACHED_STORAGE: 67 | _CACHED_STORAGE = run_async( 68 | _fetch_block_by_name(block_name=current_deployment.pull_steps[0]["prefect.deployments.steps.pull_with_block"]["block_document_name"]) 69 | ) 70 | return _CACHED_STORAGE 71 | 72 | 73 | def get_prefect_logger(name: str = "") -> Union[logging.LoggerAdapter, logging.Logger]: 74 | """ 75 | Tries to get the prefect run logger, and if it is not available, 76 | gets the root logger. 77 | """ 78 | try: 79 | return get_run_logger() 80 | except prefect.exceptions.MissingContextError: 81 | # if empty string is passed, 82 | # obtains access to the root logger 83 | logger = logging.getLogger(name) 84 | logger.setLevel(logging.INFO) 85 | return logger 86 | 87 | 88 | def is_prefect_cloud_deployment() -> bool: 89 | return runtime.deployment.id is not None 90 | -------------------------------------------------------------------------------- /block_cascade/executors/databricks/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import sys 4 | import os 5 | 6 | try: 7 | import cloudpickle 8 | except ImportError: 9 | import pickle as cloudpickle # Databricks Runtime 11+ renames cloudpickle to pickle... # noqa: E501 10 | 11 | INPUT_FILENAME = "function.pkl" 12 | OUTPUT_FILENAME = "output.pkl" 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def run(): 19 | storage_location, _ = sys.argv[1], sys.argv[2] 20 | 21 | # Detect storage type: S3 (cluster compute) or Unity Catalog Volumes (serverless compute) 22 | if storage_location.startswith("s3://"): 23 | _run_s3(storage_location) 24 | elif storage_location.startswith("/Volumes/"): 25 | _run_volumes(storage_location) 26 | else: 27 | raise ValueError( 28 | f"Unsupported storage location: {storage_location}. " 29 | "Must be /Volumes/ (serverless) or s3:// (cluster compute)." 30 | ) 31 | 32 | 33 | def _run_s3(bucket_location: str): 34 | """Run with S3 storage backend.""" 35 | import boto3 36 | 37 | s3_bucket, object_path = bucket_location.replace("s3://", "").split("/", 1) 38 | 39 | try: 40 | s3 = boto3.resource("s3") 41 | func = cloudpickle.loads( 42 | s3.Bucket(s3_bucket) 43 | .Object(f"{object_path}/{INPUT_FILENAME}") 44 | .get()["Body"] 45 | .read() 46 | ) 47 | logger.info("Starting execution") 48 | 49 | result = func() 50 | 51 | logger.info(f"Saving output of task to {bucket_location}/{OUTPUT_FILENAME}") 52 | try: 53 | s3.Bucket(s3_bucket).Object(f"{object_path}/{OUTPUT_FILENAME}").put( 54 | Body=pickle.dumps(result) 55 | ) 56 | except RuntimeError as e: 57 | logger.error( 58 | "Failed to serialize user function return value. Be sure not to return " 59 | "Spark objects from user functions. For example, you should convert " 60 | "Spark dataframes to Pandas dataframes before returning." 61 | ) 62 | raise e 63 | except RuntimeError as e: 64 | logger.error("Failed to execute user function") 65 | raise e 66 | 67 | 68 | def _run_volumes(storage_location: str): 69 | """Run with Unity Catalog Volumes storage backend. 70 | 71 | Unity Catalog Volumes are accessed as regular filesystem paths from within 72 | Databricks clusters. They provide proper security and permissions through 73 | Unity Catalog governance. 74 | """ 75 | input_path = os.path.join(storage_location, INPUT_FILENAME) 76 | output_path = os.path.join(storage_location, OUTPUT_FILENAME) 77 | 78 | try: 79 | # Read pickled function directly from filesystem 80 | with open(input_path, "rb") as f: 81 | func = cloudpickle.load(f) 82 | 83 | logger.info("Starting execution") 84 | result = func() 85 | 86 | logger.info(f"Saving output of task to {output_path}") 87 | try: 88 | with open(output_path, "wb") as f: 89 | pickle.dump(result, f) 90 | except RuntimeError as e: 91 | logger.error( 92 | "Failed to serialize user function return value. Be sure not to return " 93 | "Spark objects from user functions. For example, you should convert " 94 | "Spark dataframes to Pandas dataframes before returning." 95 | ) 96 | raise e 97 | except RuntimeError as e: 98 | logger.error("Failed to execute user function") 99 | raise e 100 | 101 | 102 | if __name__ == "__main__": 103 | # This module is run as main in order to execute a task on a worker 104 | run() 105 | -------------------------------------------------------------------------------- /tests/executors/databricks/resource/test_python_library.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | from block_cascade.executors.databricks.resource import DatabricksPythonLibrary 3 | 4 | 5 | def test_python_library_model_dump(): 6 | library = DatabricksPythonLibrary(name="example-package", version="1.2.3", repo="https://example.com/pypi") 7 | expected_output = { 8 | "pypi": { 9 | "package": "example-package==1.2.3", 10 | "repo": "https://example.com/pypi" 11 | } 12 | } 13 | assert library.model_dump() == expected_output 14 | 15 | library = DatabricksPythonLibrary(name="example-package", infer_version=False) 16 | expected_output = { 17 | "pypi": { 18 | "package": "example-package" 19 | } 20 | } 21 | assert library.model_dump() == expected_output 22 | 23 | def test_python_library_infer_version(): 24 | with patch("block_cascade.executors.databricks.resource.importlib.metadata.version") as mock_version: 25 | mock_version.return_value = "4.5.6" 26 | library = DatabricksPythonLibrary(name="example-package", infer_version=True) 27 | assert library.version == "4.5.6" 28 | 29 | 30 | def test_databricks_resource_string_library_conversion(): 31 | from block_cascade.executors.databricks.resource import DatabricksResource 32 | 33 | resource = DatabricksResource( 34 | storage_location="s3://test-bucket/cascade", 35 | spark_version="11.3.x-scala2.12", 36 | python_libraries=["test-package"] 37 | ) 38 | 39 | assert len(resource.python_libraries) == 1 40 | assert isinstance(resource.python_libraries[0], DatabricksPythonLibrary) 41 | assert resource.python_libraries[0].name == "test-package" 42 | 43 | resource = DatabricksResource( 44 | storage_location="s3://test-bucket/cascade", 45 | spark_version="11.3.x-scala2.12", 46 | python_libraries=[ 47 | "package1", 48 | DatabricksPythonLibrary(name="package2", version="1.0.0") 49 | ] 50 | ) 51 | 52 | assert len(resource.python_libraries) == 2 53 | assert isinstance(resource.python_libraries[0], DatabricksPythonLibrary) 54 | assert isinstance(resource.python_libraries[1], DatabricksPythonLibrary) 55 | assert resource.python_libraries[0].name == "package1" 56 | assert resource.python_libraries[1].name == "package2" 57 | assert resource.python_libraries[1].version == "1.0.0" 58 | 59 | 60 | def test_databricks_resource_string_with_version_conversion(): 61 | from block_cascade.executors.databricks.resource import DatabricksResource 62 | 63 | resource = DatabricksResource( 64 | storage_location="s3://test-bucket/cascade", 65 | spark_version="11.3.x-scala2.12", 66 | python_libraries=["cloudpickle==0.10.0"] 67 | ) 68 | 69 | assert len(resource.python_libraries) == 1 70 | assert isinstance(resource.python_libraries[0], DatabricksPythonLibrary) 71 | assert resource.python_libraries[0].name == "cloudpickle" 72 | assert resource.python_libraries[0].version == "0.10.0" 73 | 74 | resource = DatabricksResource( 75 | storage_location="s3://test-bucket/cascade", 76 | spark_version="11.3.x-scala2.12", 77 | python_libraries=[ 78 | "numpy==1.22.4", 79 | "pandas==2.0.0", 80 | DatabricksPythonLibrary(name="scikit-learn", version="1.2.2") 81 | ] 82 | ) 83 | 84 | assert len(resource.python_libraries) == 3 85 | assert isinstance(resource.python_libraries[0], DatabricksPythonLibrary) 86 | assert isinstance(resource.python_libraries[1], DatabricksPythonLibrary) 87 | assert isinstance(resource.python_libraries[2], DatabricksPythonLibrary) 88 | assert resource.python_libraries[0].name == "numpy" 89 | assert resource.python_libraries[0].version == "1.22.4" 90 | assert resource.python_libraries[1].name == "pandas" 91 | assert resource.python_libraries[1].version == "2.0.0" 92 | assert resource.python_libraries[2].name == "scikit-learn" 93 | assert resource.python_libraries[2].version == "1.2.2" -------------------------------------------------------------------------------- /tests/test_vertex_executor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from unittest.mock import PropertyMock, patch 4 | 5 | import cloudpickle 6 | from fsspec.implementations.local import LocalFileSystem 7 | from google.cloud.aiplatform_v1beta1.types import job_state 8 | 9 | from block_cascade import GcpMachineConfig, GcpResource 10 | from block_cascade.executors.vertex.executor import ( 11 | Status, 12 | VertexCancelledError, 13 | VertexExecutor, 14 | ) 15 | from block_cascade.executors.vertex.job import VertexJob 16 | from block_cascade.utils import wrapped_partial 17 | from tests.resource_fixtures import gcp_environment 18 | 19 | 20 | # create a basic function 21 | def add(a: int, b: int) -> int: 22 | return a + b 23 | 24 | 25 | # status of job as global variables 26 | CANCELLED_STATUS = Status(job_state.JobState.JOB_STATE_CANCELLED, "test job cancelled") 27 | STAGE_METHOD = "block_cascade.executors.vertex.executor.Executor._stage" 28 | STATUS_METHOD = "block_cascade.executors.vertex.executor.VertexExecutor._get_status" 29 | START_METHOD = "block_cascade.executors.vertex.executor.VertexExecutor._start" 30 | VERTEX_PROPERTY = "block_cascade.executors.vertex.executor.VertexExecutor.vertex" 31 | STORAGE_PATH = "block_cascade.executors.vertex.executor.VertexExecutor.storage_path" 32 | FILESYSTEM = "block_cascade.executors.vertex.executor.gcsfs.GCSFileSystem" 33 | 34 | # Create a GCP resource 35 | machine_config = GcpMachineConfig(type="n1-standard-1") 36 | environment_config = gcp_environment 37 | 38 | gcp_resource = GcpResource( 39 | chief=machine_config, 40 | environment=environment_config, 41 | ) 42 | 43 | @pytest.fixture 44 | def vertex_executor_fixture(): 45 | with patch(STAGE_METHOD) as stage_mock, \ 46 | patch(START_METHOD, return_value="test_job") as start_mock, \ 47 | patch(STATUS_METHOD, return_value=CANCELLED_STATUS) as status_mock, \ 48 | patch(VERTEX_PROPERTY, return_value="dummy_api"), \ 49 | patch(FILESYSTEM, LocalFileSystem): 50 | 51 | vertex_executor = VertexExecutor( 52 | resource=gcp_resource, 53 | func=wrapped_partial(add, 1, 2), 54 | ) 55 | vertex_executor.storage_location = ( 56 | f"{os.path.expanduser('~')}/cascade-storage" 57 | ) 58 | 59 | stage_mock.reset_mock() 60 | start_mock.reset_mock() 61 | status_mock.reset_mock() 62 | 63 | yield vertex_executor, stage_mock, start_mock, status_mock 64 | 65 | def test_run(vertex_executor_fixture): 66 | """ 67 | Tests that the VertexExecutor.run() method calls the correct private methods 68 | """ 69 | # swap the fs for a local fs 70 | vertex_executor, stage_mock, start_mock, status_mock = vertex_executor_fixture 71 | 72 | with pytest.raises(VertexCancelledError): 73 | vertex_executor.run() 74 | 75 | start_mock.assert_called_once() 76 | status_mock.assert_called_once() 77 | stage_mock.assert_called_once() 78 | 79 | def test_create_job(vertex_executor_fixture): 80 | """ 81 | Tests that a VertexJob can be created from a VertexExecutor. 82 | """ 83 | vertex_executor = vertex_executor_fixture[0] 84 | test_job = vertex_executor.create_job() 85 | assert isinstance(test_job, VertexJob) 86 | 87 | custom_job = test_job.create_payload() 88 | assert isinstance(custom_job, dict) 89 | 90 | def test_stage(tmp_path): 91 | """ 92 | Tests that the VertexExecutor._stage() correctly stages a function 93 | """ 94 | with patch(VERTEX_PROPERTY, return_value="dummy_api"), \ 95 | patch(FILESYSTEM, LocalFileSystem): 96 | 97 | executor = VertexExecutor( 98 | resource=gcp_resource, 99 | func=wrapped_partial(add, 1, 2), 100 | ) 101 | executor._fs = LocalFileSystem(auto_mkdir=True) 102 | executor.storage_location = str(tmp_path) 103 | 104 | executor._stage() 105 | 106 | with executor.fs.open(executor.staged_filepath, "rb") as f: 107 | func = cloudpickle.load(f) 108 | 109 | assert func() == 3 110 | 111 | -------------------------------------------------------------------------------- /block_cascade/prefect/v2/environment.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from prefect import runtime 4 | 5 | from block_cascade.concurrency import run_async 6 | from block_cascade.gcp import VertexAIEnvironmentInfoProvider 7 | from block_cascade.utils import PREFECT_SUBVERSION, PREFECT_VERSION 8 | 9 | if (PREFECT_SUBVERSION <= 7) & (PREFECT_VERSION < 3): 10 | from prefect.orion.schemas.core import BlockDocument 11 | else: 12 | from prefect.server.schemas.core import BlockDocument 13 | 14 | from prefect.client.schemas.responses import DeploymentResponse 15 | 16 | from block_cascade.prefect.v2 import _fetch_block, _fetch_deployment 17 | 18 | 19 | class PrefectEnvironmentClient(VertexAIEnvironmentInfoProvider): 20 | """ 21 | A client for fetching Deployment related 22 | metadata from a Prefect 2 Flow. 23 | """ 24 | 25 | def __init__(self): 26 | self._current_deployment = None 27 | self._current_job_variables = None 28 | self._current_infrastructure = None 29 | 30 | def get_container_image(self) -> Optional[str]: 31 | job_variables = self._get_job_variables() 32 | if job_variables: 33 | return job_variables.get("image") 34 | 35 | infra = self._get_infrastructure_block() 36 | if infra: 37 | return infra.data.get("image") 38 | return None 39 | 40 | def get_network(self) -> Optional[str]: 41 | job_variables = self._get_job_variables() 42 | if job_variables: 43 | return job_variables.get("network") 44 | 45 | infra = self._get_infrastructure_block() 46 | if infra: 47 | return infra.data.get("network") 48 | 49 | return None 50 | 51 | def get_project(self) -> Optional[str]: 52 | job_variables = self._get_job_variables() 53 | if job_variables: 54 | return job_variables.get("credentials", {}).get("project") 55 | 56 | infra = self._get_infrastructure_block() 57 | if infra: 58 | return infra.data.get("gcp_credentials", {}).get("project") 59 | 60 | return None 61 | 62 | def get_region(self) -> Optional[str]: 63 | job_variables = self._get_job_variables() 64 | if job_variables: 65 | return job_variables.get("region") 66 | 67 | infra = self._get_infrastructure_block() 68 | if infra: 69 | return infra.data.get("region") 70 | 71 | return None 72 | 73 | def get_service_account(self) -> Optional[str]: 74 | job_variables = self._get_job_variables() 75 | if job_variables: 76 | return job_variables.get("service_account_name") 77 | 78 | infra = self._get_infrastructure_block() 79 | if infra: 80 | return infra.data.get("service_account") 81 | 82 | return None 83 | 84 | def _get_job_variables(self) -> Optional[Dict]: 85 | current_deployment = self._get_current_deployment() 86 | if not current_deployment: 87 | return None 88 | 89 | if not self._current_job_variables: 90 | self._current_job_variables = current_deployment.job_variables 91 | return self._current_job_variables 92 | 93 | def _get_infrastructure_block(self) -> Optional[BlockDocument]: 94 | current_deployment = self._get_current_deployment() 95 | if not current_deployment: 96 | return None 97 | 98 | if not self._current_infrastructure: 99 | self._current_infrastructure = run_async( 100 | _fetch_block(current_deployment.infrastructure_document_id) 101 | ) 102 | return self._current_infrastructure 103 | 104 | def _get_current_deployment(self) -> Optional[DeploymentResponse]: 105 | deployment_id = runtime.deployment.id 106 | if not deployment_id: 107 | return None 108 | 109 | if not self._current_deployment: 110 | self._current_deployment = run_async( 111 | _fetch_deployment(deployment_id) 112 | ) 113 | return self._current_deployment 114 | -------------------------------------------------------------------------------- /block_cascade/prefect/v2/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Union 3 | 4 | import prefect 5 | from prefect import runtime 6 | 7 | from block_cascade.concurrency import run_async 8 | from block_cascade.utils import PREFECT_SUBVERSION, PREFECT_VERSION 9 | 10 | if (PREFECT_SUBVERSION <= 7) & (PREFECT_VERSION < 3): 11 | from prefect.orion.schemas.core import BlockDocument 12 | else: 13 | from prefect.server.schemas.core import BlockDocument 14 | 15 | try: 16 | from prefect.client.orchestration import get_client 17 | except: # noqa: E722 18 | from prefect.client import get_client 19 | 20 | try: 21 | 22 | from prefect import get_run_logger 23 | except: # noqa: E722 24 | from prefect.logging import get_run_logger 25 | 26 | from prefect.client.schemas.responses import DeploymentResponse 27 | from prefect.context import FlowRunContext, TaskRunContext 28 | 29 | _CACHED_DEPLOYMENT: Optional[DeploymentResponse] = None 30 | _CACHED_STORAGE: Optional[BlockDocument] = None 31 | 32 | 33 | async def _fetch_deployment(deployment_id: str) -> BlockDocument: 34 | async with get_client() as client: 35 | return await client.read_deployment(deployment_id) 36 | 37 | 38 | async def _fetch_block(block_id: str) -> Optional[BlockDocument]: 39 | async with get_client() as client: 40 | return await client.read_block_document(block_id) 41 | 42 | async def _fetch_block_by_name(block_name: str, block_type_slug: str = "gcs-bucket") -> Optional[BlockDocument]: 43 | async with get_client() as client: 44 | return await client.read_block_document_by_name( 45 | name=block_name, 46 | block_type_slug=block_type_slug, 47 | ) 48 | 49 | def get_from_prefect_context(attr: str, default: str = "") -> str: 50 | flow_context = FlowRunContext.get() 51 | task_context = TaskRunContext.get() 52 | if not flow_context or not task_context: 53 | return default 54 | 55 | if attr == "flow_name" or attr == "flow_run_name": # noqa: PLR1714 56 | return str(getattr(flow_context.flow_run, "name", default)) 57 | if attr == "flow_id" or attr == "flow_run_id": # noqa: PLR1714 58 | return str(getattr(flow_context.flow_run, "id", default)) 59 | if attr == "task_run" or attr == "task_full_name": # noqa: PLR1714 60 | return str(getattr(task_context.task_run, "name", default)) 61 | if attr == "task_run_id": 62 | return str(getattr(task_context.task_run, "id", default)) 63 | raise RuntimeError("Unsupported attribute: {attr}.") 64 | 65 | 66 | def get_current_deployment() -> Optional[DeploymentResponse]: 67 | deployment_id = runtime.deployment.id 68 | if not deployment_id: 69 | return None 70 | 71 | global _CACHED_DEPLOYMENT # noqa: PLW0603 72 | if not _CACHED_DEPLOYMENT: 73 | _CACHED_DEPLOYMENT = run_async( 74 | _fetch_deployment(deployment_id) 75 | ) 76 | return _CACHED_DEPLOYMENT 77 | 78 | 79 | def get_storage_block() -> Optional[BlockDocument]: 80 | current_deployment = get_current_deployment() 81 | if not current_deployment: 82 | return None 83 | 84 | global _CACHED_STORAGE # noqa: PLW0603 85 | if not _CACHED_STORAGE: 86 | if current_deployment.pull_steps: 87 | _CACHED_STORAGE = run_async( 88 | _fetch_block_by_name(block_name=current_deployment.pull_steps[0]["prefect.deployments.steps.pull_with_block"]["block_document_name"]) 89 | ) 90 | else: 91 | _CACHED_STORAGE = run_async( 92 | _fetch_block(block_id=current_deployment.storage_document_id) 93 | ) 94 | return _CACHED_STORAGE 95 | 96 | 97 | def get_prefect_logger(name: str = "") -> Union[logging.LoggerAdapter, logging.Logger]: 98 | """ 99 | Tries to get the prefect run logger, and if it is not available, 100 | gets the root logger. 101 | """ 102 | try: 103 | return get_run_logger() 104 | except prefect.exceptions.MissingContextError: 105 | # if empty string is passed, 106 | # obtains access to the root logger 107 | logger = logging.getLogger(name) 108 | logger.setLevel(logging.INFO) 109 | return logger 110 | 111 | 112 | def is_prefect_cloud_deployment() -> bool: 113 | return runtime.deployment.id is not None 114 | -------------------------------------------------------------------------------- /block_cascade/executors/executor.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from copy import copy 3 | import os 4 | from uuid import uuid4 5 | 6 | import cloudpickle 7 | from fsspec.implementations.local import LocalFileSystem 8 | 9 | INPUT_FILENAME = "function.pkl" 10 | OUTPUT_FILENAME = "output.pkl" 11 | 12 | 13 | class Executor(abc.ABC): 14 | """ 15 | Base class for an executor, which is the interface to run a function on some arbitraty resource. 16 | The main entry point for an executor is to call the run() method. 17 | A typical lifecyele for an executor is 18 | 1. Create an instance of the executor 19 | 2. Call the run() method, which calls the _run() method, which in turn calls _stage() and _start() 20 | and returns the result of the function by loading it from the output filepath and unpickling it. 21 | """ # noqa: E501 22 | 23 | def __init__(self, func): 24 | """ 25 | Args: 26 | func: Callable 27 | The function to be run by the executor. 28 | Note: 29 | * this function must be picklable 30 | * expect no arguments 31 | * have a __name__ attribute 32 | use block_cascade.utils.wrapped_partial to prepare a function for execution 33 | """ 34 | self.func = func 35 | self._fs = LocalFileSystem(auto_mkdir=True) 36 | self._storage_location = None 37 | self.name = None 38 | self.storage_key = str(uuid4()) 39 | 40 | @property 41 | def fs(self): 42 | return self._fs 43 | 44 | @fs.setter 45 | def fs(self, new_fs): 46 | self._fs = new_fs 47 | 48 | @property 49 | def input_filename(self): 50 | return INPUT_FILENAME 51 | 52 | @property 53 | def output_filename(self): 54 | return OUTPUT_FILENAME 55 | 56 | @property 57 | def func_name(self): 58 | """The name of the function being executed.""" 59 | try: 60 | name = self.name or self.func.__name__ 61 | except AttributeError: 62 | name = self.name or "unnamed" 63 | return name 64 | 65 | @property 66 | def storage_location(self): 67 | return self._storage_location 68 | 69 | @storage_location.setter 70 | def storage_location(self, storage_location: str): 71 | self._storage_location = storage_location 72 | 73 | @property 74 | def storage_path(self): 75 | return os.path.join(self.storage_location, self.storage_key) 76 | 77 | @property 78 | def output_filepath(self): 79 | """The path to the output file for the pickled result from the function.""" 80 | return os.path.join(self.storage_path, self.output_filename) 81 | 82 | @property 83 | def staged_filepath(self): 84 | return os.path.join(self.storage_path, self.input_filename) 85 | 86 | def _stage(self): 87 | """ 88 | Stage the job to run but do not run it 89 | """ 90 | 91 | with self.fs.open(self.staged_filepath, "wb") as f: 92 | cloudpickle.dump(self.func, f) 93 | 94 | @abc.abstractmethod 95 | def _start(self): 96 | """ 97 | Start the function on the resource. 98 | """ 99 | raise NotImplementedError 100 | 101 | def run(self): 102 | """ 103 | Run the child executor's _run function and return the result 104 | """ 105 | return self._run() 106 | 107 | @abc.abstractmethod 108 | def _run(self): 109 | """Run the function and save the output to self.output(name)""" 110 | raise NotImplementedError 111 | 112 | def _result(self): 113 | try: 114 | with self.fs.open(self.output_filepath, "rb") as f: 115 | result = cloudpickle.load(f) 116 | self.fs.rm(self.storage_path, recursive=True) 117 | except FileNotFoundError: 118 | raise FileNotFoundError( 119 | f"Could not find output file {self.output_filepath}" 120 | ) 121 | return result 122 | 123 | def with_(self, **kwargs): 124 | """ 125 | Convenience method for creating a copy of the executor with 126 | new attributes. 127 | """ 128 | instance = copy(self) 129 | for k, v in kwargs.items(): 130 | setattr(instance, k, v) 131 | return instance 132 | -------------------------------------------------------------------------------- /tests/test_databricks_executor.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from unittest.mock import patch 3 | 4 | from block_cascade import DatabricksResource 5 | from block_cascade.executors import DatabricksExecutor 6 | from block_cascade.executors.databricks.job import DatabricksJob 7 | from block_cascade.utils import wrapped_partial 8 | 9 | # Mocks paths 10 | MOCK_CLUSTER_POLICY = ( 11 | "block_cascade.executors.DatabricksExecutor.get_cluster_policy_id_from_policy_name" 12 | ) 13 | MOCK__RUN = "cascade.executors.DatabricksExecutor._run" 14 | MOCK_FILESYSTEM = "cascade.executors.DatabricksExecutor.fs" 15 | MOCK_STORAGE_PATH = "cascade.executors.DatabricksExecutor.storage_path" 16 | 17 | DATABRICKS_GROUP = "cascade" 18 | 19 | databricks_resource = DatabricksResource( 20 | storage_location="s3://test-bucket/cascade", 21 | group_name=DATABRICKS_GROUP, 22 | spark_version="11.3.x-scala2.12", 23 | ) 24 | 25 | 26 | def addition(a: int, b: int) -> int: 27 | return a + b 28 | 29 | 30 | addition_packed = wrapped_partial(addition, 1, 2) 31 | 32 | 33 | def test_create_executor(): 34 | """Test that a DatabricksExecutor can be created.""" 35 | _ = DatabricksExecutor( 36 | func=addition_packed, 37 | resource=databricks_resource, 38 | ) 39 | 40 | 41 | @patch(MOCK_CLUSTER_POLICY, return_value="12345") 42 | def test_create_job(mock_cluster_policy): 43 | """Test that the creat_job method returns a valid DatabricksJob object.""" 44 | executor = DatabricksExecutor( 45 | func=addition_packed, 46 | resource=databricks_resource, 47 | ) 48 | databricks_job = executor.create_job() 49 | assert isinstance(databricks_job, DatabricksJob) 50 | 51 | 52 | @patch(MOCK_CLUSTER_POLICY, return_value="12345") 53 | def test_infer_name(mock_cluster_policy): 54 | """Test that if no name is provided, the name is inferred correctly.""" 55 | executor = DatabricksExecutor( 56 | func=addition_packed, 57 | resource=databricks_resource, 58 | ) 59 | assert executor.name is None 60 | _ = executor.create_job() 61 | assert executor.name == "addition" 62 | 63 | partial_func = partial(addition, 1, 2) 64 | 65 | executor_partialfunc = DatabricksExecutor( 66 | func=partial_func, 67 | resource=databricks_resource, 68 | ) 69 | 70 | assert executor_partialfunc.name is None 71 | _ = executor_partialfunc.create_job() 72 | assert executor_partialfunc.name == "unnamed" 73 | 74 | 75 | def test_serverless_job_creation(): 76 | """Test that a serverless job is created correctly without cluster configuration.""" 77 | serverless_resource = DatabricksResource( 78 | storage_location="s3://test-bucket/cascade", 79 | group_name=DATABRICKS_GROUP, 80 | use_serverless=True, 81 | python_libraries=["pandas", "numpy"], 82 | ) 83 | 84 | executor = DatabricksExecutor( 85 | func=addition_packed, 86 | resource=serverless_resource, 87 | ) 88 | 89 | # Should not require cluster policy lookup for serverless 90 | databricks_job = executor.create_job() 91 | assert isinstance(databricks_job, DatabricksJob) 92 | assert databricks_job.cluster_policy_id is None 93 | 94 | # Verify the payload structure 95 | payload = databricks_job.create_payload() 96 | task = payload["tasks"][0] 97 | 98 | # Task should reference environment by key 99 | assert "environment_key" in task 100 | assert task["environment_key"] == "default" 101 | 102 | # Task should NOT have cluster configuration 103 | assert task.get("existing_cluster_id") is None 104 | assert task.get("new_cluster") is None 105 | 106 | # Libraries should NOT be at task level for serverless 107 | assert "libraries" not in task or len(task.get("libraries", [])) == 0 108 | 109 | # Environments should be defined at job level 110 | assert "environments" in payload 111 | assert len(payload["environments"]) == 1 112 | 113 | env = payload["environments"][0] 114 | assert env["environment_key"] == "default" 115 | assert "spec" in env 116 | assert "dependencies" in env["spec"] 117 | assert "environment_version" in env["spec"] 118 | assert env["spec"]["environment_version"] == "3" 119 | 120 | # Verify dependencies include required libraries 121 | dependencies = env["spec"]["dependencies"] 122 | assert isinstance(dependencies, list) 123 | assert any("cloudpickle" in dep for dep in dependencies) 124 | assert any("prefect" in dep for dep in dependencies) 125 | assert any("pandas" in dep for dep in dependencies) 126 | assert any("numpy" in dep for dep in dependencies) 127 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | from pyfakefs.fake_filesystem import FakeFilesystem 2 | import pytest 3 | 4 | from block_cascade.config import find_default_configuration 5 | from block_cascade.executors.databricks.resource import ( 6 | DatabricksAutoscaleConfig, 7 | DatabricksResource, 8 | ) 9 | from block_cascade.executors.vertex.resource import ( 10 | GcpEnvironmentConfig, 11 | GcpMachineConfig, 12 | GcpResource, 13 | ) 14 | 15 | GCP_PROJECT = "test-project" 16 | GCP_STORAGE_LOCATION = f"gs://{GCP_PROJECT}-cascade/" 17 | 18 | 19 | @pytest.fixture(params=["cascade.yaml", "cascade.yml"]) 20 | def configuration_filename(request): 21 | return request.param 22 | 23 | 24 | @pytest.fixture() 25 | def storage_location(): 26 | return GCP_STORAGE_LOCATION 27 | 28 | 29 | @pytest.fixture() 30 | def gcp_project(): 31 | return GCP_PROJECT 32 | 33 | 34 | @pytest.fixture() 35 | def gcp_location(): 36 | return "us-central1" 37 | 38 | 39 | @pytest.fixture() 40 | def gcp_service_account(): 41 | return "test-project@test-project.iam.gserviceaccount.com" 42 | 43 | 44 | @pytest.fixture() 45 | def gcp_machine_config(): 46 | return GcpMachineConfig(type="n1-standard-4", count=2) 47 | 48 | 49 | @pytest.fixture 50 | def gcp_environment(storage_location, gcp_project, gcp_location, gcp_service_account): 51 | return GcpEnvironmentConfig( 52 | storage_location=storage_location, 53 | project=gcp_project, 54 | service_account=gcp_service_account, 55 | region=gcp_location, 56 | ) 57 | 58 | 59 | @pytest.fixture() 60 | def gcp_resource(gcp_environment, gcp_machine_config): 61 | return GcpResource(chief=gcp_machine_config, environment=gcp_environment) 62 | 63 | 64 | @pytest.fixture() 65 | def databricks_resource(): 66 | return DatabricksResource( 67 | storage_location="s3://test-bucket/cascade", 68 | worker_count=DatabricksAutoscaleConfig(min_workers=5, max_workers=10), 69 | cloud_pickle_by_value=["a", "b"], 70 | spark_version="11.3.x-scala2.12", 71 | ) 72 | 73 | 74 | @pytest.fixture() 75 | def test_job_name(): 76 | return "hello-world" 77 | 78 | 79 | def test_no_configuration(): 80 | assert find_default_configuration() is None 81 | 82 | 83 | def test_invalid_type_specified(fs: FakeFilesystem, configuration_filename: str): 84 | configuration = """ 85 | addition: 86 | type: AwsResource 87 | """ 88 | fs.create_file(configuration_filename, contents=configuration) 89 | with pytest.raises(ValueError): 90 | find_default_configuration() 91 | 92 | 93 | def test_gcp_resource( 94 | fs: FakeFilesystem, 95 | configuration_filename: str, 96 | gcp_resource: GcpResource, 97 | test_job_name: str, 98 | ): 99 | configuration = f""" 100 | {test_job_name}: 101 | type: GcpResource 102 | chief: 103 | type: {gcp_resource.chief.type} 104 | count: {gcp_resource.chief.count} 105 | environment: 106 | storage_location: {gcp_resource.environment.storage_location} 107 | project: {gcp_resource.environment.project} 108 | service_account: {gcp_resource.environment.service_account} 109 | region: {gcp_resource.environment.region} 110 | """ 111 | fs.create_file(configuration_filename, contents=configuration) 112 | assert gcp_resource == find_default_configuration()[test_job_name] 113 | 114 | 115 | def test_databricks_resource( 116 | fs: FakeFilesystem, 117 | configuration_filename: str, 118 | databricks_resource: DatabricksResource, 119 | test_job_name: str, 120 | ): 121 | configuration = f""" 122 | {test_job_name}: 123 | type: DatabricksResource 124 | storage_location: {databricks_resource.storage_location} 125 | worker_count: 126 | min_workers: {databricks_resource.worker_count.min_workers} 127 | max_workers: {databricks_resource.worker_count.max_workers} 128 | cloud_pickle_by_value: 129 | - a 130 | - b 131 | spark_version: {databricks_resource.spark_version} 132 | """ 133 | fs.create_file(configuration_filename, contents=configuration) 134 | assert databricks_resource == find_default_configuration()[test_job_name] 135 | 136 | 137 | def test_merged_resources( 138 | fs: FakeFilesystem, 139 | configuration_filename: str, 140 | test_job_name: str, 141 | gcp_resource: GcpResource, 142 | ): 143 | configuration = f""" 144 | default: 145 | GcpResource: 146 | environment: 147 | storage_location: {gcp_resource.environment.storage_location} 148 | project: "ds-cash-dev" 149 | service_account: {gcp_resource.environment.service_account} 150 | region: {gcp_resource.environment.region} 151 | {test_job_name}: 152 | type: GcpResource 153 | environment: 154 | project: {gcp_resource.environment.project} 155 | chief: 156 | type: {gcp_resource.chief.type} 157 | count: {gcp_resource.chief.count} 158 | """ 159 | fs.create_file(configuration_filename, contents=configuration) 160 | assert gcp_resource == find_default_configuration()[test_job_name] 161 | -------------------------------------------------------------------------------- /block_cascade/executors/vertex/distributed/torch_job.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import json 3 | import logging 4 | import os 5 | from subprocess import check_call 6 | import sys 7 | from typing import List 8 | 9 | from block_cascade.executors.vertex.distributed.distributed_job import ( 10 | DistributedJobBase, 11 | ) 12 | from block_cascade.utils import INPUT_FILENAME, OUTPUT_FILENAME 13 | 14 | MASTER_PORT = "3333" # Hardcode MASTER_PORT for Vertex AI proxy compatibility 15 | RDZV_ID = "123456" # Can be any random string 16 | RDZV_BACKEND = "c10d" # c10d is the Pytorch-preferred backend, more info in the Pytorch "Rendezvous" docs # noqa: E501 17 | 18 | 19 | @dataclass 20 | class TorchJob(DistributedJobBase): 21 | """ 22 | Configure and run a distributed Pytorch training Job. 23 | """ 24 | 25 | # input data folder contents from gcs will be downloaded to /app/data 26 | input_data_gcs_path: str = None 27 | 28 | def get_data_from_gcs(self, gcs_path, local_destination="/app/data") -> None: 29 | if not os.path.exists(local_destination): 30 | os.makedirs(local_destination) 31 | check_call( 32 | [ 33 | "gsutil", 34 | "-m", 35 | "cp", 36 | "-r", 37 | gcs_path, 38 | local_destination, 39 | ] 40 | ) 41 | return None 42 | 43 | def get_cli_args(self) -> List[str]: 44 | """ Get all the args to pass to torchrun. This function takes no input, all 45 | args are either hardcoded or parsed from environment variables 46 | 47 | torchrun is a console script provided by Pytorch, more information on its use 48 | and available arguments is available in the Pytorch docs 49 | 50 | torchrun \ 51 | --nproc_per_node=auto \ 52 | --rdzv_id=$RDZV_ID \ 53 | --rdzv_backend=$RDZV_BACKEND \ 54 | --rdzv_endpoint=$RDZV_ADDR:$RDZV_PORT \ 55 | --rdzv_conf=is_host=$(if [[ $RANK != "0" ]]; then echo false;else echo true;fi) \ 56 | --nnodes=$NNODES \ 57 | -m \ 58 | $SCRIPT_PATH \ 59 | $JOB_INPUT 60 | """ # noqa: E501 61 | torchrun_target_module_path = ( 62 | "block_cascade.executors.vertex.distributed.torchrun_target" 63 | ) 64 | 65 | cluster_spec = json.loads( 66 | os.environ["CLUSTER_SPEC"] 67 | ) # Env var CLUSTER_SPEC injected by vertex 68 | 69 | input_path = f"{self.storage_path}/{INPUT_FILENAME}" 70 | output_path = f"{self.storage_path}/{OUTPUT_FILENAME}" 71 | 72 | if "workerpool1" in cluster_spec["cluster"]: 73 | # If there are additional workers beyond the chief node, setup multi-node 74 | # training 75 | master_addr, _ = cluster_spec["cluster"]["workerpool0"][0].split(":") 76 | 77 | nnodes = str(len(cluster_spec["cluster"]["workerpool1"]) + 1) 78 | 79 | rank = os.environ["RANK"] # Env var RANK injected by vertex 80 | is_chief = str( 81 | rank == "0" 82 | ).lower() # This is the only change in logic between chief and workers 83 | 84 | return [ 85 | "torchrun", 86 | "--nproc_per_node=auto", 87 | f"--rdzv_id={RDZV_ID}", 88 | f"--rdzv_backend={RDZV_BACKEND}", 89 | f"--rdzv_endpoint={master_addr}:{MASTER_PORT}", 90 | f"--rdzv_conf=is_host={is_chief}", 91 | f"--nnodes={nnodes}", 92 | "-m", 93 | torchrun_target_module_path, 94 | input_path, 95 | output_path, 96 | ] 97 | else: 98 | # Handle the special case where nnodes=="1", i.e. when there is 99 | # only a master node and no additional workers. 100 | return [ 101 | "torchrun", 102 | "--standalone", 103 | "--nproc_per_node=auto", 104 | "-m", 105 | torchrun_target_module_path, 106 | input_path, 107 | output_path, 108 | ] 109 | 110 | def _run(self) -> None: 111 | """ 112 | Optionally load input data from GCS and then launch the torch run 113 | utility to run training training. 114 | 115 | There is no switch logic for chief/master and worker nodes: 116 | this function is run on each training node, then the torchrun utility 117 | distributes the torchrun target script to each accelerator and runs 118 | the training script 119 | 120 | torchrun is a console script provided by Pytorch, more information on its use 121 | here: https://pytorch.org/docs/stable/elastic/run.html 122 | """ 123 | if self.input_data_gcs_path: 124 | self.get_data_from_gcs(self.input_data_gcs_path) 125 | 126 | cli_args = self.get_cli_args() 127 | logging.info(f"Running torchrun with args {cli_args}") 128 | check_call( 129 | cli_args, 130 | stdout=sys.stdout, 131 | stderr=sys.stderr, 132 | ) 133 | -------------------------------------------------------------------------------- /block_cascade/executors/vertex/resource.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Iterable, Optional, TypeVar 3 | 4 | from pydantic import BaseModel, Field, ValidationInfo, field_validator 5 | 6 | from block_cascade.executors.vertex.distributed.distributed_job import ( 7 | DistributedJobBase, 8 | ) 9 | 10 | T = TypeVar("T", bound="GcpEnvironmentConfig") 11 | 12 | 13 | class GcpAcceleratorConfig(BaseModel): 14 | """ 15 | Description of a GPU accelerator to attach to a machine. Accelerator type and count 16 | must be compatabile with the machine type. 17 | See https://cloud.google.com/vertex-ai/docs/training/configure-compute#accelerators 18 | for valid machine_type, accelerator_type and count combinations. 19 | count: int = 1 20 | type: str = 'NVIDIA_TESLA_T4' 21 | """ 22 | 23 | count: int = 1 24 | type: str = "NVIDIA_TESLA_T4" 25 | 26 | 27 | class NfsMountConfig(BaseModel): 28 | """ 29 | Description of an NFS mount to attach to a machine. 30 | """ 31 | 32 | server: str 33 | path: str 34 | mount_point: str 35 | 36 | 37 | class GcpMachineConfig(BaseModel): 38 | """ 39 | Description of a VM type that will be provisioned for a job in GCP. 40 | GCPResources are composed of one or more machines. 41 | 42 | type: str = 'n1-standard-4' 43 | VertexAI machine type, default is n1-standard-4 44 | See https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types 45 | https://cloud.google.com/compute/docs/machine-resource#recommendations_for_machine_types 46 | count: int = 1 47 | The number of machines of to provision in this node pool. 48 | min_replica_count: Optional[int] = None 49 | The minimum number of replicas to provision for this node pool. Only relevant for creating 50 | persistent resources 51 | max_replica_count: Optional[int] = None 52 | The maximum number of replicas to provision for this node pool. Only relevant for creating 53 | persistent resources 54 | accelerator: Optional[GcpAcceleratorConfig] = None 55 | Description of a GPU accelerator to attach to the machine. 56 | See https://cloud.google.com/vertex-ai/docs/training/configure-compute#accelerators 57 | disk_size_gb: Optional[int] = None 58 | Size of the boot disk in GB. If None, uses default size for machine type. 59 | nfs_mounts: Optional[Iterable[NfsMountConfig]] = None 60 | List of NFS mounts to attach to the machine. Specified via NfsMountConfig objects. 61 | """ # noqa: E501 62 | 63 | type: str = "n2-standard-4" 64 | count: int = 1 65 | min_replica_count: Optional[int] = None 66 | max_replica_count: Optional[int] = None 67 | accelerator: Optional[GcpAcceleratorConfig] = None 68 | disk_size_gb: Optional[int] = None 69 | nfs_mounts: Optional[Iterable[NfsMountConfig]] = None 70 | 71 | 72 | class GcpEnvironmentConfig(BaseModel, validate_assignment=True): 73 | """ 74 | Description of the specific GCP environment in which a job will run. 75 | A valid project and service account are required. 76 | 77 | storage_location: str 78 | Path to the directory on GCS where files will be staged and output written 79 | project: Optional[str] 80 | GCP Project used to launch job. 81 | service_account: Optional[str] = None 82 | The name of the service account that will be used for the job. 83 | region: Optional[str] = None 84 | The region in which to start the job. 85 | network: Optional[str] = None 86 | The name of the virtual network in which to start the job 87 | image: Optional[str] = None 88 | The full URL of the image or just the path component following 89 | the project name in the container registry URL. 90 | """ 91 | 92 | storage_location: str 93 | project: Optional[str] = None 94 | service_account: Optional[str] = None 95 | region: Optional[str] = None 96 | network: Optional[str] = None 97 | image: Optional[str] = None 98 | 99 | @field_validator("image", mode="after") 100 | @classmethod 101 | def image_setter(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]: # noqa: N805 102 | image = v 103 | if image is None: 104 | return image 105 | # Full URL 106 | elif image and ("/" in image): 107 | return image 108 | # Just the image tag 109 | else: 110 | return f"us.gcr.io/{info.data['project']}/{image}" 111 | 112 | @property 113 | def is_complete(self): 114 | """ 115 | Determines if the environment config has all required fields to launch a 116 | remote VertexAI job from cascade outside a Prefect context. 117 | Note that network is not required. 118 | """ 119 | return all([self.project, self.service_account, self.region, self.image]) 120 | 121 | 122 | class GcpResource(BaseModel): 123 | """ 124 | Description of a GCP computing resource and its environment 125 | A resource consists of a GCPEnvironmentConfig and one or more GCPMachineConfigs 126 | 127 | chief: GCPMachineConfig 128 | A config describing the chief worker pool. 129 | workers: Optional[GCPMachineConfig] = None 130 | The machine type of the worker machines. 131 | envrionment: Optional[GCPEnvironmentConfig] = None 132 | The GCP environment in which to run the job. If none, the environment will be 133 | inferred from the current Prefect context. 134 | persistent_resource_id: Optional[str] = None 135 | 136 | Set accelerators for GPU training by passing a `GcpAcceleratorConfig` 137 | to the chief or worker machine config object. 138 | """ 139 | 140 | chief: GcpMachineConfig = Field(default_factory=GcpMachineConfig) 141 | workers: Optional[GcpMachineConfig] = None 142 | environment: Optional[GcpEnvironmentConfig] = None 143 | distributed_job: Optional[DistributedJobBase] = None 144 | persistent_resource_id: Optional[str] = None 145 | -------------------------------------------------------------------------------- /tests/test_torch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import pytest 5 | 6 | from block_cascade.executors.vertex import ( 7 | GcpAcceleratorConfig, 8 | GcpEnvironmentConfig, 9 | GcpMachineConfig, 10 | GcpResource, 11 | ) 12 | from block_cascade.executors.vertex.distributed.torch_job import TorchJob 13 | from block_cascade import remote 14 | 15 | torch = pytest.importorskip("torch") 16 | 17 | from torch import nn, optim # noqa 18 | import torch.nn.functional as F # noqa 19 | from torch.nn.parallel import DistributedDataParallel as DDP # noqa 20 | from torch.utils.data.distributed import DistributedSampler # noqa 21 | from torchvision import datasets, transforms # noqa 22 | 23 | 24 | # ====================== Setup test Pytorch model to train ====================== 25 | 26 | 27 | class ConvNet(nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) 31 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 32 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 33 | self.fc1 = nn.Linear(32 * 7 * 7, 64) 34 | self.dropout = nn.Dropout(p=0.5) 35 | self.fc2 = nn.Linear(64, 10) 36 | 37 | def forward(self, x): 38 | x = nn.functional.relu(self.conv1(x)) 39 | x = self.pool(x) 40 | x = nn.functional.relu(self.conv2(x)) 41 | x = self.pool(x) 42 | x = x.view(-1, 32 * 7 * 7) 43 | x = nn.functional.relu(self.fc1(x)) 44 | x = self.dropout(x) 45 | x = nn.functional.softmax(self.fc2(x), dim=1) 46 | return x 47 | 48 | 49 | class Trainer: 50 | def __init__( 51 | self, 52 | model: torch.nn.Module, 53 | train_data: torch.utils.data.DataLoader, 54 | optimizer: torch.optim.Optimizer, 55 | ) -> None: 56 | self.local_rank = int(os.environ["LOCAL_RANK"]) 57 | self.global_rank = int(os.environ["RANK"]) 58 | self.model = model.to(self.local_rank) 59 | self.train_data = train_data 60 | self.optimizer = optimizer 61 | self.epochs_run = 0 62 | self.model = DDP(self.model, device_ids=[self.local_rank]) 63 | 64 | def _run_batch(self, source, targets): 65 | self.optimizer.zero_grad() 66 | output = self.model(source) 67 | loss = F.cross_entropy(output, targets) 68 | loss.backward() 69 | self.optimizer.step() 70 | 71 | def _run_epoch(self, epoch): 72 | b_sz = len(next(iter(self.train_data))[0]) 73 | logging.info( 74 | f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}" # noqa: E501 75 | ) 76 | for source, targets in self.train_data: 77 | source = source.to(self.local_rank) # noqa: PLW2901 78 | targets = targets.to(self.local_rank) # noqa: PLW2901 79 | self._run_batch(source, targets) 80 | 81 | def train(self, max_epochs: int): 82 | for epoch in range(self.epochs_run, max_epochs): 83 | self._run_epoch(epoch) 84 | logging.info( 85 | f"Epoch {epoch} done on worker [{self.global_rank}/{self.local_rank}]" 86 | ) 87 | 88 | snapshot = {} 89 | snapshot["MODEL_STATE"] = self.model.module.state_dict() 90 | snapshot["EPOCHS_RUN"] = epoch 91 | return snapshot 92 | 93 | 94 | def load_train_objs(): 95 | train_set = datasets.MNIST( 96 | root="/app/data/", train=True, transform=transforms.ToTensor(), download=True 97 | ) 98 | 99 | model = ConvNet().cuda() 100 | optimizer = optim.Adam(model.parameters()) 101 | return train_set, model, optimizer 102 | 103 | 104 | def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int): 105 | return torch.utils.data.DataLoader( 106 | dataset, 107 | batch_size=batch_size, 108 | pin_memory=True, 109 | shuffle=False, 110 | sampler=DistributedSampler(dataset), 111 | ) 112 | 113 | 114 | # ====================== Code to run the test ====================== 115 | 116 | 117 | @pytest.mark.skipif( 118 | "BUILDKITE_PIPELINE" in os.environ, 119 | reason="Run integration tests locally only", 120 | ) 121 | def test_torchjob(): 122 | """ 123 | Test launching a Pytorch training job on Vertex 124 | """ 125 | ACCEL_MACHINE_TYPE = "NVIDIA_TESLA_V100" # noqa: N806 126 | WORKER_COUNT = 2 # noqa: N806 127 | accelerator_config = GcpAcceleratorConfig(count=2, type=ACCEL_MACHINE_TYPE) 128 | 129 | environment = GcpEnvironmentConfig( 130 | storage_location="gs://ds-cash-production-cascade/", 131 | project="ds-cash-production", 132 | service_account="ds-cash-production@ds-cash-production.iam.gserviceaccount.com", 133 | region="us-west1", 134 | network="projects/123456789123/global/networks/shared-vpc", 135 | image="us.gcr.io/ds-cash-production/cascade/cascade-test", 136 | ) 137 | 138 | resource = GcpResource( 139 | chief=GcpMachineConfig(type="n1-standard-16", accelerator=accelerator_config), 140 | workers=GcpMachineConfig( 141 | type="n1-standard-16", count=WORKER_COUNT, accelerator=accelerator_config 142 | ), 143 | environment=environment, 144 | distributed_job=TorchJob(), 145 | ) 146 | 147 | @remote(resource=resource, job_name="torchjob-cascade-test") 148 | def training_task(): 149 | total_epochs = 10 150 | 151 | dataset, model, optimizer = load_train_objs() 152 | train_data = prepare_dataloader(dataset, batch_size=64) 153 | trainer = Trainer(model, train_data, optimizer) 154 | result = trainer.train(total_epochs) 155 | 156 | return result 157 | 158 | result_dict = training_task() 159 | logging.info(f"Result: {result_dict}") 160 | 161 | # Task/Executor run is expected to return a dictionary containing at the bare 162 | # minimum a Pytorch state_dict describing a Pytorch model. This dictionary 163 | # should be directly accessible from the executor.run-returned object, i.e. it 164 | # does not need to be unpickled or loaded from a remote filestore 165 | assert "MODEL_STATE" in result_dict 166 | -------------------------------------------------------------------------------- /block_cascade/executors/databricks/job.py: -------------------------------------------------------------------------------- 1 | """ Data model for task running on Databricks 2 | """ 3 | import logging 4 | from typing import Any, Optional 5 | 6 | from block_cascade.executors.databricks.resource import ( 7 | DatabricksAutoscaleConfig, 8 | DatabricksResource, 9 | DatabricksPythonLibrary 10 | ) 11 | 12 | from pydantic import BaseModel 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class DatabricksJob(BaseModel): 18 | """A description of a job to run on Databricks 19 | 20 | Attributes 21 | ---------- 22 | name 23 | The job display name on Vertex 24 | resource: DatabricksResource 25 | The execution cluster configuration on Databricks - see the docs for `DatabricksResource`. 26 | storage_path: str 27 | The full path to the directory for assets (input, output) on AWS (includes storage_key) 28 | storage_key: str 29 | A key suffixed to the storage location to ensure a unique path for each job. 30 | Also used as the `idempotency_token` in the Job API request. 31 | cluster_policy_id: str 32 | Generated by default by looking up using team name 33 | existing_cluster_id: str 34 | The id of an existing cluster to use, if specified job_config is ignored. 35 | run_path: str 36 | Path to run.py bootstrapping script on AWS 37 | task_args: str 38 | Additional args to be passed to the task 39 | timeout_seconds: str 40 | The maximum time this job can run for; default is 24 hours 41 | """ # noqa: E501 42 | 43 | name: str 44 | resource: DatabricksResource 45 | storage_path: str 46 | storage_key: str 47 | run_path: str 48 | cluster_policy_id: Optional[str] = None 49 | existing_cluster_id: Optional[str] = None 50 | timeout_seconds: int = 86400 51 | 52 | def create_payload(self): 53 | """""" 54 | task = self._task_spec() 55 | 56 | # Only add cluster configuration if not using serverless 57 | if not self.resource.use_serverless: 58 | task.update({"existing_cluster_id": self.existing_cluster_id}) 59 | task.update({"new_cluster": self._cluster_spec()}) 60 | 61 | payload = { 62 | "tasks": [task], 63 | "run_name": self.name, 64 | "timeout_seconds": self.timeout_seconds, 65 | "idempotency_token": self.storage_key, 66 | "access_control_list": [ 67 | { 68 | "group_name": self.resource.group_name, 69 | "permission_level": "CAN_MANAGE", 70 | }, 71 | ], 72 | } 73 | 74 | # Add environments configuration for serverless at job level 75 | if self.resource.use_serverless: 76 | payload["environments"] = [ 77 | { 78 | "environment_key": "default", 79 | "spec": { 80 | "dependencies": self._pip_dependencies(), 81 | "environment_version": self.resource.serverless_environment_version, 82 | } 83 | } 84 | ] 85 | 86 | return payload 87 | 88 | def _task_spec(self): 89 | task_args = self.resource.task_args or {} 90 | 91 | if self.existing_cluster_id is None and not self.resource.use_serverless: 92 | if task_args.get("libraries") is None: 93 | task_args["libraries"] = [] 94 | task_args["libraries"].extend(self._libraries()) 95 | elif self.existing_cluster_id and self.resource.use_serverless: 96 | # Log warning if both existing_cluster_id and use_serverless are set 97 | # This should be caught by validation, but adding defensive check 98 | logger.warning( 99 | "Both existing_cluster_id and use_serverless are set. " 100 | "Serverless mode takes precedence; existing_cluster_id will be ignored." 101 | ) 102 | 103 | task_spec = { 104 | "task_key": f"{self.name[:32]}---{self.name[-32:]}", 105 | "description": "A function submitted from Cascade", 106 | "depends_on": [], 107 | "spark_python_task": { 108 | "python_file": self.run_path, 109 | "parameters": [self.storage_path, self.storage_key], 110 | }, 111 | **task_args, 112 | } 113 | 114 | # Add environment_key for serverless compute 115 | # The environment is defined at the job level in create_payload() 116 | if self.resource.use_serverless: 117 | task_spec["environment_key"] = "default" 118 | 119 | return task_spec 120 | 121 | def _libraries(self) -> list[dict[str, Any]]: 122 | required_libraries = ("cloudpickle", "prefect") 123 | for lib in required_libraries: 124 | if any(lib == package.name for package in self.resource.python_libraries): 125 | continue 126 | self.resource.python_libraries.append( 127 | DatabricksPythonLibrary( 128 | name=lib 129 | ) 130 | ) 131 | return [package.model_dump() for package in self.resource.python_libraries] 132 | 133 | def _pip_dependencies(self) -> list[str]: 134 | """ 135 | Convert python libraries to pip dependency strings for serverless environments. 136 | Returns a list of pip requirement specifiers. 137 | """ 138 | required_libraries = ("cloudpickle", "prefect") 139 | for lib in required_libraries: 140 | if any(lib == package.name for package in self.resource.python_libraries): 141 | continue 142 | self.resource.python_libraries.append( 143 | DatabricksPythonLibrary( 144 | name=lib 145 | ) 146 | ) 147 | 148 | # Convert DatabricksPythonLibrary objects to pip requirement strings 149 | return [str(package) for package in self.resource.python_libraries] 150 | 151 | def _cluster_spec(self): 152 | """ 153 | Creates a cluster spec for a Databricks job from the resource object 154 | passed to the DatabricksJobConfig object. 155 | """ 156 | if self.existing_cluster_id or self.resource.use_serverless: 157 | return None 158 | else: 159 | cluster_spec = { 160 | "spark_version": self.resource.spark_version, 161 | "node_type_id": self.resource.machine, 162 | "policy_id": self.cluster_policy_id, 163 | "data_security_mode": self.resource.data_security_mode, 164 | "single_user_name": None, 165 | } 166 | worker_count = self.resource.worker_count 167 | if ( 168 | isinstance(worker_count, DatabricksAutoscaleConfig) 169 | or "DatabricksAutoscaleConfig" in type(worker_count).__name__ 170 | ): 171 | workers = { 172 | "autoscale": { 173 | "min_workers": worker_count.min_workers, 174 | "max_workers": worker_count.max_workers, 175 | } 176 | } 177 | elif isinstance(worker_count, int): 178 | workers = {"num_workers": worker_count} 179 | else: 180 | raise TypeError( 181 | f"Expected `worker_count` of type `DatabricksAutoscaleConfig` or " 182 | f"`int` but received {type(worker_count)}" 183 | ) 184 | 185 | cluster_spec.update(workers) 186 | if self.resource.cluster_spec_overrides is not None: 187 | cluster_spec.update(self.resource.cluster_spec_overrides) 188 | return cluster_spec 189 | -------------------------------------------------------------------------------- /block_cascade/gcp/monitoring.py: -------------------------------------------------------------------------------- 1 | from asyncio import gather 2 | from collections import defaultdict 3 | from datetime import datetime 4 | from enum import Enum 5 | 6 | from google.cloud import monitoring_v3 7 | from google.cloud.monitoring_v3 import types as monitoring_types 8 | 9 | from block_cascade.executors.vertex.resource import GcpResource 10 | from block_cascade.prefect import get_prefect_logger 11 | 12 | SERVICE = "aiplatform.googleapis.com" 13 | RESOURCE_CATEGORY = "custom_model_training" 14 | 15 | 16 | class GcpMetrics(Enum): 17 | QUOTA_LIMIT = ("serviceruntime.googleapis.com/quota/limit", 86400 * 2) 18 | QUOTA_USAGE = ("serviceruntime.googleapis.com/quota/allocation/usage", 86400 * 2) 19 | 20 | def __init__(self, metric_type, delta): 21 | self.metric_type = metric_type 22 | self.interval = self.get_interval(delta) 23 | 24 | @staticmethod 25 | def get_interval(delta: int) -> monitoring_v3.TimeInterval: 26 | """ 27 | Produce a time interval 28 | 29 | Args: 30 | delta (int): amount of time (in seconds) in the past to subtract 31 | from the current time,this is adjusted based on metric type as each 32 | metric is updated at different intervals 33 | see: https://cloud.google.com/monitoring/api/metrics_gcp#gcp-serviceruntime 34 | 35 | Returns: 36 | monitoring_v3.TimeInterval 37 | """ 38 | start = datetime.utcnow() 39 | start_seconds = int(start.strftime("%s")) 40 | interval = monitoring_v3.TimeInterval( 41 | { 42 | "end_time": {"seconds": start_seconds, "nanos": 0}, 43 | "start_time": {"seconds": (start_seconds - delta), "nanos": 0}, 44 | } 45 | ) 46 | return interval 47 | 48 | 49 | def _create_quota_metric_suffix_from_cpu_type(machine_type: str) -> str: 50 | """ 51 | Generate a query filter to determine quota/usage for a given cpu type 52 | 53 | Args: 54 | machine_type (str): cpu type supplied by end user in GcpMachineConfig 55 | 56 | Raises: 57 | Key error, if cpu type not supported 58 | 59 | Returns: 60 | str: a metric.label.quota_metric for querying the monitoring API 61 | """ 62 | cpu_to_quota = { 63 | "a2": "a2_cpus", 64 | "n1": "cpus", 65 | "n2": "n2_cpus", 66 | "c2": "c2_cpus", 67 | "m1": "m1_cpus", 68 | "g2": "g2_cpus", 69 | } 70 | 71 | quota_type = cpu_to_quota[machine_type] 72 | 73 | return quota_type 74 | 75 | 76 | def _create_quota_metric_suffix_from_gpu_type(accelerator_type: str) -> str: 77 | """ 78 | Generate a query filter to determine quota/usage for a given accelerator 79 | 80 | Args: 81 | accelerator_type (str): accelerator supplied by end user in GcpAcceleratorConfig 82 | 83 | Raises: 84 | e: Key error, if accelerator type not supported 85 | 86 | Returns: 87 | str: a metric.label.quota_metric for querying the monitoring API 88 | """ 89 | 90 | return accelerator_type.lower() + "_gpus" 91 | 92 | 93 | def _get_most_recent_point( 94 | time_series: monitoring_types.metric.TimeSeries, 95 | ) -> str: 96 | """ 97 | Return the most end_time timestamp and quota value 98 | from a TimeSeries object. 99 | """ 100 | try: 101 | quota_val = time_series.points[0].value.int64_value 102 | except Exception: 103 | return "unknown" 104 | # if limit > 1*10^6, unlimited 105 | 106 | if quota_val > 10**6: 107 | return "unlimited" 108 | 109 | return str(quota_val) 110 | 111 | 112 | async def _make_metric_request( 113 | project: str, metric_type: GcpMetrics, quota_metric: str, region: str, logger 114 | ) -> str: 115 | """ 116 | Create a MetricServiceAsyncClient and try making a call to list time series for 117 | the given metric, if an error is encountered return a missing list. 118 | Args: 119 | project (str): 120 | metric_type (GcpMetrics): 121 | quota_metric (str): 122 | region (str): 123 | 124 | Returns: 125 | List[monitoring_types.metric.TimeSeries]: Return a list of Time series, 126 | if an error is encountered return an empty list 127 | """ 128 | try: 129 | client = monitoring_v3.MetricServiceAsyncClient() 130 | except Exception as e: 131 | logger.error(e) 132 | return [] 133 | 134 | try: 135 | results = await client.list_time_series( 136 | # the request should be specified so that it returns only one metric 137 | request={ 138 | "name": f"projects/{project}", 139 | "filter": f'metric.type = "{metric_type.metric_type}" AND\ 140 | metric.label.quota_metric="{quota_metric}" AND\ 141 | resource.labels.location="{region}"', 142 | "interval": metric_type.interval, 143 | "view": monitoring_v3.ListTimeSeriesRequest.TimeSeriesView.FULL, 144 | } 145 | ) 146 | except Exception as e: 147 | logger.error(e) 148 | return [] 149 | 150 | results_list = [] 151 | 152 | async for page in results: 153 | results_list.append(page) 154 | 155 | if len(results_list) == 0: 156 | return "unknown" 157 | 158 | return _get_most_recent_point(results_list[0]) 159 | 160 | 161 | def _get_resources_by_metric(resource: GcpResource) -> dict: 162 | """ 163 | Create a dictionary to store all resource types in a 164 | GcpResource object and num of each resource 165 | 166 | Args: 167 | resource (GcpResource): complete GcpResource obj 168 | 169 | Returns: 170 | dict: keys are quota_metric, values are number of that type of resource 171 | """ 172 | resources_by_metric = defaultdict(lambda: 0) 173 | 174 | cpu_resources = [resource.chief] 175 | if resource.workers: 176 | cpu_resources.append(resource.workers) 177 | 178 | gpu_resources = [] 179 | if resource.chief.accelerator: 180 | gpu_resources.append(resource.chief.accelerator) 181 | if resource.workers and resource.workers.accelerator: 182 | gpu_resources.append(resource.workers.accelerator) 183 | 184 | # cpus 185 | for pool in cpu_resources: 186 | cpu_type, _, num_cores = str.split(pool.type, "-") 187 | resources_by_metric[_create_quota_metric_suffix_from_cpu_type(cpu_type)] += ( 188 | int(num_cores) * pool.count 189 | ) 190 | 191 | # gpus 192 | for accelerator in gpu_resources: 193 | resources_by_metric[ 194 | _create_quota_metric_suffix_from_gpu_type(accelerator.type) 195 | ] += accelerator.count 196 | 197 | return resources_by_metric 198 | 199 | 200 | async def log_quotas_for_resource(resource: GcpResource) -> None: 201 | """ 202 | Parses all resources in a GcpResource object and queries GCP to determine 203 | current usage and quota limit for each resource type. 204 | 205 | Logs results to a Prefect run logger if available; gcp.monitoring logger if not. 206 | 207 | Args: 208 | resource (GcpResource) 209 | """ 210 | logger = get_prefect_logger(__name__) 211 | 212 | resources_by_quota_metric = _get_resources_by_metric(resource) 213 | 214 | for quota_metric_suffix, num_resouces in resources_by_quota_metric.items(): 215 | # create string of resource for logging 216 | metric_str = quota_metric_suffix 217 | if metric_str[-1] == "s": 218 | metric_str = metric_str[:-1] 219 | 220 | resource_log_str = ( 221 | f"VertexJob will consume {num_resouces} {metric_str} resources." 222 | ) 223 | 224 | # get quota limits 225 | limit_p = _make_metric_request( 226 | project=resource.environment.project, 227 | metric_type=GcpMetrics.QUOTA_LIMIT, 228 | quota_metric=f"{SERVICE}/{RESOURCE_CATEGORY}_{quota_metric_suffix}", 229 | region=resource.environment.region, 230 | logger=logger, 231 | ) 232 | 233 | # get quota usage 234 | usage_p = _make_metric_request( 235 | project=resource.environment.project, 236 | metric_type=GcpMetrics.QUOTA_USAGE, 237 | quota_metric=f"{SERVICE}/{RESOURCE_CATEGORY}_{quota_metric_suffix}", 238 | region=resource.environment.region, 239 | logger=logger, 240 | ) 241 | 242 | limit, usage = await gather(limit_p, usage_p) 243 | 244 | logger.info(resource_log_str + f"Current usage: {usage}; quota limit: {limit}.") 245 | -------------------------------------------------------------------------------- /block_cascade/executors/databricks/filesystem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom filesystem implementation for Databricks Unity Catalog Volumes. 3 | 4 | This module provides fsspec-compatible filesystem class for interacting with 5 | Unity Catalog Volumes (/Volumes////). 6 | 7 | Unity Catalog Volumes provide: 8 | - Proper security and permissions through Unity Catalog governance 9 | - Serverless compute compatibility 10 | - Fine-grained access control 11 | - Cross-workspace accessibility 12 | 13 | This filesystem uses the Databricks Files API to perform file operations 14 | from the client side. 15 | """ 16 | 17 | import io 18 | import os 19 | import logging 20 | from typing import BinaryIO 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | # Buffer size for streaming uploads 25 | BUFFER_SIZE_BYTES = 2**20 # 1MB 26 | 27 | 28 | class DatabricksFilesystem: 29 | """ 30 | An fsspec-compatible filesystem for Databricks Unity Catalog Volumes. 31 | 32 | This filesystem provides file operations (open, upload, delete) that work with 33 | Unity Catalog Volumes from the client side using the Databricks Files API. 34 | 35 | Unity Catalog Volumes provide proper security and permissions through UC governance, 36 | making them the recommended storage location for serverless Databricks compute. 37 | 38 | Path format: /Volumes//// 39 | Example: /Volumes/main/my_team/cascade/job_artifacts/ 40 | 41 | Parameters 42 | ---------- 43 | api_client : ApiClient 44 | Databricks API client from databricks-cli 45 | auto_mkdir : bool 46 | Automatically create parent directories when opening files for writing 47 | """ 48 | 49 | def __init__(self, api_client, auto_mkdir: bool = True): 50 | self.api_client = api_client 51 | self.auto_mkdir = auto_mkdir 52 | 53 | def _ensure_parent_dir(self, path: str) -> None: 54 | """Ensure parent directory exists if auto_mkdir is enabled.""" 55 | if not self.auto_mkdir: 56 | return 57 | 58 | parent_path = os.path.dirname(path) 59 | if parent_path and parent_path != "/Volumes": 60 | # Create directory using Files API - use session directly (no body) 61 | url = self.api_client.get_url('/fs/directories' + parent_path) 62 | headers = { 63 | 'Authorization': self.api_client.default_headers.get('Authorization', '') 64 | } 65 | response = self.api_client.session.put(url, headers=headers) 66 | response.raise_for_status() 67 | logger.debug(f"Created parent directory: {parent_path}") 68 | 69 | def open(self, path: str, mode: str = "rb") -> BinaryIO: 70 | """ 71 | Open a file for reading or writing in Unity Catalog Volumes. 72 | 73 | Parameters 74 | ---------- 75 | path : str 76 | Path to the file in Unity Catalog Volumes 77 | Example: /Volumes/catalog/schema/volume/file.pkl 78 | mode : str 79 | File mode: 'rb' for reading, 'wb' for writing 80 | 81 | Returns 82 | ------- 83 | BinaryIO 84 | File-like object for reading/writing 85 | """ 86 | 87 | if mode == "rb": 88 | # Read mode: download file using Files API 89 | # Download using Files API - use session directly for binary data 90 | # perform_query() tries to parse as JSON, but /fs/files returns raw bytes 91 | url = self.api_client.get_url('/fs/files' + path) 92 | headers = { 93 | 'Authorization': self.api_client.default_headers.get('Authorization', '') 94 | } 95 | response = self.api_client.session.get(url, headers=headers) 96 | response.raise_for_status() 97 | return io.BytesIO(response.content) 98 | 99 | elif mode == "wb": 100 | # Write mode: return a special file object that uploads on close 101 | self._ensure_parent_dir(path) 102 | return _DatabricksUploadFile(self.api_client, path) 103 | 104 | else: 105 | raise ValueError(f"Unsupported mode: {mode}. Only 'rb' and 'wb' are supported.") 106 | 107 | def upload(self, local_path: str, remote_path: str, overwrite: bool = True) -> None: 108 | """ 109 | Upload a local file to Unity Catalog Volumes. 110 | 111 | Parameters 112 | ---------- 113 | local_path : str 114 | Local file path to upload 115 | remote_path : str 116 | Remote path in Unity Catalog Volumes 117 | Example: /Volumes/catalog/schema/volume/file.py 118 | overwrite : bool 119 | Whether to overwrite existing file 120 | """ 121 | self._ensure_parent_dir(remote_path) 122 | 123 | # Read local file 124 | with open(local_path, "rb") as f: 125 | content = f.read() 126 | 127 | # Upload using Files API - use session directly for binary data 128 | url = self.api_client.get_url('/fs/files' + remote_path) 129 | headers = { 130 | 'Content-Type': 'application/octet-stream', 131 | 'Authorization': self.api_client.default_headers.get('Authorization', '') 132 | } 133 | response = self.api_client.session.put( 134 | url, 135 | data=content, 136 | headers=headers, 137 | params={'overwrite': str(overwrite).lower()} 138 | ) 139 | response.raise_for_status() 140 | 141 | def rm(self, path: str, recursive: bool = False) -> None: 142 | """ 143 | Delete a file or directory. 144 | 145 | Parameters 146 | ---------- 147 | path : str 148 | Path to delete 149 | recursive : bool 150 | If True, delete directory and all its contents 151 | """ 152 | 153 | if recursive: 154 | # For recursive deletion, delete files individually first, then the directory 155 | list_url = self.api_client.get_url('/fs/directories' + path) 156 | headers = { 157 | 'Authorization': self.api_client.default_headers.get('Authorization', '') 158 | } 159 | list_response = self.api_client.session.get(list_url, headers=headers) 160 | list_response.raise_for_status() 161 | 162 | contents = list_response.json() 163 | files = contents.get('contents', []) 164 | 165 | # Delete each file 166 | for item in files: 167 | item_path = item.get('path') 168 | if item_path: 169 | file_url = self.api_client.get_url('/fs/files' + item_path) 170 | file_response = self.api_client.session.delete(file_url, headers=headers) 171 | file_response.raise_for_status() 172 | 173 | # Delete the empty directory 174 | dir_url = self.api_client.get_url('/fs/directories' + path) 175 | dir_response = self.api_client.session.delete(dir_url, headers=headers) 176 | dir_response.raise_for_status() 177 | 178 | else: 179 | # Use files API for single file deletion 180 | url = self.api_client.get_url('/fs/files' + path) 181 | headers = { 182 | 'Authorization': self.api_client.default_headers.get('Authorization', '') 183 | } 184 | response = self.api_client.session.delete(url, headers=headers) 185 | response.raise_for_status() 186 | 187 | 188 | class _DatabricksUploadFile(io.BytesIO): 189 | """ 190 | A BytesIO-like object that uploads to Databricks when closed. 191 | 192 | This class allows us to use the standard `with open(path, 'wb') as f: f.write()` 193 | pattern while uploading to Databricks via the Files API. 194 | """ 195 | 196 | def __init__(self, api_client, remote_path: str): 197 | super().__init__() 198 | self.api_client = api_client 199 | self.remote_path = remote_path 200 | self._closed = False 201 | 202 | def close(self) -> None: 203 | """Upload the buffer contents to Databricks when closing.""" 204 | if self._closed: 205 | return 206 | 207 | try: 208 | # Get the buffer contents 209 | content = self.getvalue() 210 | 211 | # Upload using Files API - use session directly for binary data 212 | url = self.api_client.get_url('/fs/files' + self.remote_path) 213 | headers = { 214 | 'Content-Type': 'application/octet-stream', 215 | 'Authorization': self.api_client.default_headers.get('Authorization', '') 216 | } 217 | response = self.api_client.session.put( 218 | url, 219 | data=content, 220 | headers=headers, 221 | params={'overwrite': 'true'} 222 | ) 223 | response.raise_for_status() 224 | finally: 225 | self._closed = True 226 | super().close() 227 | 228 | def __enter__(self): 229 | return self 230 | 231 | def __exit__(self, exc_type, exc_val, exc_tb): 232 | if not self._closed: 233 | self.close() 234 | return False 235 | 236 | -------------------------------------------------------------------------------- /block_cascade/executors/vertex/job.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data model for running jobs on VertexAI 3 | """ 4 | 5 | from dataclasses import asdict, dataclass 6 | import os 7 | from typing import List, Mapping, Optional, Union 8 | 9 | from slugify import slugify 10 | 11 | from block_cascade.executors.vertex.resource import GcpMachineConfig, GcpResource 12 | from block_cascade.executors.vertex.tune import ( 13 | ParamCategorical, 14 | ParamDiscrete, 15 | ParamDouble, 16 | ParamInteger, 17 | Tune, 18 | ) 19 | 20 | DISTRIBUTED_JOB_FILENAME = "distributed_job.pkl" 21 | 22 | 23 | def _convert_to_gcp_compatible_label(val: str, is_key: bool = False) -> str: 24 | converted_val = slugify(val, max_length=63, regex_pattern=r"[^-a-z0-9_]+") 25 | if is_key and not converted_val: 26 | raise RuntimeError("Keys for GCP resources must be at least 1 " "character.") 27 | if is_key and not converted_val[0].isalpha(): 28 | raise RuntimeError( 29 | "Keys for GCP resources must start with " "a lowercase letter." 30 | ) 31 | return converted_val 32 | 33 | 34 | @dataclass(frozen=True) 35 | class VertexJob: 36 | """ 37 | Data model to convert a GCP resource object into a valid payload 38 | for a VertexAI CustomJob request 39 | 40 | This class is intended to be used by the VertexExecutor, and is not intended for use 41 | by the end user of this library. It is exposed for testing purposes. 42 | 43 | This does not support the whole Vertex API, and assumes that you 44 | - are using custom containers (supplied via the GcpResource.image parameter) 45 | - are running a python function (or partial) inside that container (available at 46 | the staged_filepath) 47 | 48 | display_name: str 49 | The name of the job, used for logging and tracking 50 | resource: GcpResource 51 | The resource object describing the GCP environment and cluster configuration 52 | see block_cascade.executors.vertex.resource for details 53 | storage_path: str 54 | The path to the directory used to store the staged file and output 55 | tune: Optional[Tune] = None 56 | Whether to run a hyperparameter tuning job, and if so, how to configure it 57 | dashboard: Optional[bool] = False 58 | Whether to enable a hyperlink on the job page, 59 | in order to view the Dask dashboard 60 | web_console_access: Optional[bool] = False 61 | Whether to allow web console access to the job 62 | code_package: Optional[str] = None 63 | The GCS path to the users first party code that is added to sys.path 64 | at runtime to handle unpickling of a function that references 65 | first party code module. 66 | """ 67 | 68 | display_name: str 69 | resource: GcpResource 70 | storage_path: str 71 | tune: Optional[Tune] = None 72 | dashboard: Optional[bool] = False 73 | web_console: Optional[bool] = False 74 | labels: Optional[Mapping[str, str]] = None 75 | code_package: Optional[str] = None 76 | 77 | @property 78 | def distributed_job_path(self): 79 | return os.path.join(self.storage_path, DISTRIBUTED_JOB_FILENAME) 80 | 81 | def create_payload(self): 82 | """Conversion from our resource data model to a CustomJob request on vertex 83 | 84 | https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs#CustomJob 85 | """ 86 | gcp_compatible_labels = { 87 | _convert_to_gcp_compatible_label( 88 | key, is_key=True 89 | ): _convert_to_gcp_compatible_label(val) 90 | for key, val in (self.labels or {}).items() 91 | } 92 | if self.tune is None: 93 | return { 94 | "display_name": self.display_name, 95 | "job_spec": self._create_job_spec(), 96 | "labels": gcp_compatible_labels, 97 | } 98 | else: 99 | return { 100 | "display_name": self.display_name, 101 | "trial_job_spec": self._create_job_spec(), 102 | "max_trial_count": self.tune.trials, 103 | "parallel_trial_count": self.tune.parallel, 104 | "study_spec": self._create_study_spec(), 105 | "labels": gcp_compatible_labels, 106 | } 107 | 108 | def _create_job_spec(self): 109 | """ 110 | Creates a CustomJobSpec from the resource object passed by the user. 111 | API docs at: 112 | https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec 113 | """ 114 | job_spec = { 115 | "worker_pool_specs": self._create_cluster_spec(), 116 | } 117 | environment = self.resource.environment 118 | 119 | if environment.network is not None: 120 | job_spec["network"] = environment.network 121 | 122 | if environment.service_account is not None: 123 | job_spec["service_account"] = environment.service_account 124 | 125 | if self.dashboard is True: 126 | job_spec["enable_dashboard_access"] = True 127 | 128 | if self.web_console is True: 129 | job_spec["enable_web_access"] = True 130 | 131 | if self.resource.persistent_resource_id is not None: 132 | job_spec["persistent_resource_id"] = self.resource.persistent_resource_id 133 | 134 | return job_spec 135 | 136 | def _create_machine_pool_spec(self, machine_config: GcpMachineConfig): 137 | """ 138 | Uses a machine config (descriping chief or worker pool) to a specification 139 | for a given machine pool in the Vertex API 140 | """ 141 | node_pool_spec = { 142 | "replica_count": machine_config.count, 143 | "container_spec": self._create_container_spec(), 144 | "machine_spec": self._create_machine_spec(machine_config), 145 | "nfs_mounts": self._create_nfs_specs(machine_config), 146 | } 147 | 148 | if machine_config.disk_size_gb is not None: 149 | node_pool_spec["disk_spec"] = { 150 | "boot_disk_size_gb": machine_config.disk_size_gb 151 | } 152 | 153 | return node_pool_spec 154 | 155 | def _create_cluster_spec(self) -> List[dict]: 156 | """ 157 | Creates a cluster spec from chief and (optional) worker specs 158 | https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#WorkerPoolSpec 159 | """ 160 | 161 | # First pool spec must have exactly one replica, its intended to be the "chief" 162 | # in single machine use cases this is the only spec 163 | cluster_spec = [] 164 | if self.resource.chief.count != 1: 165 | raise ValueError( 166 | f"Chief pool must have exactly one replica, got {self.resource.chief.count}" # noqa: E501 167 | ) 168 | else: 169 | cluster_spec = [] 170 | cluster_spec.append(self._create_machine_pool_spec(self.resource.chief)) 171 | 172 | # worker pool specs are optional 173 | if self.resource.workers is not None: 174 | cluster_spec.append(self._create_machine_pool_spec(self.resource.workers)) 175 | 176 | return cluster_spec 177 | 178 | def _create_container_spec(self): 179 | """ 180 | The image to use to create a container and the entrypoint for 181 | that container 182 | """ 183 | 184 | executor_module_path = "block_cascade.executors.vertex.run" 185 | distributed_job = "False" 186 | if self.resource.distributed_job is not None: 187 | distributed_job = "True" 188 | 189 | command = [ 190 | "python", 191 | "-m", 192 | executor_module_path, 193 | self.storage_path, 194 | distributed_job, 195 | self.code_package or "", 196 | ] 197 | 198 | return { 199 | "image_uri": self.resource.environment.image, 200 | "command": command, 201 | "args": [], 202 | } 203 | 204 | @staticmethod 205 | def _create_machine_spec(machine_config: GcpMachineConfig): 206 | """ 207 | Adds chief_machine_spec information to the job_spec 208 | """ 209 | machine_spec = {"machine_type": machine_config.type} 210 | if machine_config.accelerator is not None: 211 | machine_spec["accelerator_type"] = machine_config.accelerator.type 212 | machine_spec["accelerator_count"] = machine_config.accelerator.count 213 | return machine_spec 214 | 215 | @staticmethod 216 | def _create_nfs_specs(machine_config: GcpMachineConfig) -> List[dict]: 217 | """ 218 | Produce a list of NFS mount specs from a machine config 219 | """ 220 | nfs_specs = [] 221 | if machine_config.nfs_mounts is not None: 222 | for nfs_mount in machine_config.nfs_mounts: 223 | nfs_specs.append(asdict(nfs_mount)) 224 | return nfs_specs 225 | 226 | @staticmethod 227 | def _create_disk_spec(machine_config: GcpMachineConfig) -> dict: 228 | """ 229 | Produce a disk spec from a machine config 230 | """ 231 | if machine_config.disk_size_gb is not None: 232 | return {"boot_disk_size_gb": machine_config.disk_size_gb} 233 | return None 234 | 235 | # tuning spec creation 236 | def _create_study_spec(self) -> dict: 237 | """ 238 | Create a study specification dictionary from the tuning metrics, goals, 239 | and parameters 240 | """ 241 | study_spec = { 242 | "metrics": [{"metric_id": self.tune.metric, "goal": self.tune.goal}], 243 | "parameters": [self._create_parameter_spec(p) for p in self.tune.params], 244 | } 245 | if self.tune.algorithm is not None: 246 | study_spec["algorithm"] = self.tune.algorithm 247 | return study_spec 248 | 249 | def _create_parameter_spec( 250 | self, param: Union[ParamDiscrete, ParamCategorical, ParamInteger, ParamDouble] 251 | ) -> dict: 252 | """ 253 | Create a parameter specification dictionary with 254 | sub-specifications for each type of value 255 | """ 256 | s = {"parameter_id": param.name} 257 | 258 | if hasattr(param, "scale") and param.scale is not None: 259 | s["scale_type"] = param.scale.name 260 | 261 | if isinstance(param, ParamDouble): 262 | s["double_value_spec"] = {"min_value": param.min, "max_value": param.max} 263 | elif isinstance(param, ParamInteger): 264 | s["integer_value_spec"] = {"min_value": param.min, "max_value": param.max} 265 | elif isinstance(param, ParamCategorical): 266 | s["categorical_value_spec"] = {"values": param.values} 267 | elif isinstance(param, ParamDiscrete): 268 | s["discrete_value_spec"] = {"values": param.values} 269 | return s 270 | -------------------------------------------------------------------------------- /block_cascade/executors/vertex/distributed/distributed_job.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | import json 4 | import logging 5 | import os 6 | import pickle 7 | import socket 8 | from subprocess import Popen, check_call 9 | import sys 10 | import time 11 | from typing import Callable, List 12 | import warnings 13 | 14 | import gcsfs 15 | 16 | VERTEX_DASHBOARD_PORT = "8888" 17 | 18 | 19 | @dataclass 20 | class DistributedJobBase(ABC): 21 | """ 22 | Abstract base class for distributed jobs. 23 | To define a new distributed job, inherit from this class 24 | and implement the `run` method. 25 | """ 26 | 27 | @staticmethod 28 | def get_pool_number(): 29 | """ 30 | Parse a Vertex job instance's CLUSTER_SPEC environment variable 31 | to determine whether this job is being run on a `chief`, `worker`, 32 | `parameter_server` or `evaluator`. 33 | 34 | See docs: https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-spec-format 35 | 36 | Some cluster workloads - e.g. Dask - do not use evaluators and parameter servers. 37 | 38 | """ # noqa: E501 39 | 40 | if "CLUSTER_SPEC" not in os.environ: 41 | warnings.warn( 42 | "Did not find CLUSTER_SPEC in environment. CLUSTER_SPEC is expected to " 43 | "be in environment if running on a Vertex AIP cluster. " 44 | "https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-spec-format" 45 | ) 46 | return None 47 | 48 | try: 49 | clusterspec = json.loads(os.environ.get("CLUSTER_SPEC")) 50 | except json.JSONDecodeError as e: 51 | logging.error( 52 | "Found CLUSTER_SPEC in environment but cannot parse it as JSON." 53 | ) 54 | raise e 55 | 56 | workerpool = clusterspec.get("task", {}).get("type", "") 57 | # e.g. "workerpool0", "workerpool1" 58 | workerpool_number = workerpool.replace("workerpool", "") # e.g. 0, 1 59 | return int(workerpool_number) 60 | 61 | def run_function(self, dump_output=True): 62 | """ 63 | Runs the function provided at initialization and optionally 64 | dumps the output to GCS. 65 | """ 66 | logging.info("Starting user code execution") 67 | result = self.func() 68 | output_path = f"{self.storage_path}/output.pkl" 69 | 70 | if dump_output: 71 | logging.info(f"Saving output of task to {output_path}") 72 | fs = gcsfs.GCSFileSystem() 73 | with fs.open(output_path, "wb") as f: 74 | pickle.dump(result, f) 75 | 76 | def run(self, func: Callable, storage_path: str): 77 | """ 78 | Initializes the function and storage path for the distributed job. 79 | It is necessary to initialize via the run method as the function 80 | and storage path are unknown at initialize time (when the resource 81 | object is created). 82 | """ 83 | self.func = func 84 | self.storage_path = storage_path 85 | self._run() 86 | 87 | @abstractmethod 88 | def _run(self): 89 | """To be implemented by child classes, this method is intended as the 90 | entrypoint for running distributed jobs. 91 | 92 | An example _run method for a distributed job that would run a startup 93 | script on both chief and worker nodes and then execute the job function 94 | on the chief node. 95 | 96 | task = self.get_pool_number() 97 | if task == 0: # on chief 98 | self.start_chief() 99 | self.run_function() 100 | 101 | elif task == 1: # on worker 102 | self.start_worker() 103 | """ 104 | pass 105 | 106 | 107 | @dataclass 108 | class DistributedJob(DistributedJobBase): 109 | """ 110 | A basic distributed job that launches a cluster without 111 | any preconfiguration or startup code apart from what Vertex AI provides, see 112 | https://cloud.google.com/vertex-ai/docs/training/distributed-training 113 | """ 114 | 115 | def _run(self): 116 | """ 117 | Executes the function and only dumps output from the chief node. 118 | """ 119 | task = self.get_pool_number() 120 | if task == 0: # on chief 121 | self.run_function(dump_output=True) 122 | 123 | elif task == 1: # on worker 124 | self.run_function(dump_output=False) 125 | 126 | 127 | @dataclass 128 | class DaskJob(DistributedJobBase): 129 | """ 130 | Configure a Dask Job. 131 | 132 | Parameters 133 | ---------- 134 | chief_port: str 135 | A Job object. Port on which scheduler should listen. Be sure to supply a 136 | string and not an integer. Defaults to "8786". 137 | 138 | scheduler_cli_args: list 139 | Additional arguments to pass to scheduler process in format 140 | ['--arg1', 'arg1value', '--arg2', 'arg2value']. Defaults to []. 141 | 142 | worker_cli_args: list 143 | Additional arguments to pass to worker processes in format 144 | ['--arg1', 'arg1value', '--arg2', 'arg2value']. Defaults to []. 145 | """ 146 | 147 | chief_port: str = "8786" 148 | scheduler_cli_args: List[str] = field(default_factory=list) 149 | worker_cli_args: List[str] = field(default_factory=list) 150 | logger: logging.Logger = None 151 | 152 | @property 153 | def chief_ip(self) -> str: 154 | """ 155 | Gets the IP of the chief host. 156 | 157 | If run on chief, determines the IP of the host on the local network. 158 | 159 | If run on worker, waits for that file to materialize on GCS and returns 160 | the IP of the chief host. 161 | """ 162 | task = self.get_pool_number() 163 | 164 | fs = gcsfs.GCSFileSystem() 165 | if task == 0: # on chief 166 | host_name = socket.gethostname() 167 | return socket.gethostbyname(host_name) 168 | else: 169 | 170 | chief_ip_file = self.get_chief_ip_file() 171 | 172 | # workers (and potentially evaluators, parameter servers not yet used) 173 | # look for the file on startup 174 | times_tried = 0 175 | retry_delay = 5 # seconds 176 | max_retries = 36 # 36 tries * 5 seconds = 3 mins 177 | while times_tried <= max_retries: 178 | try: 179 | with fs.open(chief_ip_file, "r") as f: 180 | return f.read().rstrip("\n") 181 | except FileNotFoundError: 182 | if times_tried == max_retries: 183 | raise TimeoutError( 184 | "Timed out waiting for chief to stage IP file." 185 | ) 186 | self.logger.info( 187 | f"waiting for scheduler IP file to be ready at {chief_ip_file}" 188 | ) 189 | times_tried += 1 190 | time.sleep(retry_delay) 191 | 192 | def __setstate__(self, state): 193 | self.__dict__.update(state) 194 | 195 | pool_number = self.get_pool_number() 196 | if pool_number is not None: 197 | 198 | machine_type = "chief" if not pool_number else f"worker{pool_number}" 199 | 200 | self.logger = logging.getLogger(__name__) 201 | self.logger.setLevel(logging.INFO) 202 | self.logger.propagate = False 203 | 204 | handler = logging.StreamHandler(stream=sys.stdout) 205 | handler.setLevel(logging.INFO) 206 | 207 | formatter = logging.Formatter( 208 | f"%(asctime)s {machine_type} %(levelname)s: %(message)s" 209 | ) 210 | handler.setFormatter(formatter) 211 | self.logger.addHandler(handler) 212 | 213 | def get_chief_ip_file(self) -> str: 214 | """ 215 | Path on GCS at which file containing chief IP is 216 | expected to be found. 217 | """ 218 | return os.path.join(self.storage_path, "chief_ip.txt") 219 | 220 | def get_chief_address(self) -> str: 221 | """ 222 | Return IP and port of chief in format 000.000.000.000:0000 223 | """ 224 | 225 | return f"{self.chief_ip}:{self.chief_port}" 226 | 227 | def start_chief(self): 228 | """ 229 | Start the Dask scheduler. 230 | 231 | Pass `scheduler_cli_args` at object creation time in format 232 | `DaskJob(scheduler_cli_args = ['--arg1', 'arg1value', '--arg2', 'arg2value'])` 233 | to modify worker process startup. 234 | """ 235 | 236 | chief_ip_file = self.get_chief_ip_file() 237 | 238 | fs = gcsfs.GCSFileSystem() 239 | with fs.open(chief_ip_file, "w") as f: 240 | f.write(self.chief_ip) 241 | 242 | self.logger.info(f"The scheduler IP is {self.chief_ip}") 243 | 244 | # This allows users (who invoke Client in their code) point their 245 | # Dask Client at a file called `__scheduler__` 246 | # i.e. distributed.Client(scheduler_file="__scheduler__") 247 | with open("__scheduler__", "w") as file: 248 | json.dump({"address": self.get_chief_address()}, file) 249 | 250 | Popen( 251 | [ 252 | "dask", 253 | "scheduler", 254 | "--protocol", 255 | "tcp", 256 | "--port", 257 | self.chief_port, 258 | *self.scheduler_cli_args, 259 | "--dashboard-address", 260 | f":{VERTEX_DASHBOARD_PORT}", 261 | ], 262 | stdout=sys.stdout, 263 | stderr=sys.stderr, 264 | ) 265 | 266 | def start_worker(self): 267 | """ 268 | Starts a Dask worker process. 269 | 270 | Pass `worker_cli_args` at object creation time in format 271 | `DaskJob(worker_cli_args = ['--arg1', 'arg1value', '--arg2', 'arg2value'])` 272 | to modify worker process startup. 273 | """ 274 | self.logger.info(f"Chief Ip: {self.chief_ip}") 275 | check_call( 276 | [ 277 | "dask", 278 | "worker", 279 | self.get_chief_address(), 280 | *self.worker_cli_args, 281 | ], 282 | stdout=sys.stdout, 283 | stderr=sys.stderr, 284 | ) 285 | 286 | def _run(self): 287 | """ 288 | Runs startup on chief and worker nodes to set up Dask cluster. 289 | Runs the function on the chief (scheduler) node and dumps output to GCS. 290 | """ 291 | self.logger.info("Starting dask run function") 292 | task = self.get_pool_number() 293 | if task == 0: # on chief 294 | self.start_chief() 295 | self.run_function(dump_output=True) 296 | 297 | elif task == 1: # on worker 298 | self.start_worker() 299 | -------------------------------------------------------------------------------- /block_cascade/executors/databricks/resource.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import logging 3 | import os 4 | from typing import Any, Iterator, List, Optional, Union 5 | 6 | from pydantic import BaseModel, Field, model_validator 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class DatabricksSecret(BaseModel): 12 | """Databricks secret to auth to Databricks 13 | 14 | Parameters 15 | ---------- 16 | host: str 17 | token: str 18 | """ 19 | 20 | host: str 21 | token: str 22 | 23 | 24 | class DatabricksAutoscaleConfig(BaseModel): 25 | """Configuration for autoscaling on DataBricks clusters. 26 | 27 | Parameters 28 | ---------- 29 | min_workers: Optional[int] 30 | Minimum number of workers to scale to. Default is 1. 31 | max_workers: Optional[int] 32 | Maximum number of workers to scale to. Default is 8. 33 | 34 | """ 35 | 36 | min_workers: int = 1 37 | max_workers: int = 8 38 | 39 | 40 | class DatabricksPythonLibrary(BaseModel): 41 | """Configuration for a Python library to be installed on a Databricks cluster. 42 | 43 | Reference: https://docs.databricks.com/aws/en/reference/jobs-2.0-api#pythonpypilibrary 44 | 45 | Parameters 46 | ---------- 47 | name: str 48 | The name of the package. 49 | repo: Optional[str] 50 | The Python package index to install the package from. If not specified, 51 | defaults to the configured package index on the Databricks cluster which 52 | is most likely PyPI. 53 | version: Optional[str] 54 | The version of the package. 55 | infer_version: bool 56 | Whether to infer the version of the package from the current 57 | environment if not specified explicitly. Defaults to True. 58 | It is critical to recognize the Databricks runtime has preinstalled 59 | packages so version pinning can lead to an incompatible Databricks 60 | runtime. Alternatively by not pinning, an incompatible version 61 | of the dependency could be installed that is not compatible with 62 | your code. 63 | """ 64 | name: str 65 | repo: Optional[str] = None 66 | version: Optional[str] = None 67 | infer_version: bool = True 68 | 69 | @model_validator(mode="after") 70 | def maybe_update_version(self): 71 | if self.infer_version and not self.version: 72 | try: 73 | self.version = importlib.metadata.version(self.name) 74 | except importlib.metadata.PackageNotFoundError: 75 | logger.warning( 76 | f"Could not infer version for package '{self.name}' from runtime. " 77 | "The version will be left unspecified for Databricks runtime " 78 | "installation." 79 | ) 80 | return self 81 | 82 | def __str__(self) -> str: 83 | """Convert to pip requirement string format.""" 84 | return f"{self.name}=={self.version}" if self.version else self.name 85 | 86 | def model_dump(self, **kwargs) -> dict: 87 | package_specififer = str(self) 88 | return { 89 | "pypi": { 90 | "package": package_specififer, 91 | **({"repo": self.repo} if self.repo else {}) 92 | } 93 | } 94 | 95 | 96 | class DatabricksResource(BaseModel): 97 | """Description of a Databricks Cluster 98 | 99 | Parameters 100 | ---------- 101 | storage_location: str 102 | Path to the directory where files will be staged and output written. 103 | Storage location can be either Unity Catalog Volumes (/Volumes/) or S3 (s3://). Unity Catalog 104 | is required for serverless compute. 105 | This format for Unity Catalog Volumes is: /Volumes/// 106 | worker_count: Union[int, DatabricksAutoscaleConfig] 107 | If an integer is supplied, specifies the of workers in Databricks cluster. 108 | If a `DatabricksAutoscaleConfig` is supplied, specifies the autoscale 109 | configuration to use. Default is 1 worker without autoscaling enabled. 110 | machine: str 111 | AWS machine type for worker nodes. See https://www.databricks.com/product/aws-pricing/instance-types 112 | Default is i3.xlarge (4 vCPUs, 31 GB RAM) 113 | spark_version: Optional[str] 114 | Databricks runtime version. 115 | https://docs.databricks.com/release-notes/runtime/releases.html 116 | Required when use_serverless=False. Must not be set when use_serverless=True. 117 | data_security_mode: Optional[str] 118 | See `data_security_mode` at 119 | https://docs.databricks.com/administration-guide/clusters/policies.html#cluster-policy-attribute-paths 120 | Sets Databricks security mode. At time of writing, Delta Live Tables require `SINGLE_USER` mode. In Cascade 121 | versions <=0.9.5, default was `NONE`. 122 | cluster_spec_overrides: Optional[dict] 123 | Additional entries to add to task `new_cluster` object in Databricks API call. 124 | https://docs.databricks.com/dev-tools/api/latest/clusters.html#request-structure-of-the-cluster-definition 125 | Example: {"spark_env_vars": {'A_VARIABLE': "A_VALUE"}} 126 | cluster_policy: Optional[str] = None 127 | Databricks cluster policy name (policy ID is looked up using this name). See 128 | https://docs.databricks.com/administration-guide/clusters/policies.html for details. 129 | By default, looks up `group_name`'s default cluster policy. Most users do 130 | not need to configure. 131 | existing_cluster_id: Optional[str] = None 132 | If specified, does not start a new cluster and instead attempts to deploy this job 133 | to an existing cluster with this ID. Useful during testing to avoid lag time of 134 | repeated cluster starts between iterations. 135 | https://docs.databricks.com/clusters/create-cluster.html 136 | Get cluster ID from JSON link within cluster info in Databricks UI. 137 | group_name: str 138 | The group name to run as in the databricks instance 139 | See "access_control_list"."group_name" in Databricks Job's API 140 | https://docs.databricks.com/api/workspace/jobs/create 141 | secret : Optional[DatabricksSecret] 142 | Token and hostname used to authenticate 143 | Required to run tasks on Databricks 144 | s3_credentials: dict 145 | Credentials to access S3, will be used to initialize S3FileSystem by 146 | calling s3fs.S3FileSystem(**s3_credentials). 147 | Required when storage_location starts with "s3://" (cluster compute). 148 | If no credentials are provided boto's credential resolver will be used. 149 | For details see: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html 150 | Not needed for Unity Catalog Volumes (/Volumes/, serverless compute). 151 | cloud_pickle_by_value: list[str] 152 | List of names of modules to be pickled by value instead of by reference. 153 | cloudpickle_infer_base_module: bool = True 154 | Whether to automatically infer the remote function's base module to be included 155 | in the `cloudpickle_by_value` 156 | task_args: Optional[dict] = None 157 | If provided, pass these additional arguments into the databricks task. This can 158 | be used 159 | python_libraries: List[PythonLibrary] = [] 160 | If provided, install these additional libraries on the cluster when the 161 | remote task is run. 162 | timeout_seconds: int = 86400 163 | The maximum time this job can run for; default is 24 hours. 164 | use_serverless: bool = False 165 | If True, use Databricks serverless compute instead of provisioning a cluster. 166 | When enabled, cluster-related parameters (worker_count, machine, spark_version, 167 | cluster_policy, existing_cluster_id) are ignored. 168 | See https://docs.databricks.com/api/workspace/jobs/submit for details. 169 | serverless_environment_version: str = "3" 170 | Serverless environment version. Each version comes with specific Python version 171 | and set of preinstalled packages. Default is "3". 172 | See https://docs.databricks.com/aws/release-notes/serverless/#serverless-environment-versions 173 | Only used when use_serverless=True. 174 | 175 | """ # noqa: E501 176 | 177 | storage_location: str 178 | worker_count: Union[int, DatabricksAutoscaleConfig] = 1 179 | machine: str = "i3.xlarge" 180 | spark_version: Optional[str] = None 181 | data_security_mode: Optional[str] = "SINGLE_USER" 182 | cluster_spec_overrides: Optional[dict] = None 183 | cluster_policy: Optional[str] = None 184 | existing_cluster_id: Optional[str] = None 185 | use_serverless: bool = False 186 | group_name: str = Field(default_factory=lambda: os.environ.get("DATABRICKS_GROUP", "default-group")) 187 | secret: Optional[DatabricksSecret] = None 188 | s3_credentials: Optional[dict] = None 189 | cloud_pickle_by_value: List[str] = Field(default_factory=list) 190 | cloud_pickle_infer_base_module: bool = True 191 | task_args: Optional[dict] = None 192 | python_libraries: list[Union[str, DatabricksPythonLibrary]] = Field(default_factory=list) 193 | timeout_seconds: int = 86400 194 | serverless_environment_version: str = "3" 195 | 196 | @model_validator(mode="after") 197 | def convert_string_libraries_to_objects(self): 198 | """Convert any string library names to DatabricksPythonLibrary objects for backwards compatibility. 199 | 200 | Supports formats: 201 | - "package_name" 202 | - "package_name==version" 203 | """ 204 | converted_libraries = [] 205 | for lib in self.python_libraries: 206 | if isinstance(lib, str): 207 | # Parse string to extract name and version if specified 208 | if "==" in lib: 209 | name, version = lib.split("==", 1) 210 | converted_libraries.append(DatabricksPythonLibrary(name=name.strip(), version=version.strip())) 211 | else: 212 | converted_libraries.append(DatabricksPythonLibrary(name=lib.strip())) 213 | else: 214 | converted_libraries.append(lib) 215 | self.python_libraries = converted_libraries 216 | return self 217 | 218 | @model_validator(mode="after") 219 | def validate_serverless_configuration(self): 220 | """Validate serverless vs cluster configuration parameters.""" 221 | if self.use_serverless: 222 | # Validate storage location for serverless 223 | if not self.storage_location.startswith("/Volumes/"): 224 | logger.warning( 225 | f"Serverless compute is enabled but storage_location is '{self.storage_location}'. " 226 | "Serverless compute requires Unity Catalog Volumes (format: /Volumes////). " 227 | "This may cause the job to fail." 228 | ) 229 | 230 | # Warn if existing_cluster_id is set (will be ignored) 231 | if self.existing_cluster_id: 232 | logger.info( 233 | "Serverless compute is enabled. The existing_cluster_id parameter will be ignored." 234 | ) 235 | 236 | # Validate spark_version is not set for serverless 237 | if self.spark_version is not None: 238 | raise ValueError( 239 | "spark_version is not applicable for serverless compute" 240 | ) 241 | else: 242 | # Validate spark_version is set for cluster compute 243 | if self.spark_version is None: 244 | raise ValueError( 245 | "spark_version must set for cluster compute. " 246 | "Please specify a Databricks runtime version (e.g., '17.3.x-scala2.13')." 247 | "Note: for Serverless compute, set use_serverless and specify serverless_environment_version." 248 | ) 249 | 250 | return self 251 | -------------------------------------------------------------------------------- /block_cascade/decorators.py: -------------------------------------------------------------------------------- 1 | from functools import partial, wraps 2 | from importlib.metadata import version 3 | from pathlib import Path 4 | from typing import Callable, Optional, Union 5 | 6 | import requests 7 | 8 | from block_cascade.config import find_default_configuration 9 | from block_cascade.gcp import VMMetadataServerClient 10 | from block_cascade.executors.databricks.resource import DatabricksResource 11 | from block_cascade.executors.databricks.executor import DatabricksExecutor 12 | from block_cascade.executors.local.executor import LocalExecutor 13 | from block_cascade.executors.vertex.executor import VertexExecutor 14 | from block_cascade.executors.vertex.resource import GcpEnvironmentConfig, GcpResource 15 | from block_cascade.executors.vertex.tune import Tune 16 | from block_cascade.prefect import ( 17 | PrefectEnvironmentClient, 18 | get_from_prefect_context, 19 | get_prefect_logger, 20 | is_prefect_cloud_deployment, 21 | ) 22 | from block_cascade.utils import _infer_base_module, wrapped_partial 23 | 24 | RESERVED_ARG_PREFIX = "remote_" 25 | 26 | 27 | def remote( 28 | func: Union[Callable, partial, None] = None, 29 | resource: Union[GcpResource, DatabricksResource] = None, 30 | config_name: Optional[str] = None, 31 | job_name: Optional[str] = None, 32 | web_console_access: Optional[bool] = False, 33 | tune: Optional[Tune] = None, 34 | code_package: Optional[Path] = None, 35 | remote_resource_on_local: bool = True, 36 | *args, 37 | **kwargs, 38 | ): 39 | """ 40 | Decorator factory to generate a decorator that can be used to run a 41 | function remotely. 42 | 43 | @tasks 44 | @remote(resource=GcpResource(...)) 45 | def train_model(): 46 | ... 47 | 48 | The function (along with captured arguments) will then run as a separate vertex job 49 | on the specified resource. 50 | 51 | Some parameters to the remote function that cannot be determined at function 52 | decoration time can also be passed in during function run time, through keyword 53 | args with the convention `remote_{arg}` (i.e. to overwrite `tune`, set 54 | `remote_tune` in the function). 55 | 56 | @task 57 | @remote() 58 | def train_model(remote_tune=Tune(...)) 59 | ... 60 | 61 | Parameters 62 | ---------- 63 | func: Function 64 | The function to run remotely 65 | resource: Union[GcpResource, DatabricksResource] 66 | A description of the remote resource (environment and cluster configuration) 67 | to run the function on 68 | config_name: Optional[str] 69 | The name of the configuration to use; must be a named block in cascade.yml 70 | If not provided, the job_name will be used to key the configuration 71 | job_name: Optional[str] 72 | The display name for the job; if no name is passed (default) this will 73 | be inferred from the function name and environment 74 | web_console_access: Optional[bool] 75 | For VertexAI jobs in GCP, whether to allow web console access to the job 76 | tune: Optional[Tune] 77 | An optional Tune object to use for hyperparameter tuning; only on VertexAI 78 | code_package: Optional[Path] 79 | An optional path to the first party code that your remo. 80 | This is only necessary if the following conditions hold true: 81 | - The function is desired to run in Vertex AI 82 | - The function is not being executed from a Prefect2/3 Cloud Deployment 83 | - The function references a module that is not from a third party 84 | dependency, but from the same package the function is a member of. 85 | remote_resource_on_local: bool 86 | When running a Prefect flow locally: 87 | - If True: use specified remote resource (GCPResource requires an Image URI) 88 | - If False: set remote resource to None and fallback to LocalExecutor 89 | If the flow is running the Prefect Cloud, this argument will have no effect, regardless of the value. 90 | """ 91 | if not resource: 92 | resource_configurations = find_default_configuration() or {} 93 | if config_name: 94 | resource = resource_configurations.get(config_name) 95 | else: 96 | resource = resource_configurations.get(job_name) 97 | 98 | remote_args = locals() 99 | # Support calling this with arguments before using as a decorator, e.g. this 100 | # allows us to do 101 | # 102 | # remote = remote(resource=GcpResource(...)) 103 | # ... 104 | # @remote 105 | # def train_model(): 106 | # ...s 107 | if func is None: 108 | return partial( 109 | remote, 110 | job_name=job_name, 111 | web_console_access=web_console_access, 112 | resource=resource, 113 | tune=tune, 114 | code_package=code_package, 115 | remote_resource_on_local=remote_resource_on_local, 116 | *args, 117 | **kwargs, 118 | ) 119 | 120 | @wraps(func) 121 | def remote_func(*args, **kwargs): 122 | """ 123 | Remote function that will be returned by the decorator factory. 124 | Will inherit the docstring and name of the function it decorates. 125 | """ 126 | 127 | # list of parameters that can be overriden by supplying their value at 128 | # function call time as a keyword argument with the prefix "remote_" 129 | nonlocal remote_args 130 | for parameter in remote_args: 131 | if f"{RESERVED_ARG_PREFIX}{parameter}" in kwargs: 132 | remote_args[parameter] = kwargs[f"{RESERVED_ARG_PREFIX}{parameter}"] 133 | kwargs.pop(f"{RESERVED_ARG_PREFIX}{parameter}", None) 134 | 135 | resource = remote_args.get("resource", None) 136 | job_name = remote_args.get("job_name", None) 137 | tune = remote_args.get("tune", None) 138 | code_package = remote_args.get("code_package", None) 139 | web_console_access = remote_args.get("web_console_access", False) 140 | remote_resource_on_local = remote_args.get('remote_resource_on_local', True) 141 | 142 | # get the prefect logger and flow metadata if available 143 | # to determine if this flow is running on the cloud 144 | prefect_logger = get_prefect_logger(__name__) 145 | 146 | flow_id = get_from_prefect_context("flow_id", "LOCAL") 147 | flow_name = get_from_prefect_context("flow_name", "LOCAL") 148 | task_id = get_from_prefect_context("task_run_id", "LOCAL") 149 | task_name = get_from_prefect_context("task_run", "LOCAL") 150 | 151 | via_cloud = is_prefect_cloud_deployment() 152 | prefect_logger.info(f"Via cloud? {via_cloud}") 153 | 154 | # create a new wrapped partial function with the passed *args and **kwargs 155 | # so that it can be sent to the remote executor with its parameters 156 | packed_func = wrapped_partial(func, *args, **kwargs) 157 | 158 | # if running a flow locally ignore the remote resource, even if specified 159 | # necessary for running a @remote decorated task in a local flow 160 | if not via_cloud and not remote_resource_on_local: 161 | prefect_logger.info("Not running in Prefect Cloud and remote_resource_on_local=False." 162 | "Because of this Cascade remote resource set to None and LocalExecutor is used.") 163 | resource = None 164 | 165 | # if no resource is passed, run locally 166 | if resource is None: 167 | prefect_logger.info("Executing task with LocalExecutor.") 168 | executor = LocalExecutor(func=packed_func) 169 | 170 | # if a GcpResource is passed, try to run on Vertex 171 | elif isinstance(resource, GcpResource): 172 | prefect_logger.info("Executing task with GcpResource.") 173 | # Align naming with labels defined by Prefect 174 | # Infrastructure: https://github.com/PrefectHQ/prefect/blob/main/src/prefect/infrastructure/base.py#L134 175 | # and mutated to be GCP compatible: https://github.com/PrefectHQ/prefect-gcp/blob/main/prefect_gcp/aiplatform.py#L214 176 | labels = { 177 | "prefect-io_flow-run-id": flow_id, 178 | "prefect-io_flow-name": flow_name, 179 | "prefect-io_task-name": task_name, 180 | "prefect-io_task-id": task_id, 181 | "block_cascade-version": version("block_cascade"), 182 | } 183 | resource.environment = resource.environment or GcpEnvironmentConfig() 184 | if resource.environment.is_complete: 185 | executor = VertexExecutor( 186 | resource=resource, 187 | name=job_name, 188 | func=packed_func, 189 | tune=tune, 190 | labels=labels, 191 | logger=prefect_logger, 192 | code_package=code_package, 193 | web_console=web_console_access, 194 | ) 195 | else: 196 | client = ( 197 | PrefectEnvironmentClient() 198 | if via_cloud 199 | else VMMetadataServerClient() 200 | ) 201 | 202 | try: 203 | if not resource.environment.image: 204 | resource.environment.image = client.get_container_image() 205 | if not resource.environment.project: 206 | resource.environment.project = client.get_project() 207 | if not resource.environment.service_account: 208 | resource.environment.service_account = ( 209 | client.get_service_account() 210 | ) 211 | if not resource.environment.region: 212 | resource.environment.region = client.get_region() 213 | except requests.exceptions.ConnectionError: 214 | prefect_logger.warning( 215 | "Failure to connect to host. " 216 | "Execution environment must be outside of " 217 | "a GCP VM or as a result of Prefect " 218 | "Deployment." 219 | ) 220 | 221 | if not resource.environment.is_complete: 222 | missing_env_attributes = [ 223 | attr for attr in ["project", "service_account", "region", "image"] 224 | if getattr(resource.environment, attr) is None 225 | ] 226 | raise RuntimeError( 227 | "Unable to infer remaining environment for GcpResource. " 228 | f"Missing attributes: {missing_env_attributes}. " 229 | "Please provide a complete environment to the " 230 | "configured GcpResource." 231 | ) 232 | executor = VertexExecutor( 233 | resource=resource, 234 | func=packed_func, 235 | name=job_name, 236 | tune=tune, 237 | labels=labels, 238 | logger=prefect_logger, 239 | code_package=code_package, 240 | web_console=web_console_access, 241 | ) 242 | elif ( 243 | isinstance(resource, DatabricksResource) 244 | or "DatabricksResource" in type(resource).__name__ 245 | ): 246 | prefect_logger.info("Executing task with DatabricksResource.") 247 | failed_to_infer_base = ( 248 | "Unable to infer base module of function. Specify " 249 | "the base module in the `cloud_pickle_by_value` attribute " 250 | "of the DatabricksResource object if necessary." 251 | ) 252 | if resource.cloud_pickle_infer_base_module: 253 | base_module_name = _infer_base_module(func) 254 | # if base module is __main__ or None, it can't be registered 255 | if base_module_name is None or base_module_name.startswith("__"): 256 | prefect_logger.warn(failed_to_infer_base) 257 | else: 258 | resource.cloud_pickle_by_value.append(base_module_name) 259 | 260 | executor = DatabricksExecutor( 261 | func=packed_func, 262 | resource=resource, 263 | name=job_name, 264 | ) 265 | else: 266 | raise ValueError("No valid resource provided.") 267 | 268 | # if sucessful, this will return the result of ._result() 269 | return executor.run() 270 | 271 | return remote_func 272 | -------------------------------------------------------------------------------- /block_cascade/executors/vertex/executor.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import logging 3 | import os 4 | from pathlib import Path 5 | import pickle 6 | import time 7 | from typing import Callable, Mapping, Optional, Union 8 | 9 | import cloudpickle 10 | import gcsfs 11 | from google.cloud import aiplatform_v1beta1 as aiplatform 12 | from google.cloud.aiplatform_v1beta1.types import job_state 13 | 14 | from block_cascade.concurrency import run_async 15 | from block_cascade.executors.executor import Executor 16 | from block_cascade.executors.vertex.distributed.distributed_job import DaskJob 17 | from block_cascade.executors.vertex.job import VertexJob 18 | from block_cascade.executors.vertex.resource import GcpResource 19 | from block_cascade.executors.vertex.tune import Tune, TuneResult 20 | from block_cascade.gcp.monitoring import log_quotas_for_resource 21 | from block_cascade.utils import PREFECT_VERSION, maybe_convert 22 | 23 | 24 | if PREFECT_VERSION == 3: 25 | from block_cascade.prefect.v3 import get_current_deployment, get_storage_block 26 | elif PREFECT_VERSION == 2: 27 | from block_cascade.prefect.v2 import get_current_deployment, get_storage_block 28 | else: 29 | get_storage_block = None 30 | get_current_deployment = None 31 | 32 | 33 | class VertexError(Exception): 34 | pass 35 | 36 | 37 | class VertexCancelledError(Exception): 38 | pass 39 | 40 | 41 | @dataclass 42 | class Status: 43 | state: job_state.JobState 44 | message: str 45 | 46 | @property 47 | def is_executing(self): 48 | return self.state in { 49 | job_state.JobState.JOB_STATE_UNSPECIFIED, 50 | job_state.JobState.JOB_STATE_QUEUED, 51 | job_state.JobState.JOB_STATE_PENDING, 52 | job_state.JobState.JOB_STATE_RUNNING, 53 | job_state.JobState.JOB_STATE_PAUSED, 54 | } 55 | 56 | @property 57 | def is_cancelled(self): 58 | return self.state in { 59 | job_state.JobState.JOB_STATE_CANCELLING, 60 | job_state.JobState.JOB_STATE_CANCELLED, 61 | } 62 | 63 | @property 64 | def is_succesful(self): 65 | return self.state is job_state.JobState.JOB_STATE_SUCCEEDED 66 | 67 | 68 | class VertexExecutor(Executor): 69 | def __init__( 70 | self, 71 | resource: GcpResource, 72 | func: Callable, 73 | name: str = None, 74 | job: VertexJob = None, 75 | tune: Tune = None, 76 | dashboard: bool = False, 77 | web_console: bool = False, 78 | labels: Optional[Mapping[str, str]] = None, 79 | logger: Optional[Union[logging.Logger, logging.LoggerAdapter]] = None, 80 | code_package: Optional[Path] = None, 81 | ): 82 | """Executor to submit VertexJobs as CustomJobs in Vertex AI 83 | The lifecycle of an executor is: 84 | * initialize the executor 85 | * call the run method 86 | * this in turn creates a VertexJob, stages the function and data, 87 | attempts to start the job by calling VertexAI's CustomJob API 88 | and monitors the status of the job 89 | * the _start method is reposible for managing the lifecycle of the VertexJob 90 | 91 | Parameters 92 | ---------- 93 | resource: GcpResource 94 | The environment in which to create and run a VertexJob 95 | func: Callable 96 | The function to execute 97 | Note: this needs to be a Callable that expects no arguments 98 | and has a __name__ attribute; use cascade.utils.wrapped_partial 99 | to prepare a function for execution 100 | job: VertexJob 101 | It is possible to pass an existing VertexJob to the executor; this is not 102 | recomended. The run method of the executor will create a new VertexJob if 103 | no job is passed. 104 | dashboard: bool 105 | Whether to enable access to the Dask dashboard in VertexAI, defaults 106 | to False 107 | web_console: bool 108 | Whether to enable access to the web console in VertexAI, defaults to False 109 | labels: Dict[str, str], Optional 110 | User defined metadata to attatch to VertexJobs. 111 | logger: Union[logging.Logger, logging.LoggerAdapter], Optional 112 | A configured logger for logging messages. 113 | code_package: Optional[Path] 114 | An optional path to the first party code that your remote execution needs. 115 | This is only necessary if the following conditions hold true: 116 | - The function is desired to run in Vertex AI 117 | - The function is not being executed from a Prefect2/3 Cloud Deployment 118 | - The function references a module that is not from a third party 119 | dependency, but from the same package the function is a member of. 120 | """ 121 | super().__init__(func=func) 122 | self.resource = resource 123 | self.dashboard = dashboard 124 | self.web_console = web_console 125 | self.job = job 126 | self.tune = tune 127 | self.labels = labels 128 | self.name = name 129 | 130 | if code_package and code_package.is_file(): 131 | raise RuntimeError( 132 | f"{code_package} references a file. " 133 | "Please reference a directory representing a Python package." 134 | ) 135 | self.code_package = code_package 136 | 137 | # the name is not known at executor initialization, 138 | # it is set when the job is submitted 139 | self._name = name 140 | self._fs = gcsfs.GCSFileSystem() 141 | self._storage_location = self.resource.environment.storage_location 142 | self._logger = logger 143 | if not self._logger: 144 | self._logger = logging.getLogger(__name__) 145 | 146 | @property 147 | def name(self) -> str: 148 | return self._name 149 | 150 | @name.setter 151 | def name(self, name: str): 152 | self._name = name 153 | 154 | @property 155 | def vertex(self): 156 | """ 157 | Returns a Vertex client; refreshes the client for each usage 158 | This seems excessive but we've had issues with authentication expiring 159 | """ 160 | region = self.job.resource.environment.region 161 | client_options = {"api_endpoint": f"{region}-aiplatform.googleapis.com"} 162 | return aiplatform.JobServiceClient(client_options=client_options) 163 | 164 | @property 165 | def display_name(self): 166 | if self.name is not None: 167 | return self.name 168 | if self.tune is None: 169 | return f"{self.func.__name__}-cascade" 170 | return f"{self.func.__name__}-hyperparameter-tuning" 171 | 172 | @property 173 | def distributed_job_path(self): 174 | return os.path.join(self.storage_path, "distributed_job.pkl") 175 | 176 | @property 177 | def code_path(self): 178 | return os.path.join(self.storage_path, "code") 179 | 180 | def create_job(self) -> VertexJob: 181 | """ 182 | create a VertexJob to be run by this executor 183 | when we move to 3.10 can use matching 184 | """ 185 | 186 | package_path = None 187 | if not self.code_package and PREFECT_VERSION in (2, 3): 188 | self._logger.info( 189 | "Checking if flow is running from a Prefect 2/3 Cloud deployment." 190 | ) 191 | deployment = get_current_deployment() 192 | if deployment: 193 | storage = get_storage_block() 194 | if "/" in deployment.entrypoint: 195 | module_name = deployment.entrypoint.split("/")[0] 196 | elif ":" in deployment.entrypoint: 197 | module_name = deployment.entrypoint.split(":")[0] 198 | else: 199 | module_name = deployment.entrypoint 200 | 201 | module_name = module_name.lstrip("/") 202 | # 203 | # This is hardcoded to only support 204 | # prefect.filesystems.GCS or prefect_gcp.cloud_storage.GCSBucket 205 | # storage blocks. 206 | # 207 | # TODO: This should be generalized to accept any storage block. 208 | # 209 | bucket = storage.data.get( 210 | "bucket_path", 211 | storage.data.get( 212 | "bucket" 213 | ) 214 | ) 215 | if not bucket: 216 | raise RuntimeError( 217 | f"Unable to parse bucket from storage block: {storage}" 218 | ) 219 | deployment_path = storage.data.get("bucket_folder").rstrip("/") if storage.data.get("bucket_folder") else deployment.path.rstrip("/") 220 | 221 | package_path = f"{bucket}/{deployment_path}/{module_name}" 222 | self._logger.info( 223 | f"Code package from deployment is located at gs://{package_path}" 224 | ) 225 | elif self.code_package: 226 | upload_path = os.path.join(self.code_path, self.code_package.name) 227 | fs = gcsfs.GCSFileSystem() 228 | self._logger.info(f"Uploading first party package to {upload_path}") 229 | fs.upload(str(self.code_package), upload_path, recursive=True) 230 | package_path = upload_path 231 | 232 | dashboard = self.dashboard or isinstance(self.resource.distributed_job, DaskJob) 233 | return VertexJob( 234 | display_name=self.display_name, 235 | resource=self.resource, 236 | storage_path=self.storage_path, 237 | tune=self.tune, 238 | dashboard=dashboard, 239 | web_console=self.web_console, 240 | labels=self.labels, 241 | code_package=package_path, 242 | ) 243 | 244 | def _stage_distributed_job(self): 245 | """ 246 | Stages a distributed job object in GCS; this is done instead of staging the 247 | function as the distributed job requires the function to be included in the 248 | distributed job object to perform additional setup on the cluster before 249 | function execution 250 | """ 251 | with self.fs.open(self.distributed_job_path, "wb") as f: 252 | cloudpickle.dump(self.resource.distributed_job, f) 253 | 254 | def _get_status(self): 255 | name = self.name 256 | if "hyperparameter" in name: 257 | response = self.vertex.get_hyperparameter_tuning_job(name=name) 258 | else: 259 | response = self.vertex.get_custom_job(name=name) 260 | return Status(response.state, response.error) 261 | 262 | def _run(self): 263 | """ 264 | Runs a task and return the result, called from 265 | the public run function. 266 | Submits the job to Vertex and monitors its status by polling the API. 267 | Returns the result of the job. 268 | """ 269 | 270 | if self.resource.distributed_job is not None: 271 | self._stage_distributed_job() 272 | 273 | self._stage() 274 | 275 | custom_job_name = self._start() 276 | self.name = custom_job_name 277 | 278 | status = self._get_status() 279 | while status.is_executing: 280 | time.sleep(30) 281 | status = self._get_status() 282 | 283 | if status.is_cancelled: 284 | raise VertexCancelledError( 285 | f"Job {self.name} was cancelled: {status.message}" 286 | ) 287 | 288 | if not status.is_succesful: 289 | raise VertexError(f"Job {self.name} failed: {status.message}") 290 | 291 | return self._result() 292 | 293 | def _start(self) -> name: 294 | """ 295 | Create and start a job in Vertex. 296 | Returns the name of the job as returned by the Vertex CustomJobs API 297 | """ 298 | 299 | # if the job does not exist, create it 300 | if self.job is None: 301 | self.job = self.create_job() 302 | 303 | try: 304 | run_async(log_quotas_for_resource(self.job.resource)) 305 | except Exception as e: 306 | self._logger.warning(e) 307 | pass 308 | 309 | # create the payload for the custom job 310 | custom_job_payload = self.job.create_payload() 311 | 312 | # pull the gcp_environment from the job object 313 | gcp_environment = self.job.resource.environment 314 | project = gcp_environment.project 315 | region = gcp_environment.region 316 | 317 | # the resource name of the Location to create the custom job in 318 | parent = f"projects/{project}/locations/{region}" 319 | 320 | # submit the job to CutomJob endpoint 321 | if self.job.tune is None: 322 | resp = self.vertex.create_custom_job( 323 | parent=parent, custom_job=custom_job_payload 324 | ) 325 | self._logger.info(f"Created a remote vertex job: {resp}") 326 | else: 327 | resp = self.vertex.create_hyperparameter_tuning_job( 328 | parent=parent, hyperparameter_tuning_job=custom_job_payload 329 | ) 330 | self._logger.info( 331 | f"Created a remote vertex hyperparameter tuning job: {resp}" 332 | ) 333 | 334 | _, _, _, location, _, job_id = resp.name.split("/") 335 | path = f"locations/{location}/training/{job_id}" 336 | 337 | url_log = f"Logs for remote job can be found at https://console.cloud.google.com/vertex-ai/{path}" 338 | self._logger.info(url_log) 339 | 340 | # return the name of the job in Vertex for use monitoring status 341 | # and fetching results 342 | return resp.name 343 | 344 | def _result(self): 345 | output_filepath = self.output_filepath 346 | # We have to build hyperparameter tuning result from the Vertex api, not 347 | # the task and we save it to the task output here for consistency 348 | if "hyperparameter" in self.name: 349 | result = self._tune_result() 350 | with self.fs.open(output_filepath, "wb") as f: 351 | pickle.dump(result, f) 352 | return result 353 | 354 | # For anything else we can just load the result from the output 355 | try: 356 | with self.fs.open(output_filepath, "rb") as f: 357 | return pickle.load(f) 358 | except ValueError: 359 | self._logger.warning( 360 | f"Failed to load the output from succesful job at {output_filepath}" 361 | ) 362 | raise 363 | 364 | def _tune_result(self): 365 | response = self.vertex.get_hyperparameter_tuning_job(name=self.name) 366 | 367 | reverse = "MAXIMIZE" in str(response.study_spec.metrics[0].goal) 368 | trials = response.trials 369 | trials = sorted( 370 | trials, 371 | key=lambda trial: trial.final_measurement.metrics[0].value, 372 | reverse=reverse, 373 | ) 374 | flattened = [ 375 | { 376 | "trial_id": maybe_convert(t.id), 377 | "metric": maybe_convert(t.final_measurement.metrics[0].value), 378 | **{ 379 | param.parameter_id: maybe_convert(param.value) 380 | for param in t.parameters 381 | }, 382 | } 383 | for t in trials 384 | if t.state.name == "SUCCEEDED" 385 | ] 386 | best_trial = flattened[0] 387 | metric = best_trial["metric"] 388 | hyperparameters = { 389 | key: val 390 | for key, val in best_trial.items() 391 | if key not in ("trial_id", "metric") 392 | } 393 | return TuneResult(metric, hyperparameters, flattened) 394 | -------------------------------------------------------------------------------- /block_cascade/executors/databricks/executor.py: -------------------------------------------------------------------------------- 1 | from types import ModuleType 2 | from typing import Callable, Iterable, Optional 3 | 4 | from block_cascade.executors.executor import Executor 5 | 6 | try: 7 | import cloudpickle 8 | except ImportError: 9 | import pickle as cloudpickle # Databricks renames cloudpickle to pickle in Runtimes 11 + # noqa: E501 10 | 11 | import base64 12 | import importlib 13 | import os 14 | import threading 15 | import time 16 | import s3fs 17 | from dataclasses import dataclass 18 | from slugify import slugify 19 | 20 | from databricks_cli.cluster_policies.api import ClusterPolicyApi 21 | from databricks_cli.runs.api import RunsApi 22 | from databricks_cli.sdk.api_client import ApiClient 23 | 24 | from block_cascade.executors.databricks.resource import DatabricksSecret 25 | from block_cascade.executors.databricks.job import DatabricksJob 26 | from block_cascade.executors.databricks.resource import DatabricksResource 27 | from block_cascade.executors.databricks.filesystem import DatabricksFilesystem 28 | from block_cascade.prefect import get_prefect_logger 29 | 30 | from importlib.resources import files 31 | 32 | 33 | lock = threading.Lock() 34 | 35 | # must specify API version=2.1 or runs submitted from Vertex are not viewable in 36 | # Databricks UI 37 | DATABRICKS_API_VERSION = "2.1" 38 | 39 | 40 | class DatabricksError(Exception): 41 | pass 42 | 43 | 44 | class DatabricksCancelledError(Exception): 45 | pass 46 | 47 | 48 | @dataclass 49 | class Status: 50 | """ 51 | https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobsrunlifecyclestate 52 | https://docs.databricks.com/dev-tools/api/2.0/jobs.html#runresultstate 53 | """ 54 | 55 | status: dict 56 | 57 | def __post_init__(self): 58 | self.result_state = self.status["state"].get("result_state", "") 59 | self.life_cycle_state = self.status["state"]["life_cycle_state"] 60 | 61 | def is_executing(self): 62 | return self.life_cycle_state in {"PENDING", "RUNNING"} 63 | 64 | def is_cancelled(self): 65 | return self.result_state == "CANCELED" 66 | 67 | def is_succesful(self): 68 | return self.result_state == "SUCCESS" 69 | 70 | 71 | class DatabricksExecutor(Executor): 72 | def __init__( 73 | self, 74 | func: Callable, 75 | resource: DatabricksResource, 76 | name: Optional[str] = None, 77 | ): 78 | """Executor to submit tasks to run as databricks jobs 79 | 80 | Parameters 81 | ---------- 82 | func : Callable 83 | Function to run 84 | resource : DatabricksResource, optional 85 | Databricks resource, describing the cluster to run the job on 86 | job_name: str, optional 87 | An optional name for the job, by default this is None and inferred from 88 | func.__name__ 89 | """ 90 | super().__init__(func=func) 91 | self.resource = resource 92 | self.name = name 93 | self.active_job = None 94 | self._fs = None 95 | self._databricks_secret = resource.secret 96 | self._storage_location = resource.storage_location 97 | 98 | # extract params from resource 99 | self.group_name = self.resource.group_name 100 | if self.resource.cluster_policy is None: 101 | self.cluster_policy = self.group_name + "_default" 102 | else: 103 | self.cluster_policy = self.resource.cluster_policy 104 | 105 | self.logger = get_prefect_logger(__name__) 106 | 107 | @property 108 | def databricks_secret(self): 109 | """ 110 | DatabricksSecret object containing token and host 111 | Check if the user passed a DatabricksSecret object or if 112 | DATABRICKS_HOST and DATABRICKS_TOKEN are set as ENV VARS 113 | if neither are set, raise a ValueError 114 | """ 115 | if self._databricks_secret is not None: 116 | return self._databricks_secret 117 | elif os.environ.get("DATABRICKS_HOST") and os.environ.get("DATABRICKS_TOKEN"): 118 | return DatabricksSecret( 119 | host=os.environ["DATABRICKS_HOST"], 120 | token=os.environ["DATABRICKS_TOKEN"], 121 | ) 122 | else: 123 | raise ValueError( 124 | """Cannot locate Databricks secret. Databricks secret must 125 | be set in DatabricksResource or as environment variables 126 | DATABRICKS_HOST and DATABRICKS_TOKEN""" 127 | ) 128 | 129 | @property 130 | def fs(self): 131 | """ 132 | Get the appropriate filesystem for the storage location. 133 | 134 | - For /Volumes/ paths: Uses DatabricksFilesystem (Unity Catalog Volumes via DBFS API) 135 | Required for serverless compute. Provides UC governance and permissions. 136 | 137 | - For s3:// paths: Uses s3fs.S3FileSystem 138 | Required for traditional cluster compute. 139 | 140 | For S3, credentials are refreshed every time (1 hour validity). 141 | boto3 client creation is not threadsafe, so we wrap in retries. 142 | """ 143 | if self._fs is not None: 144 | return self._fs 145 | 146 | storage_loc = self.resource.storage_location 147 | 148 | # Unity Catalog Volumes - Required for serverless compute 149 | if storage_loc.startswith("/Volumes/"): 150 | self._fs = DatabricksFilesystem( 151 | api_client=self.api_client, 152 | auto_mkdir=True 153 | ) 154 | self.logger.info(f"Using DatabricksFilesystem for Unity Catalog Volumes: {storage_loc}") 155 | return self._fs 156 | 157 | # S3 paths - Required for traditional cluster compute 158 | if storage_loc.startswith("s3://"): 159 | wait = 1 160 | n_retries = 0 161 | while n_retries <= 6: 162 | try: 163 | if self.resource.s3_credentials is None: 164 | self._fs = s3fs.S3FileSystem() 165 | else: 166 | self._fs = s3fs.S3FileSystem(**self.resource.s3_credentials) 167 | break 168 | except KeyError: 169 | self.logger.info(f"Waiting {wait} seconds to retry STS") 170 | n_retries += 1 171 | time.sleep(wait) 172 | wait *= 1.5 173 | if self._fs is None: 174 | raise RuntimeError( 175 | "Failed to initialize S3 filesystem; job pickle cannot be staged." 176 | ) 177 | return self._fs 178 | 179 | # Unknown storage type 180 | raise ValueError( 181 | f"Unsupported storage location: {storage_loc}. " 182 | "Must be either:\n" 183 | " - /Volumes//// (for serverless compute)\n" 184 | " - s3://bucket/path/ (for traditional cluster compute)" 185 | ) 186 | 187 | @fs.setter 188 | def fs(self, fs): 189 | self._fs = fs 190 | 191 | @property 192 | def cloudpickle_by_value(self) -> Iterable[ModuleType]: 193 | """ 194 | A list of modules to pickle by value rather than by reference 195 | This list is defined by the user in the resource object and 196 | has signature List[str] 197 | 198 | If a module has not already present in sys.modules, 199 | it will be imported 200 | 201 | If a module is not found in the current Python environment, 202 | raises a Runtime error 203 | 204 | Returns 205 | ------- 206 | Iterable[str] 207 | Set of modules to pickle by value 208 | """ 209 | modules_to_pickle = set() 210 | for module in self.resource.cloud_pickle_by_value or []: 211 | try: 212 | modules_to_pickle.add(importlib.import_module(module)) 213 | except ModuleNotFoundError: 214 | raise RuntimeError( 215 | f"Unable to pickle {module} due to module not being " 216 | "found in current Python environment." 217 | ) 218 | except ImportError: 219 | raise RuntimeError(f"Unable to pickle {module} due to import error.") 220 | return modules_to_pickle 221 | 222 | @property 223 | def api_client(self): 224 | """ 225 | TODO: We may be able to cache this/not recreate it every time; 226 | initially we're copying the previous AIP approach 227 | """ 228 | api_client = ApiClient( 229 | host=self.databricks_secret.host, token=self.databricks_secret.token 230 | ) 231 | return api_client 232 | 233 | @property 234 | def runs_api(self): 235 | return RunsApi(self.api_client) 236 | 237 | def get_cluster_policies(self): 238 | client = ClusterPolicyApi(self.api_client) 239 | policies = client.list_cluster_policies() 240 | return policies 241 | 242 | def get_cluster_policy_id_from_policy_name(self, cluster_policy_name: str) -> str: 243 | policies = self.get_cluster_policies() 244 | for i in policies["policies"]: 245 | if i["name"] == cluster_policy_name: 246 | return i["policy_id"] 247 | raise ValueError("No policy with provided name found") 248 | 249 | @property 250 | def run_path(self): 251 | """ 252 | Get the path where run.py should be stored. 253 | 254 | For serverless compute: Use /Shared/ path (uploaded to Workspace, referenced without /Workspace/ prefix) 255 | For cluster compute: Can be in storage_path (S3 or Volumes) 256 | 257 | Note: Serverless requires Shared workspace paths WITHOUT the /Workspace/ prefix in job spec, 258 | but WITH it for the upload API (handled in _upload_to_workspace). 259 | """ 260 | if self.resource.use_serverless: 261 | # Serverless python_file uses /Shared/ (not /Workspace/Shared/) 262 | # The upload will add /Workspace/ prefix for the API call 263 | return f"/Shared/.cascade/{self.storage_key}/run.py" 264 | else: 265 | # Traditional cluster compute can use storage location (S3 or Volumes) 266 | return os.path.join(self.storage_path, "run.py") 267 | 268 | def create_job(self): 269 | """ 270 | Create a DatabricksJob object 271 | """ 272 | try: 273 | self.name = self.name or self.func.__name__ 274 | except AttributeError: 275 | self.name = self.name or "unnamed" 276 | 277 | # Only lookup cluster policy ID if not using serverless 278 | if self.resource.use_serverless: 279 | cluster_policy_id = None 280 | else: 281 | cluster_policy_id = self.get_cluster_policy_id_from_policy_name( 282 | self.cluster_policy 283 | ) 284 | 285 | return DatabricksJob( 286 | name=slugify(self.name), 287 | resource=self.resource, 288 | storage_path=self.storage_path, 289 | storage_key=self.storage_key, 290 | existing_cluster_id=self.resource.existing_cluster_id, 291 | cluster_policy_id=cluster_policy_id, 292 | run_path=self.run_path, 293 | timeout_seconds=self.resource.timeout_seconds, 294 | ) 295 | 296 | def _run(self): 297 | """ 298 | Create the payload, submit it to the API, and monitor its status while it 299 | is executing 300 | """ 301 | 302 | self._stage() 303 | self._start() 304 | 305 | while self._status().is_executing(): 306 | time.sleep(30) 307 | 308 | if self._status().is_cancelled(): 309 | raise DatabricksCancelledError( 310 | f"Job {self.name} was cancelled: {self._status().status}" 311 | ) 312 | 313 | if not self._status().is_succesful(): 314 | raise DatabricksError(f"Job {self.name} failed: {self._status().status}") 315 | 316 | return self._result() 317 | 318 | def _result(self): 319 | """ 320 | Override base _result() to add better error handling for Volumes. 321 | """ 322 | try: 323 | with self.fs.open(self.output_filepath, "rb") as f: 324 | result = cloudpickle.load(f) 325 | 326 | # Clean up storage 327 | self.fs.rm(self.storage_path, recursive=True) 328 | 329 | except FileNotFoundError as e: 330 | self.logger.error(f"Could not read output file: {e}") 331 | raise FileNotFoundError( 332 | f"Could not find output file {self.output_filepath}. " 333 | f"Original error: {e}" 334 | ) 335 | 336 | return result 337 | 338 | def _upload_to_workspace(self, local_path: str, workspace_path: str): 339 | """ 340 | Upload a file to Databricks Workspace using the Workspace API. 341 | 342 | This is required for serverless compute, which cannot access Unity Catalog Volumes 343 | for the python_file parameter. 344 | 345 | Parameters 346 | ---------- 347 | local_path : str 348 | Local file path to upload 349 | workspace_path : str 350 | Workspace path (e.g., /Shared/.cascade/uuid/run.py for serverless) 351 | 352 | Note 353 | ---- 354 | The Workspace API and job specifications use the SAME path format for serverless: 355 | - Both use: /Shared/.cascade/uuid/run.py (no /Workspace/ prefix needed) 356 | """ 357 | # For serverless, the path format is the same for API and job spec 358 | api_path = workspace_path 359 | 360 | # Read local file 361 | with open(local_path, "rb") as f: 362 | content = f.read() 363 | 364 | # Base64 encode content 365 | content_b64 = base64.b64encode(content).decode('utf-8') 366 | 367 | # Create parent directory if needed 368 | parent_dir = os.path.dirname(api_path) 369 | self.api_client.perform_query( 370 | 'POST', 371 | '/workspace/mkdirs', 372 | data={'path': parent_dir} 373 | ) 374 | 375 | # Upload file to workspace 376 | # Use AUTO format instead of SOURCE to create a regular file, not a notebook 377 | self.api_client.perform_query( 378 | 'POST', 379 | '/workspace/import', 380 | data={ 381 | 'path': api_path, 382 | 'content': content_b64, 383 | 'format': 'AUTO', 384 | 'overwrite': True 385 | } 386 | ) 387 | 388 | def _upload_run_script(self): 389 | """ 390 | Upload run.py bootstrap script. 391 | 392 | For serverless: Upload to Workspace (serverless can't access Volumes for python_file) 393 | For cluster: Upload to storage location (S3 or Volumes) 394 | """ 395 | run_script = ( 396 | files("block_cascade.executors.databricks") 397 | .joinpath("run.py") 398 | .resolve() 399 | .as_posix() 400 | ) 401 | 402 | if self.resource.use_serverless: 403 | # Upload to Databricks Workspace for serverless compatibility 404 | self._upload_to_workspace(run_script, self.run_path) 405 | else: 406 | # Upload to storage location (S3 or Volumes) for cluster compute 407 | self.fs.upload(run_script, self.run_path) 408 | 409 | def _stage(self): 410 | """ 411 | Overwrite the base _stage method to additionally stage 412 | block_cascade.executors.databricks.run.py and register pickle by value dependencies 413 | and then unregister them 414 | """ 415 | self._upload_run_script() 416 | 417 | with lock: 418 | for dep in self.cloudpickle_by_value: 419 | cloudpickle.register_pickle_by_value(dep) 420 | 421 | with self.fs.open(self.staged_filepath, "wb") as f: 422 | cloudpickle.dump(self.func, f) 423 | 424 | for dep in self.cloudpickle_by_value: 425 | cloudpickle.unregister_pickle_by_value(dep) 426 | 427 | def _start(self): 428 | """Create a job, use it to create a payload, and submit it to the API""" 429 | 430 | # get the databricks client and submit a job to the API 431 | client = self.runs_api 432 | job = self.create_job() 433 | databricks_payload = job.create_payload() 434 | self.logger.info(f"Databricks job payload: {databricks_payload}") 435 | 436 | self.active_job = client.submit_run( 437 | databricks_payload, version=DATABRICKS_API_VERSION 438 | ) 439 | 440 | self.logger.info(f"Created Databricks job: {self.active_job}") 441 | url = client.get_run(**self.active_job)["run_page_url"] 442 | 443 | self.logger.info(f"Databricks job running: {url}") 444 | 445 | def _status(self, raw=False): 446 | runs_client = self.runs_api 447 | status = runs_client.get_run(**self.active_job) 448 | if raw: 449 | return status 450 | return Status(status) 451 | 452 | def list_runtime_versions(self): 453 | from databricks_cli.clusters.api import ClusterApi 454 | 455 | clusterapi = ClusterApi(self.api_client) 456 | versions = clusterapi.spark_versions() 457 | return sorted(versions["versions"], key=lambda x: x["key"]) 458 | --------------------------------------------------------------------------------