├── .dockerignore ├── .github └── workflows │ ├── cla.yml │ ├── docker.yml │ ├── integration.yaml │ ├── release.yml │ └── test.yml ├── .gitignore ├── .vscode └── settings.json ├── Cargo.lock ├── Cargo.toml ├── Justfile ├── LICENSE ├── README.md ├── examples ├── Dockerfile ├── concurrent_greeter.py ├── example.py ├── greeter.py ├── hypercorn-config.toml ├── pydantic_greeter.py ├── requirements.txt ├── virtual_object.py └── workflow.py ├── pyproject.toml ├── python └── restate │ ├── __init__.py │ ├── asyncio.py │ ├── aws_lambda.py │ ├── context.py │ ├── discovery.py │ ├── endpoint.py │ ├── exceptions.py │ ├── handler.py │ ├── harness.py │ ├── object.py │ ├── py.typed │ ├── serde.py │ ├── server.py │ ├── server_context.py │ ├── server_types.py │ ├── service.py │ ├── vm.py │ └── workflow.py ├── requirements.txt ├── rust-toolchain.toml ├── shell.nix ├── src └── lib.rs ├── test-services ├── .env ├── Dockerfile ├── README.md ├── entrypoint.sh ├── exclusions.yaml ├── hypercorn-config.toml ├── requirements.txt ├── services │ ├── __init__.py │ ├── awakeable_holder.py │ ├── block_and_wait_workflow.py │ ├── cancel_test.py │ ├── counter.py │ ├── failing.py │ ├── interpreter.py │ ├── kill_test.py │ ├── list_object.py │ ├── map_object.py │ ├── non_determinism.py │ ├── proxy.py │ ├── test_utils.py │ └── virtual_object_command_interpreter.py └── testservices.py └── tests └── serde.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Git 2 | .git 3 | .gitignore 4 | .gitattributes 5 | 6 | # Byte-compiled / optimized / DLL files 7 | **/__pycache__/ 8 | **/*.py[cod] 9 | 10 | # Virtual environment 11 | **/.env 12 | **/.venv/ 13 | **/venv/ 14 | 15 | **/test_report/ 16 | 17 | target 18 | 19 | -------------------------------------------------------------------------------- /.github/workflows/cla.yml: -------------------------------------------------------------------------------- 1 | name: "CLA Assistant" 2 | on: 3 | issue_comment: 4 | types: [created] 5 | pull_request_target: 6 | types: [opened, closed, synchronize] 7 | 8 | jobs: 9 | CLAAssistant: 10 | uses: restatedev/restate/.github/workflows/cla.yml@main 11 | secrets: inherit 12 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Build docker image 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | tags: 7 | - v** 8 | 9 | env: 10 | REPOSITORY_OWNER: ${{ github.repository_owner }} 11 | GHCR_REGISTRY: "ghcr.io" 12 | GHCR_REGISTRY_USERNAME: ${{ github.actor }} 13 | GHCR_REGISTRY_TOKEN: ${{ secrets.GITHUB_TOKEN }} 14 | 15 | jobs: 16 | build-python-services-docker-image: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up QEMU dependency 23 | uses: docker/setup-qemu-action@v2 24 | 25 | - name: Set up Docker Buildx 26 | uses: docker/setup-buildx-action@v2 27 | 28 | - name: Log into GitHub container registry 29 | uses: docker/login-action@v2 30 | with: 31 | registry: ${{ env.GHCR_REGISTRY }} 32 | username: ${{ env.GHCR_REGISTRY_USERNAME }} 33 | password: ${{ env.GHCR_REGISTRY_TOKEN }} 34 | 35 | - name: Extract metadata (tags, labels) 36 | id: meta 37 | uses: docker/metadata-action@v5 38 | with: 39 | images: | 40 | ghcr.io/restatedev/test-services-python 41 | tags: | 42 | type=ref,event=branch 43 | type=semver,pattern={{version}} 44 | type=semver,pattern={{major}}.{{minor} 45 | 46 | - name: Build docker image 47 | uses: docker/build-push-action@v3 48 | with: 49 | context: . 50 | file: test-services/Dockerfile 51 | push: true 52 | platforms: linux/arm64,linux/amd64 53 | tags: ${{ steps.meta.outputs.tags }} 54 | labels: ${{ steps.meta.outputs.labels }} 55 | -------------------------------------------------------------------------------- /.github/workflows/integration.yaml: -------------------------------------------------------------------------------- 1 | name: Integration 2 | 3 | # Controls when the workflow will run 4 | on: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | schedule: 10 | - cron: "0 */6 * * *" # Every 6 hours 11 | workflow_dispatch: 12 | inputs: 13 | restateCommit: 14 | description: "restate commit" 15 | required: false 16 | default: "" 17 | type: string 18 | restateImage: 19 | description: "restate image, superseded by restate commit" 20 | required: false 21 | default: "ghcr.io/restatedev/restate:main" 22 | type: string 23 | workflow_call: 24 | inputs: 25 | restateCommit: 26 | description: "restate commit" 27 | required: false 28 | default: "" 29 | type: string 30 | restateImage: 31 | description: "restate image, superseded by restate commit" 32 | required: false 33 | default: "ghcr.io/restatedev/restate:main" 34 | type: string 35 | 36 | jobs: 37 | sdk-test-suite: 38 | if: github.repository_owner == 'restatedev' 39 | runs-on: ubuntu-latest 40 | name: Features integration test 41 | permissions: 42 | contents: read 43 | issues: read 44 | checks: write 45 | pull-requests: write 46 | actions: read 47 | 48 | steps: 49 | - uses: actions/checkout@v4 50 | with: 51 | repository: restatedev/sdk-python 52 | 53 | - name: Set up Docker containerd snapshotter 54 | uses: crazy-max/ghaction-setup-docker@v3 55 | with: 56 | set-host: true 57 | daemon-config: | 58 | { 59 | "features": { 60 | "containerd-snapshotter": true 61 | } 62 | } 63 | 64 | ### Download the Restate container image, if needed 65 | # Setup restate snapshot if necessary 66 | # Due to https://github.com/actions/upload-artifact/issues/53 67 | # We must use download-artifact to get artifacts created during *this* workflow run, ie by workflow call 68 | - name: Download restate snapshot from in-progress workflow 69 | if: ${{ inputs.restateCommit != '' && github.event_name != 'workflow_dispatch' }} 70 | uses: actions/download-artifact@v4 71 | with: 72 | name: restate.tar 73 | # In the workflow dispatch case where the artifact was created in a previous run, we can download as normal 74 | - name: Download restate snapshot from completed workflow 75 | if: ${{ inputs.restateCommit != '' && github.event_name == 'workflow_dispatch' }} 76 | uses: dawidd6/action-download-artifact@v3 77 | with: 78 | repo: restatedev/restate 79 | workflow: ci.yml 80 | commit: ${{ inputs.restateCommit }} 81 | name: restate.tar 82 | - name: Install restate snapshot 83 | if: ${{ inputs.restateCommit != '' }} 84 | run: | 85 | output=$(docker load --input restate.tar | head -n 1) 86 | docker tag "${output#*: }" "localhost/restatedev/restate-commit-download:latest" 87 | docker image ls -a 88 | 89 | - name: Set up QEMU 90 | uses: docker/setup-qemu-action@v3 91 | - name: Set up Docker Buildx 92 | uses: docker/setup-buildx-action@v3 93 | - name: Build Python test-services image 94 | id: build 95 | uses: docker/build-push-action@v6 96 | with: 97 | context: . 98 | file: "test-services/Dockerfile" 99 | push: false 100 | load: true 101 | tags: restatedev/test-services-python 102 | cache-from: type=gha,scope=${{ github.workflow }} 103 | cache-to: type=gha,mode=max,scope=${{ github.workflow }} 104 | 105 | - name: Run test tool 106 | uses: restatedev/sdk-test-suite@v3.0 107 | with: 108 | restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }} 109 | serviceContainerImage: "restatedev/test-services-python" 110 | exclusionsFile: "test-services/exclusions.yaml" 111 | testArtifactOutput: "sdk-python-integration-test-report" 112 | serviceContainerEnvFile: "test-services/.env" 113 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - v** 7 | workflow_dispatch: 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | version-check: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Install dasel 18 | run: | 19 | wget -q https://github.com/TomWright/dasel/releases/download/v2.8.1/dasel_linux_amd64 && \ 20 | mv dasel_linux_amd64 dasel && \ 21 | chmod +x dasel && \ 22 | ./dasel --version 23 | - name: Verify version 24 | run: | 25 | # write yaml to remove quotes 26 | version=$(./dasel --file Cargo.toml --read toml 'package.version' --write yaml) 27 | tag_version=${{ github.ref_name }} 28 | tag_version=${tag_version#v} 29 | if [ ${tag_version} != ${version} ]; then 30 | echo "::error file=release.yml,line=28::Cargo.toml version '${version}' is not equal to tag version '${tag_version}'. Please align them." 31 | exit 1; 32 | fi 33 | 34 | linux: 35 | runs-on: ${{ matrix.platform.runner }} 36 | needs: [version-check] 37 | strategy: 38 | matrix: 39 | platform: 40 | - runner: ubuntu-latest 41 | target: x86_64 42 | - runner: ubuntu-latest 43 | target: aarch64 44 | steps: 45 | - uses: actions/checkout@v4 46 | - uses: actions/setup-python@v5 47 | with: 48 | python-version: 3.x 49 | - name: Build wheels 50 | uses: PyO3/maturin-action@v1 51 | with: 52 | target: ${{ matrix.platform.target }} 53 | args: --release --out dist --find-interpreter 54 | sccache: 'false' 55 | # See https://github.com/PyO3/maturin-action/issues/222 56 | manylinux: ${{ matrix.platform.target == 'aarch64' && '2_28' || 'auto' }} 57 | - name: Upload wheels 58 | uses: actions/upload-artifact@v4 59 | with: 60 | name: wheels-linux-${{ matrix.platform.target }} 61 | path: dist 62 | 63 | musllinux: 64 | runs-on: ${{ matrix.platform.runner }} 65 | needs: [version-check] 66 | strategy: 67 | matrix: 68 | platform: 69 | - runner: ubuntu-latest 70 | target: x86_64 71 | - runner: ubuntu-latest 72 | target: aarch64 73 | steps: 74 | - uses: actions/checkout@v4 75 | - uses: actions/setup-python@v5 76 | with: 77 | python-version: 3.x 78 | - name: Build wheels 79 | uses: PyO3/maturin-action@v1 80 | with: 81 | target: ${{ matrix.platform.target }} 82 | args: --release --out dist --find-interpreter 83 | sccache: 'false' 84 | manylinux: musllinux_1_2 85 | - name: Upload wheels 86 | uses: actions/upload-artifact@v4 87 | with: 88 | name: wheels-musllinux-${{ matrix.platform.target }} 89 | path: dist 90 | 91 | macos: 92 | runs-on: ${{ matrix.platform.runner }} 93 | needs: [version-check] 94 | strategy: 95 | matrix: 96 | platform: 97 | - runner: macos-13 98 | target: x86_64 99 | - runner: macos-14 100 | target: aarch64 101 | steps: 102 | - uses: actions/checkout@v4 103 | - uses: actions/setup-python@v5 104 | with: 105 | python-version: 3.x 106 | - name: Build wheels 107 | uses: PyO3/maturin-action@v1 108 | with: 109 | target: ${{ matrix.platform.target }} 110 | args: --release --out dist --find-interpreter 111 | sccache: 'true' 112 | - name: Upload wheels 113 | uses: actions/upload-artifact@v4 114 | with: 115 | name: wheels-macos-${{ matrix.platform.target }} 116 | path: dist 117 | 118 | sdist: 119 | runs-on: ubuntu-latest 120 | needs: [version-check] 121 | steps: 122 | - uses: actions/checkout@v4 123 | - name: Build sdist 124 | uses: PyO3/maturin-action@v1 125 | with: 126 | command: sdist 127 | args: --out dist 128 | - name: Upload sdist 129 | uses: actions/upload-artifact@v4 130 | with: 131 | name: wheels-sdist 132 | path: dist 133 | 134 | release: 135 | name: Release 136 | runs-on: ubuntu-latest 137 | needs: [linux, musllinux, macos, sdist] 138 | permissions: 139 | # Use to sign the release artifacts 140 | id-token: write 141 | # Used to upload release artifacts 142 | contents: write 143 | # Used to generate artifact attestation 144 | attestations: write 145 | steps: 146 | - uses: actions/download-artifact@v4 147 | - name: Generate artifact attestation 148 | uses: actions/attest-build-provenance@v2 149 | with: 150 | subject-path: 'wheels-*/*' 151 | - name: Publish to PyPI 152 | uses: PyO3/maturin-action@v1 153 | env: 154 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 155 | with: 156 | command: upload 157 | args: --non-interactive --skip-existing wheels-*/* 158 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | pull_request: 5 | workflow_dispatch: 6 | push: 7 | branches: 8 | - main 9 | 10 | permissions: 11 | contents: read 12 | checks: write 13 | pull-requests: write 14 | 15 | jobs: 16 | lint-and-test: 17 | name: "Lint and Test (Python ${{ matrix.python }})" 18 | runs-on: ubuntu-latest 19 | strategy: 20 | matrix: 21 | python: [ "3.11", "3.12" ] 22 | steps: 23 | - uses: actions/checkout@v4 24 | - uses: extractions/setup-just@v2 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python }} 28 | - name: Build Rust module 29 | uses: PyO3/maturin-action@v1 30 | with: 31 | args: --out dist --interpreter python${{ matrix.python }} 32 | sccache: 'true' 33 | container: off 34 | - run: pip install -r requirements.txt 35 | - run: pip install dist/* 36 | - name: Verify 37 | run: just verify 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Rust 165 | debug/ 166 | 167 | # IntelliJ 168 | .idea 169 | 170 | # Restate data 171 | restate-data 172 | 173 | # Test reports 174 | test-services/test_report/ 175 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.words": [ 3 | "endpointmanifest" 4 | ] 5 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "restate-sdk-python-core" 3 | version = "0.7.3" 4 | edition = "2021" 5 | 6 | [package.metadata.maturin] 7 | name = "restate_sdk._internal" 8 | 9 | [lib] 10 | name = "restate_sdk_python_core" 11 | crate-type = ["cdylib"] 12 | doc = false 13 | 14 | [dependencies] 15 | pyo3 = { version = "0.24.1", features = ["extension-module"] } 16 | tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } 17 | restate-sdk-shared-core = { version = "0.3.0", features = ["request_identity", "sha2_random_seed"] } 18 | -------------------------------------------------------------------------------- /Justfile: -------------------------------------------------------------------------------- 1 | # Justfile 2 | 3 | python := "python3" 4 | 5 | default: 6 | @echo "Available recipes:" 7 | @echo " mypy - Run mypy for type checking" 8 | @echo " pylint - Run pylint for linting" 9 | @echo " test - Run pytest for testing" 10 | @echo " verify - Run mypy, pylint, test" 11 | 12 | # Recipe to run mypy for type checking 13 | mypy: 14 | @echo "Running mypy..." 15 | {{python}} -m mypy --check-untyped-defs --ignore-missing-imports python/restate/ 16 | {{python}} -m mypy --check-untyped-defs --ignore-missing-imports examples/ 17 | 18 | # Recipe to run pylint for linting 19 | pylint: 20 | @echo "Running pylint..." 21 | {{python}} -m pylint python/restate --ignore-paths='^.*.?venv.*$' 22 | {{python}} -m pylint examples/ --ignore-paths='^.*\.?venv.*$' 23 | 24 | test: 25 | @echo "Running Python tests..." 26 | {{python}} -m pytest tests/* 27 | 28 | # Recipe to run both mypy and pylint 29 | verify: mypy pylint test 30 | @echo "Type checking and linting completed successfully." 31 | 32 | # Recipe to build the project 33 | build: 34 | @echo "Building the project..." 35 | maturin build --release 36 | 37 | clean: 38 | @echo "Cleaning the project" 39 | cargo clean 40 | 41 | example: 42 | #!/usr/bin/env bash 43 | cd examples/ 44 | if [ -z "$PYTHONPATH" ]; then 45 | export PYTHONPATH="examples/" 46 | else 47 | export PYTHONPATH="$PYTHONPATH:examples/" 48 | fi 49 | hypercorn --config hypercorn-config.toml example:app 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Restate 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Documentation](https://img.shields.io/badge/doc-reference-blue)](https://docs.restate.dev) 2 | [![Examples](https://img.shields.io/badge/view-examples-blue)](https://github.com/restatedev/examples) 3 | [![Discord](https://img.shields.io/discord/1128210118216007792?logo=discord)](https://discord.gg/skW3AZ6uGd) 4 | [![Twitter](https://img.shields.io/twitter/follow/restatedev.svg?style=social&label=Follow)](https://twitter.com/intent/follow?screen_name=restatedev) 5 | 6 | # Python SDK for restate 7 | 8 | [Restate](https://restate.dev/) is a system for easily building resilient applications using *distributed durable async/await*. This repository contains the Restate SDK for writing services in **Python**. 9 | 10 | ## Community 11 | 12 | * 🤗️ [Join our online community](https://discord.gg/skW3AZ6uGd) for help, sharing feedback and talking to the community. 13 | * 📖 [Check out our documentation](https://docs.restate.dev) to get quickly started! 14 | * 📣 [Follow us on Twitter](https://twitter.com/restatedev) for staying up to date. 15 | * 🙋 [Create a GitHub issue](https://github.com/restatedev/sdk-typescript/issues) for requesting a new feature or reporting a problem. 16 | * 🏠 [Visit our GitHub org](https://github.com/restatedev) for exploring other repositories. 17 | 18 | ## Using the SDK 19 | 20 | **Prerequisites**: 21 | - Python >= v3.11 22 | 23 | To use this SDK, add the dependency to your project: 24 | 25 | ```shell 26 | pip install restate_sdk 27 | ``` 28 | 29 | ## Versions 30 | 31 | The Python SDK is currently in active development, and might break across releases. 32 | 33 | The compatibility with Restate is described in the following table: 34 | 35 | | Restate Server\sdk-python | 0.0 - 0.2 | 0.3 - 0.5 | 0.6 - 0.7 | 36 | |---------------------------|-----------|-----------|-----------| 37 | | 1.0 | ✅ | ❌ | ❌ | 38 | | 1.1 | ✅ | ✅ | ❌ | 39 | | 1.2 | ✅ | ✅ | ❌ | 40 | | 1.3 | ✅ | ✅ | ✅ | 41 | 42 | ## Contributing 43 | 44 | We’re excited if you join the Restate community and start contributing! 45 | Whether it is feature requests, bug reports, ideas & feedback or PRs, we appreciate any and all contributions. 46 | We know that your time is precious and, therefore, deeply value any effort to contribute! 47 | 48 | ### Local development 49 | 50 | * Python 3 51 | * PyEnv or VirtualEnv 52 | * [just](https://github.com/casey/just) 53 | * [Rust toolchain](https://rustup.rs/) 54 | 55 | Setup your virtual environment using the tool of your choice, e.g. VirtualEnv: 56 | 57 | ```shell 58 | python3 -m venv .venv 59 | source .venv/bin/activate 60 | ``` 61 | 62 | Install the build tools: 63 | 64 | ```shell 65 | pip install -r requirements.txt 66 | ``` 67 | 68 | Now build the Rust module and include opt-in additional dev dependencies: 69 | 70 | ```shell 71 | maturin dev -E test,lint 72 | ``` 73 | 74 | You usually need to build the Rust module only once, but you might need to rebuild it on pulls. 75 | 76 | For linting and testing: 77 | 78 | ```shell 79 | just verify 80 | ``` 81 | 82 | ## Releasing the package 83 | 84 | Pull latest main: 85 | 86 | ```shell 87 | git checkout main && git pull 88 | ``` 89 | 90 | **Update module version in `Cargo.toml` and run a local build to update the `Cargo.lock` too**, commit it. Then push tag, e.g.: 91 | 92 | ``` 93 | git tag -m "Release v0.1.0" v0.1.0 94 | git push origin v0.1.0 95 | ``` 96 | -------------------------------------------------------------------------------- /examples/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | WORKDIR /usr/src/app 4 | 5 | COPY requirements.txt ./ 6 | RUN pip install --no-cache-dir -r requirements.txt 7 | 8 | COPY . . 9 | 10 | EXPOSE 9080 11 | 12 | ENV PYTHONPATH="/usr/src/app/src" 13 | CMD ["hypercorn", "example:app", "--config", "hypercorn-config.toml"] 14 | -------------------------------------------------------------------------------- /examples/concurrent_greeter.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """greeter.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | # pylint: disable=C0115 15 | # pylint: disable=R0903 16 | # pylint: disable=C0301 17 | 18 | import typing 19 | 20 | from pydantic import BaseModel 21 | from restate import wait_completed, Service, Context 22 | 23 | # models 24 | class GreetingRequest(BaseModel): 25 | name: str 26 | 27 | class Greeting(BaseModel): 28 | messages: typing.List[str] 29 | 30 | class Message(BaseModel): 31 | role: str 32 | content: str 33 | 34 | concurrent_greeter = Service("concurrent_greeter") 35 | 36 | @concurrent_greeter.handler() 37 | async def greet(ctx: Context, req: GreetingRequest) -> Greeting: 38 | claude = ctx.service_call(claude_sonnet, arg=Message(role="user", content=f"please greet {req.name}")) 39 | openai = ctx.service_call(open_ai, arg=Message(role="user", content=f"please greet {req.name}")) 40 | 41 | pending, done = await wait_completed(claude, openai) 42 | 43 | # collect the completed greetings 44 | greetings = [await f for f in done] 45 | 46 | # cancel the pending calls 47 | for f in pending: 48 | await f.cancel_invocation() # type: ignore 49 | 50 | return Greeting(messages=greetings) 51 | 52 | 53 | # not really interesting, just for this demo: 54 | 55 | @concurrent_greeter.handler() 56 | async def claude_sonnet(ctx: Context, req: Message) -> str: 57 | return f"Bonjour {req.content[13:]}!" 58 | 59 | @concurrent_greeter.handler() 60 | async def open_ai(ctx: Context, req: Message) -> str: 61 | return f"Hello {req.content[13:]}!" 62 | -------------------------------------------------------------------------------- /examples/example.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | import restate 16 | 17 | from greeter import greeter 18 | from virtual_object import counter 19 | from workflow import payment 20 | from pydantic_greeter import pydantic_greeter 21 | from concurrent_greeter import concurrent_greeter 22 | 23 | app = restate.app(services=[greeter, 24 | counter, 25 | payment, 26 | pydantic_greeter, 27 | concurrent_greeter]) 28 | 29 | if __name__ == "__main__": 30 | import hypercorn 31 | import hypercorn.asyncio 32 | import asyncio 33 | conf = hypercorn.Config() 34 | conf.bind = ["0.0.0.0:9080"] 35 | asyncio.run(hypercorn.asyncio.serve(app, conf)) 36 | -------------------------------------------------------------------------------- /examples/greeter.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from restate import Service, Context 16 | 17 | greeter = Service("greeter") 18 | 19 | @greeter.handler() 20 | async def greet(ctx: Context, name: str) -> str: 21 | return f"Hello {name}!" 22 | -------------------------------------------------------------------------------- /examples/hypercorn-config.toml: -------------------------------------------------------------------------------- 1 | bind = "0.0.0.0:9080" 2 | h2_max_concurrent_streams = 2147483647 3 | keep_alive_max_requests = 2147483647 4 | keep_alive_timeout = 2147483647 5 | workers = 8 6 | 7 | -------------------------------------------------------------------------------- /examples/pydantic_greeter.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """greeter.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | # pylint: disable=C0115 15 | # pylint: disable=R0903 16 | 17 | from pydantic import BaseModel 18 | from restate import Service, Context 19 | 20 | # models 21 | class GreetingRequest(BaseModel): 22 | name: str 23 | 24 | class Greeting(BaseModel): 25 | message: str 26 | 27 | # service 28 | 29 | pydantic_greeter = Service("pydantic_greeter") 30 | 31 | @pydantic_greeter.handler() 32 | async def greet(ctx: Context, req: GreetingRequest) -> Greeting: 33 | return Greeting(message=f"Hello {req.name}!") 34 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | hypercorn 2 | restate_sdk 3 | pydantic 4 | dacite 5 | -------------------------------------------------------------------------------- /examples/virtual_object.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from restate import VirtualObject, ObjectContext, ObjectSharedContext 16 | 17 | counter = VirtualObject("counter") 18 | 19 | @counter.handler() 20 | async def increment(ctx: ObjectContext, value: int) -> int: 21 | n = await ctx.get("counter", type_hint=int) or 0 22 | n += value 23 | ctx.set("counter", n) 24 | return n 25 | 26 | @counter.handler(kind="shared") 27 | async def count(ctx: ObjectSharedContext) -> int: 28 | return await ctx.get("counter") or 0 29 | -------------------------------------------------------------------------------- /examples/workflow.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | # pylint: disable=C0301 15 | 16 | from datetime import timedelta 17 | 18 | from restate import Workflow, WorkflowContext, WorkflowSharedContext 19 | from restate import select 20 | from restate import TerminalError 21 | 22 | TIMEOUT = timedelta(seconds=10) 23 | 24 | payment = Workflow("payment") 25 | 26 | @payment.main() 27 | async def pay(ctx: WorkflowContext, amount: int): 28 | workflow_key = ctx.key() 29 | ctx.set("status", "verifying payment") 30 | 31 | # Call the payment service 32 | def payment_gateway(): 33 | print("Please approve this payment: ") 34 | print("To approve use:") 35 | print(f"""curl http://localhost:8080/payment/{workflow_key}/payment_verified --json '"approved"' """) 36 | print("") 37 | print("To decline use:") 38 | print(f"""curl http://localhost:8080/payment/{workflow_key}/payment_verified --json '"declined"' """) 39 | 40 | await ctx.run("payment", payment_gateway) 41 | 42 | ctx.set("status", "waiting for the payment provider to approve") 43 | 44 | # Wait for the payment to be verified 45 | 46 | match await select(result=ctx.promise("verify.payment").value(), timeout=ctx.sleep(TIMEOUT)): 47 | case ['result', "approved"]: 48 | ctx.set("status", "payment approved") 49 | return { "success" : True } 50 | case ['result', "declined"]: 51 | ctx.set("status", "payment declined") 52 | raise TerminalError(message="Payment declined", status_code=401) 53 | case ['timeout', _]: 54 | ctx.set("status", "payment verification timed out") 55 | raise TerminalError(message="Payment verification timed out", status_code=410) 56 | 57 | @payment.handler() 58 | async def payment_verified(ctx: WorkflowSharedContext, result: str): 59 | promise = ctx.promise("verify.payment", type_hint=str) 60 | await promise.resolve(result) 61 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "restate-sdk" 3 | description = "A Python SDK for Restate" 4 | requires-python = ">=3.11" 5 | classifiers = [ 6 | "Programming Language :: Rust", 7 | "Programming Language :: Python :: Implementation :: CPython", 8 | "Topic :: Software Development :: Libraries :: Application Frameworks" 9 | ] 10 | dynamic = ["version"] 11 | license = { file = "LICENSE" } 12 | authors = [ 13 | { name = "Restate Developers", email = "dev@restate.dev" } 14 | ] 15 | 16 | [project.optional-dependencies] 17 | test = ["pytest", "hypercorn"] 18 | lint = ["mypy", "pylint"] 19 | harness = ["testcontainers", "hypercorn", "httpx"] 20 | serde = ["dacite", "pydantic"] 21 | 22 | [build-system] 23 | requires = ["maturin>=1.6,<2.0"] 24 | build-backend = "maturin" 25 | 26 | [tool.maturin] 27 | features = ["pyo3/extension-module"] 28 | module-name = "restate._internal" 29 | python-source = "python" 30 | -------------------------------------------------------------------------------- /python/restate/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """ 12 | Restate SDK for Python 13 | """ 14 | 15 | from .service import Service 16 | from .object import VirtualObject 17 | from .workflow import Workflow 18 | 19 | # types 20 | from .context import Context, ObjectContext, ObjectSharedContext 21 | from .context import WorkflowContext, WorkflowSharedContext 22 | # pylint: disable=line-too-long 23 | from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, RestateDurableSleepFuture, SendHandle 24 | from .exceptions import TerminalError 25 | from .asyncio import as_completed, gather, wait_completed, select 26 | 27 | from .endpoint import app 28 | 29 | try: 30 | from .harness import test_harness # type: ignore 31 | except ImportError: 32 | # we don't have the appropriate dependencies installed 33 | 34 | # pylint: disable=unused-argument, redefined-outer-name 35 | def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore 36 | """a dummy harness constructor to raise ImportError""" 37 | raise ImportError("Install restate-sdk[harness] to use this feature") 38 | 39 | __all__ = [ 40 | "Service", 41 | "VirtualObject", 42 | "Workflow", 43 | "Context", 44 | "ObjectContext", 45 | "ObjectSharedContext", 46 | "WorkflowContext", 47 | "WorkflowSharedContext", 48 | "DurablePromise", 49 | "RestateDurableFuture", 50 | "RestateDurableCallFuture", 51 | "RestateDurableSleepFuture", 52 | "SendHandle", 53 | "TerminalError", 54 | "app", 55 | "test_harness", 56 | "gather", 57 | "as_completed", 58 | "wait_completed", 59 | "select" 60 | ] 61 | -------------------------------------------------------------------------------- /python/restate/asyncio.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | # pylint: disable=R0913,C0301,R0917 12 | # pylint: disable=line-too-long 13 | """combines multiple futures into a single future""" 14 | 15 | from typing import Any, List, Tuple 16 | from restate.exceptions import TerminalError 17 | from restate.context import RestateDurableFuture 18 | from restate.server_context import ServerDurableFuture, ServerInvocationContext 19 | 20 | async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]: 21 | """ 22 | Blocks until all futures are completed. 23 | 24 | Returns a list of all futures. 25 | """ 26 | async for _ in as_completed(*futures): 27 | pass 28 | return list(futures) 29 | 30 | async def select(**kws: RestateDurableFuture[Any]) -> List[Any]: 31 | """ 32 | Blocks until one of the futures is completed. 33 | 34 | Example: 35 | 36 | who, what = await select(car=f1, hotel=f2, flight=f3) 37 | if who == "car": 38 | print(what) 39 | elif who == "hotel": 40 | print(what) 41 | elif who == "flight": 42 | print(what) 43 | 44 | works the best with matching: 45 | 46 | match await select(result=ctx.promise("verify.payment"), timeout=ctx.sleep(timedelta(seconds=10))): 47 | case ['result', "approved"]: 48 | return { "success" : True } 49 | case ['result', "declined"]: 50 | raise TerminalError(message="Payment declined", status_code=401) 51 | case ['timeout', _]: 52 | raise TerminalError(message="Payment verification timed out", status_code=410) 53 | 54 | """ 55 | if not kws: 56 | raise ValueError("At least one future must be passed.") 57 | reverse = { f: key for key, f in kws.items() } 58 | async for f in as_completed(*kws.values()): 59 | return [reverse[f], await f] 60 | assert False, "unreachable" 61 | 62 | async def as_completed(*futures: RestateDurableFuture[Any]): 63 | """ 64 | Returns an iterator that yields the futures as they are completed. 65 | 66 | example: 67 | 68 | async for future in as_completed(f1, f2, f3): 69 | # do something with the completed future 70 | print(await future) # prints the result of the future 71 | 72 | """ 73 | remaining = list(futures) 74 | while remaining: 75 | completed, waiting = await wait_completed(*remaining) 76 | for f in completed: 77 | yield f 78 | remaining = waiting 79 | 80 | async def wait_completed(*args: RestateDurableFuture[Any]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]: 81 | """ 82 | Blocks until at least one of the futures is completed. 83 | 84 | Returns a tuple of two lists: the first list contains the futures that are completed, 85 | the second list contains the futures that are not completed. 86 | """ 87 | handles: List[int] = [] 88 | context: ServerInvocationContext | None = None 89 | completed = [] 90 | uncompleted = [] 91 | futures = list(args) 92 | 93 | if not futures: 94 | return [], [] 95 | for f in futures: 96 | if not isinstance(f, ServerDurableFuture): 97 | raise TerminalError("All futures must SDK created futures.") 98 | if context is None: 99 | context = f.context 100 | elif context is not f.context: 101 | raise TerminalError("All futures must be created by the same SDK context.") 102 | if f.is_completed(): 103 | completed.append(f) 104 | else: 105 | handles.append(f.handle) 106 | uncompleted.append(f) 107 | 108 | if completed: 109 | # the user had passed some completed futures, so we can return them immediately 110 | return completed, uncompleted # type: ignore 111 | 112 | completed = [] 113 | uncompleted = [] 114 | assert context is not None 115 | await context.create_poll_or_cancel_coroutine(handles) 116 | 117 | for index, handle in enumerate(handles): 118 | future = futures[index] 119 | if context.vm.is_completed(handle): 120 | completed.append(future) # type: ignore 121 | else: 122 | uncompleted.append(future) # type: ignore 123 | return completed, uncompleted # type: ignore 124 | -------------------------------------------------------------------------------- /python/restate/aws_lambda.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """ 12 | This module contains the Lambda/ASGI adapter. 13 | """ 14 | import asyncio 15 | import base64 16 | import os 17 | from typing import TypedDict, Dict, cast, Union, Any, Callable 18 | 19 | from restate.server_types import (ASGIApp, 20 | Scope, 21 | Receive, 22 | HTTPResponseStartEvent, 23 | HTTPResponseBodyEvent, 24 | HTTPRequestEvent) 25 | 26 | class RestateLambdaRequest(TypedDict): 27 | """ 28 | Restate Lambda request 29 | 30 | :see: https://github.com/restatedev/restate/blob/1a10c05b16b387191060b49faffb0335ee97e96d/crates/service-client/src/lambda.rs#L297 # pylint: disable=line-too-long 31 | """ 32 | path: str 33 | httpMethod: str 34 | headers: Dict[str, str] 35 | body: str 36 | isBase64Encoded: bool 37 | 38 | 39 | class RestateLambdaResponse(TypedDict): 40 | """ 41 | Restate Lambda response 42 | 43 | :see: https://github.com/restatedev/restate/blob/1a10c05b16b387191060b49faffb0335ee97e96d/crates/service-client/src/lambda.rs#L310 # pylint: disable=line-too-long 44 | """ 45 | statusCode: int 46 | headers: Dict[str, str] 47 | body: str 48 | isBase64Encoded: bool 49 | 50 | 51 | RestateLambdaHandler = Callable[[RestateLambdaRequest, Any], RestateLambdaResponse] 52 | 53 | 54 | def create_scope(req: RestateLambdaRequest) -> Scope: 55 | """ 56 | Create ASGI scope from lambda request 57 | """ 58 | headers = {k.lower(): v for k, v in req.get('headers', {}).items()} 59 | http_method = req["httpMethod"] 60 | path = req["path"] 61 | 62 | return { 63 | "type": "http", 64 | "method": http_method, 65 | "http_version": "1.1", 66 | "headers": [(k.encode(), v.encode()) for k, v in headers.items()], 67 | "path": path, 68 | "scheme": headers.get("x-forwarded-proto", "https"), 69 | "asgi": {"version": "3.0", "spec_version": "2.0"}, 70 | "raw_path": path.encode(), 71 | "root_path": "", 72 | "query_string": b'', 73 | "client": None, 74 | "server": None, 75 | "extensions": None 76 | } 77 | 78 | 79 | def request_to_receive(req: RestateLambdaRequest) -> Receive: 80 | """ 81 | Create ASGI Receive from lambda request 82 | """ 83 | assert req['isBase64Encoded'] 84 | body = base64.b64decode(req['body']) 85 | 86 | events = cast(list[HTTPRequestEvent], [{ 87 | "type": "http.request", 88 | "body": body, 89 | "more_body": False 90 | }, 91 | { 92 | "type": "http.request", 93 | "body": b'', 94 | "more_body": False 95 | }]) 96 | 97 | async def recv() -> HTTPRequestEvent: 98 | if len(events) != 0: 99 | return events.pop(0) 100 | # If we are out of events, return a future that will never complete 101 | f = asyncio.Future[HTTPRequestEvent]() 102 | return await f 103 | 104 | return recv 105 | 106 | 107 | class ResponseCollector: 108 | """ 109 | Response collector from ASGI Send to Lambda 110 | """ 111 | def __init__(self): 112 | self.body = bytearray() 113 | self.headers = {} 114 | self.status_code = 500 115 | 116 | async def __call__(self, message: Union[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> None: 117 | """ 118 | Implements ASGI send contract 119 | """ 120 | if message["type"] == "http.response.start": 121 | self.status_code = cast(int, message["status"]) 122 | self.headers = { 123 | key.decode("utf-8"): value.decode("utf-8") 124 | for key, value in message["headers"] 125 | } 126 | elif message["type"] == "http.response.body" and "body" in message: 127 | self.body.extend(message["body"]) 128 | return 129 | 130 | def to_lambda_response(self) -> RestateLambdaResponse: 131 | """ 132 | Convert collected values to lambda response 133 | """ 134 | return { 135 | "statusCode": self.status_code, 136 | "headers": self.headers, 137 | "isBase64Encoded": True, 138 | "body": base64.b64encode(self.body).decode() 139 | } 140 | 141 | 142 | def is_running_on_lambda() -> bool: 143 | """ 144 | :return: true if this Python script is running on Lambda 145 | """ 146 | # https://docs.aws.amazon.com/lambda/latest/dg/configuration-envvars.html 147 | return "AWS_LAMBDA_FUNCTION_NAME" in os.environ 148 | 149 | 150 | def wrap_asgi_as_lambda_handler(asgi_app: ASGIApp) \ 151 | -> Callable[[RestateLambdaRequest, Any], RestateLambdaResponse]: 152 | """ 153 | Wrap the given asgi_app in a Lambda handler 154 | """ 155 | # Setup AsyncIO 156 | loop = asyncio.new_event_loop() 157 | asyncio.set_event_loop(loop) 158 | 159 | def lambda_handler(event: RestateLambdaRequest, _context: Any) -> RestateLambdaResponse: 160 | loop = asyncio.get_event_loop() 161 | 162 | scope = create_scope(event) 163 | recv = request_to_receive(event) 164 | send = ResponseCollector() 165 | 166 | asgi_instance = asgi_app(scope, recv, send) 167 | asgi_task = loop.create_task(asgi_instance) # type: ignore[var-annotated, arg-type] 168 | loop.run_until_complete(asgi_task) 169 | 170 | return send.to_lambda_response() 171 | 172 | return lambda_handler 173 | -------------------------------------------------------------------------------- /python/restate/context.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | # pylint: disable=R0913,C0301,R0917 12 | """ 13 | Restate Context 14 | """ 15 | 16 | import abc 17 | from dataclasses import dataclass 18 | from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, Coroutine, overload 19 | import typing 20 | from datetime import timedelta 21 | from restate.serde import DefaultSerde, Serde 22 | 23 | T = TypeVar('T') 24 | I = TypeVar('I') 25 | O = TypeVar('O') 26 | 27 | RunAction = Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]] 28 | HandlerType = Union[Callable[[Any, I], Awaitable[O]], Callable[[Any], Awaitable[O]]] 29 | 30 | # pylint: disable=R0903 31 | class RestateDurableFuture(typing.Generic[T], Awaitable[T]): 32 | """ 33 | Represents a durable future. 34 | """ 35 | 36 | @abc.abstractmethod 37 | def __await__(self) -> typing.Generator[Any, Any, T]: 38 | pass 39 | 40 | 41 | 42 | # pylint: disable=R0903 43 | class RestateDurableCallFuture(RestateDurableFuture[T]): 44 | """ 45 | Represents a durable call future. 46 | """ 47 | 48 | @abc.abstractmethod 49 | async def invocation_id(self) -> str: 50 | """ 51 | Returns the invocation id of the call. 52 | """ 53 | 54 | @abc.abstractmethod 55 | async def cancel_invocation(self) -> None: 56 | """ 57 | Cancels the invocation. 58 | 59 | Just a utility shortcut to: 60 | .. code-block:: python 61 | 62 | await ctx.cancel_invocation(await f.invocation_id()) 63 | """ 64 | 65 | 66 | class RestateDurableSleepFuture(RestateDurableFuture[None]): 67 | """ 68 | Represents a durable sleep future. 69 | """ 70 | 71 | @abc.abstractmethod 72 | def __await__(self) -> typing.Generator[Any, Any, None]: 73 | pass 74 | 75 | class AttemptFinishedEvent(abc.ABC): 76 | """ 77 | Represents an attempt finished event. 78 | 79 | This event is used to signal that an attempt has finished (either successfully or with an error), and it is now 80 | safe to cleanup any attempt related resources, such as pending ctx.run() 3rd party calls, or any other resources that 81 | are only valid for the duration of the attempt. 82 | 83 | An attempt is considered finished when either the connection to the restate server is closed, the invocation is completed, or a transient 84 | error occurs. 85 | """ 86 | 87 | @abc.abstractmethod 88 | def is_set(self) -> bool: 89 | """ 90 | Returns True if the event is set, False otherwise. 91 | """ 92 | 93 | 94 | @abc.abstractmethod 95 | async def wait(self): 96 | """ 97 | Waits for the event to be set. 98 | """ 99 | 100 | 101 | @dataclass 102 | class Request: 103 | """ 104 | Represents an ingress request. 105 | 106 | Attributes: 107 | id (str): The unique identifier of the request. 108 | headers (dict[str, str]): The headers of the request. 109 | attempt_headers (dict[str, str]): The attempt headers of the request. 110 | body (bytes): The body of the request. 111 | attempt_finished_event (AttemptFinishedEvent): The teardown event of the request. 112 | """ 113 | id: str 114 | headers: Dict[str, str] 115 | attempt_headers: Dict[str,str] 116 | body: bytes 117 | attempt_finished_event: AttemptFinishedEvent 118 | 119 | 120 | class KeyValueStore(abc.ABC): 121 | """ 122 | A key scoped key-value store. 123 | 124 | This class defines the interface for a key-value store, 125 | which allows storing and retrieving values 126 | based on a unique key. 127 | 128 | """ 129 | 130 | @abc.abstractmethod 131 | def get(self, 132 | name: str, 133 | serde: Serde[T] = DefaultSerde(), 134 | type_hint: Optional[typing.Type[T]] = None 135 | ) -> Awaitable[Optional[T]]: 136 | """ 137 | Retrieves the value associated with the given name. 138 | 139 | Args: 140 | name: The state name 141 | serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. 142 | See also 'type_hint'. 143 | type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. 144 | """ 145 | 146 | @abc.abstractmethod 147 | def state_keys(self) -> Awaitable[List[str]]: 148 | """Returns the list of keys in the store.""" 149 | 150 | @abc.abstractmethod 151 | def set(self, 152 | name: str, 153 | value: T, 154 | serde: Serde[T] = DefaultSerde()) -> None: 155 | """set the value associated with the given name.""" 156 | 157 | @abc.abstractmethod 158 | def clear(self, name: str) -> None: 159 | """clear the value associated with the given name.""" 160 | 161 | @abc.abstractmethod 162 | def clear_all(self) -> None: 163 | """clear all the values in the store.""" 164 | 165 | # pylint: disable=R0903 166 | class SendHandle(abc.ABC): 167 | """ 168 | Represents a send operation. 169 | """ 170 | 171 | @abc.abstractmethod 172 | async def invocation_id(self) -> str: 173 | """ 174 | Returns the invocation id of the send operation. 175 | """ 176 | 177 | @abc.abstractmethod 178 | async def cancel_invocation(self) -> None: 179 | """ 180 | Cancels the invocation. 181 | 182 | Just a utility shortcut to: 183 | .. code-block:: python 184 | 185 | await ctx.cancel_invocation(await f.invocation_id()) 186 | """ 187 | 188 | 189 | class Context(abc.ABC): 190 | """ 191 | Represents the context of the current invocation. 192 | """ 193 | 194 | @abc.abstractmethod 195 | def request(self) -> Request: 196 | """ 197 | Returns the request object. 198 | """ 199 | 200 | @overload 201 | @abc.abstractmethod 202 | def run(self, 203 | name: str, 204 | action: Callable[..., Coroutine[Any, Any,T]], 205 | serde: Serde[T] = DefaultSerde(), 206 | max_attempts: typing.Optional[int] = None, 207 | max_retry_duration: typing.Optional[timedelta] = None, 208 | type_hint: Optional[typing.Type[T]] = None, 209 | args: Optional[typing.Tuple[Any, ...]] = None, 210 | ) -> RestateDurableFuture[T]: 211 | """ 212 | Runs the given action with the given name. 213 | 214 | Args: 215 | name: The name of the action. 216 | action: The action to run. 217 | serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. 218 | See also 'type_hint'. 219 | max_attempts: The maximum number of retry attempts to complete the action. 220 | If None, the action will be retried indefinitely, until it succeeds. 221 | Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError. 222 | max_retry_duration: The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds. 223 | Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError. 224 | type_hint: The type hint of the return value of the action. 225 | This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. 226 | 227 | """ 228 | 229 | @overload 230 | @abc.abstractmethod 231 | def run(self, 232 | name: str, 233 | action: Callable[..., T], 234 | serde: Serde[T] = DefaultSerde(), 235 | max_attempts: typing.Optional[int] = None, 236 | max_retry_duration: typing.Optional[timedelta] = None, 237 | type_hint: Optional[typing.Type[T]] = None, 238 | args: Optional[typing.Tuple[Any, ...]] = None, 239 | ) -> RestateDurableFuture[T]: 240 | """ 241 | Runs the given coroutine action with the given name. 242 | 243 | Args: 244 | name: The name of the action. 245 | action: The action to run. 246 | serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. 247 | See also 'type_hint'. 248 | max_attempts: The maximum number of retry attempts to complete the action. 249 | If None, the action will be retried indefinitely, until it succeeds. 250 | Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError. 251 | max_retry_duration: The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds. 252 | Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError. 253 | type_hint: The type hint of the return value of the action. 254 | This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. 255 | 256 | """ 257 | 258 | @abc.abstractmethod 259 | def run(self, 260 | name: str, 261 | action: RunAction[T], 262 | serde: Serde[T] = DefaultSerde(), 263 | max_attempts: typing.Optional[int] = None, 264 | max_retry_duration: typing.Optional[timedelta] = None, 265 | type_hint: Optional[typing.Type[T]] = None, 266 | args: Optional[typing.Tuple[Any, ...]] = None, 267 | ) -> RestateDurableFuture[T]: 268 | """ 269 | Runs the given action with the given name. 270 | 271 | Args: 272 | name: The name of the action. 273 | action: The action to run. 274 | serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. 275 | See also 'type_hint'. 276 | max_attempts: The maximum number of retry attempts to complete the action. 277 | If None, the action will be retried indefinitely, until it succeeds. 278 | Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError. 279 | max_retry_duration: The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds. 280 | Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError. 281 | type_hint: The type hint of the return value of the action. 282 | This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. 283 | 284 | """ 285 | 286 | @abc.abstractmethod 287 | def sleep(self, delta: timedelta) -> RestateDurableSleepFuture: 288 | """ 289 | Suspends the current invocation for the given duration 290 | """ 291 | 292 | @abc.abstractmethod 293 | def service_call(self, 294 | tpe: HandlerType[I, O], 295 | arg: I, 296 | idempotency_key: str | None = None, 297 | headers: typing.Dict[str, str] | None = None 298 | ) -> RestateDurableCallFuture[O]: 299 | """ 300 | Invokes the given service with the given argument. 301 | """ 302 | 303 | 304 | @abc.abstractmethod 305 | def service_send(self, 306 | tpe: HandlerType[I, O], 307 | arg: I, 308 | send_delay: Optional[timedelta] = None, 309 | idempotency_key: str | None = None, 310 | headers: typing.Dict[str, str] | None = None 311 | ) -> SendHandle: 312 | """ 313 | Invokes the given service with the given argument. 314 | """ 315 | 316 | @abc.abstractmethod 317 | def object_call(self, 318 | tpe: HandlerType[I, O], 319 | key: str, 320 | arg: I, 321 | idempotency_key: str | None = None, 322 | headers: typing.Dict[str, str] | None = None 323 | ) -> RestateDurableCallFuture[O]: 324 | """ 325 | Invokes the given object with the given argument. 326 | """ 327 | 328 | @abc.abstractmethod 329 | def object_send(self, 330 | tpe: HandlerType[I, O], 331 | key: str, 332 | arg: I, 333 | send_delay: Optional[timedelta] = None, 334 | idempotency_key: str | None = None, 335 | headers: typing.Dict[str, str] | None = None 336 | ) -> SendHandle: 337 | """ 338 | Send a message to an object with the given argument. 339 | """ 340 | 341 | @abc.abstractmethod 342 | def workflow_call(self, 343 | tpe: HandlerType[I, O], 344 | key: str, 345 | arg: I, 346 | idempotency_key: str | None = None, 347 | headers: typing.Dict[str, str] | None = None 348 | ) -> RestateDurableCallFuture[O]: 349 | """ 350 | Invokes the given workflow with the given argument. 351 | """ 352 | 353 | @abc.abstractmethod 354 | def workflow_send(self, 355 | tpe: HandlerType[I, O], 356 | key: str, 357 | arg: I, 358 | send_delay: Optional[timedelta] = None, 359 | idempotency_key: str | None = None, 360 | headers: typing.Dict[str, str] | None = None 361 | ) -> SendHandle: 362 | """ 363 | Send a message to an object with the given argument. 364 | """ 365 | 366 | # pylint: disable=R0913 367 | @abc.abstractmethod 368 | def generic_call(self, 369 | service: str, 370 | handler: str, 371 | arg: bytes, 372 | key: Optional[str] = None, 373 | idempotency_key: str | None = None, 374 | headers: typing.Dict[str, str] | None = None 375 | ) -> RestateDurableCallFuture[bytes]: 376 | """ 377 | Invokes the given generic service/handler with the given argument. 378 | """ 379 | 380 | @abc.abstractmethod 381 | def generic_send(self, 382 | service: str, 383 | handler: str, 384 | arg: bytes, 385 | key: Optional[str] = None, 386 | send_delay: Optional[timedelta] = None, 387 | idempotency_key: str | None = None, 388 | headers: typing.Dict[str, str] | None = None 389 | ) -> SendHandle: 390 | """ 391 | Send a message to a generic service/handler with the given argument. 392 | """ 393 | 394 | @abc.abstractmethod 395 | def awakeable(self, 396 | serde: Serde[T] = DefaultSerde(), 397 | type_hint: Optional[typing.Type[T]] = None 398 | ) -> typing.Tuple[str, RestateDurableFuture[T]]: 399 | """ 400 | Returns the name of the awakeable and the future to be awaited. 401 | """ 402 | 403 | @abc.abstractmethod 404 | def resolve_awakeable(self, 405 | name: str, 406 | value: I, 407 | serde: Serde[I] = DefaultSerde()) -> None: 408 | """ 409 | Resolves the awakeable with the given name. 410 | """ 411 | 412 | @abc.abstractmethod 413 | def reject_awakeable(self, name: str, failure_message: str, failure_code: int = 500) -> None: 414 | """ 415 | Rejects the awakeable with the given name. 416 | """ 417 | 418 | @abc.abstractmethod 419 | def cancel_invocation(self, invocation_id: str): 420 | """ 421 | Cancels the invocation with the given id. 422 | """ 423 | 424 | @abc.abstractmethod 425 | def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(), 426 | type_hint: typing.Optional[typing.Type[T]] = None 427 | ) -> RestateDurableFuture[T]: 428 | """ 429 | Attaches the invocation with the given id. 430 | """ 431 | 432 | 433 | class ObjectContext(Context, KeyValueStore): 434 | """ 435 | Represents the context of the current invocation. 436 | """ 437 | 438 | @abc.abstractmethod 439 | def key(self) -> str: 440 | """ 441 | Returns the key of the current object. 442 | """ 443 | 444 | 445 | class ObjectSharedContext(Context): 446 | """ 447 | Represents the context of the current invocation. 448 | """ 449 | 450 | @abc.abstractmethod 451 | def key(self) -> str: 452 | """Returns the key of the current object.""" 453 | 454 | @abc.abstractmethod 455 | def get(self, 456 | name: str, 457 | serde: Serde[T] = DefaultSerde(), 458 | type_hint: Optional[typing.Type[T]] = None 459 | ) -> RestateDurableFuture[Optional[T]]: 460 | """ 461 | Retrieves the value associated with the given name. 462 | 463 | Args: 464 | name: The state name 465 | serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. 466 | See also 'type_hint'. 467 | type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. 468 | """ 469 | 470 | @abc.abstractmethod 471 | def state_keys(self) -> Awaitable[List[str]]: 472 | """ 473 | Returns the list of keys in the store. 474 | """ 475 | 476 | class DurablePromise(typing.Generic[T]): 477 | """ 478 | Represents a durable promise. 479 | """ 480 | 481 | def __init__(self, name: str, serde: Serde[T] = DefaultSerde()) -> None: 482 | self.name = name 483 | self.serde = serde 484 | 485 | @abc.abstractmethod 486 | def resolve(self, value: T) -> Awaitable[None]: 487 | """ 488 | Resolves the promise with the given value. 489 | """ 490 | 491 | @abc.abstractmethod 492 | def reject(self, message: str, code: int = 500) -> Awaitable[None]: 493 | """ 494 | Rejects the promise with the given message and code. 495 | """ 496 | 497 | @abc.abstractmethod 498 | def peek(self) -> Awaitable[typing.Optional[T]]: 499 | """ 500 | Returns the value of the promise if it is resolved, None otherwise. 501 | """ 502 | 503 | @abc.abstractmethod 504 | def value(self) -> RestateDurableFuture[T]: 505 | """ 506 | Returns the value of the promise if it is resolved, None otherwise. 507 | """ 508 | 509 | @abc.abstractmethod 510 | def __await__(self) -> typing.Generator[Any, Any, T]: 511 | """ 512 | Returns the value of the promise. This is a shortcut for calling value() and awaiting it. 513 | """ 514 | 515 | class WorkflowContext(ObjectContext): 516 | """ 517 | Represents the context of the current workflow invocation. 518 | """ 519 | 520 | @abc.abstractmethod 521 | def promise(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]: 522 | """ 523 | Returns a durable promise with the given name. 524 | """ 525 | 526 | class WorkflowSharedContext(ObjectSharedContext): 527 | """ 528 | Represents the context of the current workflow invocation. 529 | """ 530 | 531 | @abc.abstractmethod 532 | def promise(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]: 533 | """ 534 | Returns a durable promise with the given name. 535 | """ 536 | -------------------------------------------------------------------------------- /python/restate/discovery.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """ 12 | Holds the discovery API objects as defined by the restate protocol. 13 | Note that the classes defined here do not use snake case, because they 14 | are intended to be serialized to JSON, and their cases must remain in the 15 | case that the restate server understands. 16 | """ 17 | 18 | # disable to few parameters 19 | # pylint: disable=R0903 20 | # pylint: disable=C0301 21 | # pylint: disable=C0115 22 | # pylint: disable=C0103 23 | # pylint: disable=W0622 24 | # pylint: disable=R0913, 25 | # pylint: disable=R0917, 26 | 27 | import json 28 | import typing 29 | from enum import Enum 30 | from typing import Dict, Optional, Any, List, get_args, get_origin 31 | 32 | 33 | from restate.endpoint import Endpoint as RestateEndpoint 34 | from restate.handler import TypeHint 35 | 36 | class ProtocolMode(Enum): 37 | BIDI_STREAM = "BIDI_STREAM" 38 | REQUEST_RESPONSE = "REQUEST_RESPONSE" 39 | 40 | class ServiceType(Enum): 41 | VIRTUAL_OBJECT = "VIRTUAL_OBJECT" 42 | SERVICE = "SERVICE" 43 | WORKFLOW = "WORKFLOW" 44 | 45 | class ServiceHandlerType(Enum): 46 | WORKFLOW = "WORKFLOW" 47 | EXCLUSIVE = "EXCLUSIVE" 48 | SHARED = "SHARED" 49 | 50 | class InputPayload: 51 | def __init__(self, required: bool, contentType: str, jsonSchema: Optional[Any] = None): 52 | self.required = required 53 | self.contentType = contentType 54 | self.jsonSchema = jsonSchema 55 | 56 | class OutputPayload: 57 | def __init__(self, setContentTypeIfEmpty: bool, contentType: Optional[str] = None, jsonSchema: Optional[Any] = None): 58 | self.contentType = contentType 59 | self.setContentTypeIfEmpty = setContentTypeIfEmpty 60 | self.jsonSchema = jsonSchema 61 | 62 | class Handler: 63 | def __init__(self, name: str, ty: Optional[ServiceHandlerType] = None, input: Optional[InputPayload | Dict[str, str]] = None, output: Optional[OutputPayload] = None, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None): 64 | self.name = name 65 | self.ty = ty 66 | self.input = input 67 | self.output = output 68 | self.documentation = description 69 | self.metadata = metadata 70 | 71 | class Service: 72 | def __init__(self, name: str, ty: ServiceType, handlers: List[Handler], description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None): 73 | self.name = name 74 | self.ty = ty 75 | self.handlers = handlers 76 | self.documentation = description 77 | self.metadata = metadata 78 | 79 | class Endpoint: 80 | def __init__(self, protocolMode: ProtocolMode, minProtocolVersion: int, maxProtocolVersion: int, services: List[Service]): 81 | self.protocolMode = protocolMode 82 | self.minProtocolVersion = minProtocolVersion 83 | self.maxProtocolVersion = maxProtocolVersion 84 | self.services = services 85 | 86 | PROTOCOL_MODES = { 87 | "bidi" : ProtocolMode.BIDI_STREAM, 88 | "request_response" : ProtocolMode.REQUEST_RESPONSE} 89 | 90 | SERVICE_TYPES = { 91 | "service": ServiceType.SERVICE, 92 | "object": ServiceType.VIRTUAL_OBJECT, 93 | "workflow": ServiceType.WORKFLOW} 94 | 95 | HANDLER_TYPES = { 96 | 'exclusive': ServiceHandlerType.EXCLUSIVE, 97 | 'shared': ServiceHandlerType.SHARED, 98 | 'workflow': ServiceHandlerType.WORKFLOW} 99 | 100 | class PythonClassEncoder(json.JSONEncoder): 101 | """ 102 | Serialize Python objects as JSON 103 | """ 104 | def default(self, o): 105 | if isinstance(o, Enum): 106 | return o.value 107 | return {key: value for key, value in o.__dict__.items() if value is not None} 108 | 109 | 110 | # pylint: disable=R0911 111 | def type_hint_to_json_schema(type_hint: Any) -> Any: 112 | """ 113 | Convert a Python type hint to a JSON schema. 114 | 115 | """ 116 | origin = get_origin(type_hint) or type_hint 117 | args = get_args(type_hint) 118 | if origin is str: 119 | return {"type": "string"} 120 | if origin is int: 121 | return {"type": "integer"} 122 | if origin is float: 123 | return {"type": "number"} 124 | if origin is bool: 125 | return {"type": "boolean"} 126 | if origin is list: 127 | items = type_hint_to_json_schema(args[0] if args else Any) 128 | return {"type": "array", "items": items} 129 | if origin is dict: 130 | return { 131 | "type": "object" 132 | } 133 | if origin is None: 134 | return {"type": "null"} 135 | # Default to all valid schema 136 | return True 137 | 138 | def json_schema_from_type_hint(type_hint: Optional[TypeHint[Any]]) -> Any: 139 | """ 140 | Convert a type hint to a JSON schema. 141 | """ 142 | if not type_hint: 143 | return None 144 | if not type_hint.annotation: 145 | return None 146 | if type_hint.is_pydantic: 147 | return type_hint.annotation.model_json_schema(mode='serialization') 148 | return type_hint_to_json_schema(type_hint.annotation) 149 | 150 | 151 | 152 | def compute_discovery_json(endpoint: RestateEndpoint, 153 | version: int, 154 | discovered_as: typing.Literal["bidi", "request_response"]) -> typing.Tuple[typing.Dict[str, str] ,str]: 155 | """ 156 | return restate's discovery object as JSON 157 | """ 158 | if version != 1: 159 | raise ValueError(f"Unsupported protocol version {version}") 160 | 161 | ep = compute_discovery(endpoint, discovered_as) 162 | json_str = json.dumps(ep, cls=PythonClassEncoder, allow_nan=False) 163 | headers = {"content-type": "application/vnd.restate.endpointmanifest.v1+json"} 164 | return (headers, json_str) 165 | 166 | 167 | def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal["bidi", "request_response"]) -> Endpoint: 168 | """ 169 | return restate's discovery object for an endpoint 170 | """ 171 | services: typing.List[Service] = [] 172 | 173 | for service in endpoint.services.values(): 174 | service_type = SERVICE_TYPES[service.service_tag.kind] 175 | service_handlers = [] 176 | for handler in service.handlers.values(): 177 | # type 178 | if handler.kind: 179 | ty = HANDLER_TYPES[handler.kind] 180 | else: 181 | ty = None 182 | # input 183 | inp: Optional[InputPayload | Dict[str, str]] = None 184 | if handler.handler_io.input_type and handler.handler_io.input_type.is_void: 185 | inp = {} 186 | else: 187 | inp = InputPayload(required=False, 188 | contentType=handler.handler_io.accept, 189 | jsonSchema=json_schema_from_type_hint(handler.handler_io.input_type)) 190 | # output 191 | if handler.handler_io.output_type and handler.handler_io.output_type.is_void: 192 | out = OutputPayload(setContentTypeIfEmpty=False) 193 | else: 194 | out = OutputPayload(setContentTypeIfEmpty=False, 195 | contentType=handler.handler_io.content_type, 196 | jsonSchema=json_schema_from_type_hint(handler.handler_io.output_type)) 197 | # add the handler 198 | service_handlers.append(Handler(name=handler.name, 199 | ty=ty, 200 | input=inp, 201 | output=out, 202 | description=handler.description, 203 | metadata=handler.metadata)) 204 | # add the service 205 | description = service.service_tag.description 206 | metadata = service.service_tag.metadata 207 | services.append(Service(name=service.name, ty=service_type, handlers=service_handlers, description=description, metadata=metadata)) 208 | 209 | if endpoint.protocol: 210 | protocol_mode = PROTOCOL_MODES[endpoint.protocol] 211 | else: 212 | protocol_mode = PROTOCOL_MODES[discovered_as] 213 | return Endpoint(protocolMode=protocol_mode, 214 | minProtocolVersion=5, 215 | maxProtocolVersion=5, 216 | services=services) 217 | -------------------------------------------------------------------------------- /python/restate/endpoint.py: -------------------------------------------------------------------------------- 1 | 2 | # 3 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 4 | # 5 | # This file is part of the Restate SDK for Python, 6 | # which is released under the MIT license. 7 | # 8 | # You can find a copy of the license in file LICENSE in the root 9 | # directory of this repository or package, or at 10 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 11 | # 12 | """ 13 | This module defines the Endpoint class, which serves as a container for all the services and objects 14 | """ 15 | 16 | import typing 17 | 18 | from restate.service import Service 19 | from restate.object import VirtualObject 20 | from restate.workflow import Workflow 21 | 22 | 23 | # disable too few methods in a class 24 | # pylint: disable=R0903 25 | 26 | 27 | class Endpoint: 28 | """ 29 | Endpoint service that contains all the services and objects 30 | """ 31 | 32 | services: typing.Dict[str, typing.Union[Service, VirtualObject, Workflow]] 33 | protocol: typing.Optional[typing.Literal["bidi", "request_response"]] 34 | identity_keys: typing.List[str] 35 | 36 | def __init__(self): 37 | """ 38 | Create a new restate endpoint that serves as a container for all the services and objects 39 | """ 40 | self.services = {} 41 | # we will let the user to override it later perhaps, but for now let us 42 | # auto deduce it on discovery. 43 | # None means that the user did not explicitly set it. 44 | self.protocol = None 45 | 46 | self.identity_keys = [] 47 | 48 | def bind(self, *services: typing.Union[Service, VirtualObject, Workflow]): 49 | """ 50 | Bind a service to the endpoint 51 | 52 | Args: 53 | service: The service or virtual object to bind to the endpoint 54 | 55 | Raises: 56 | ValueError: If a service with the same name already exists in the endpoint 57 | 58 | Returns: 59 | The updated Endpoint instance 60 | """ 61 | for service in services: 62 | if service.name in self.services: 63 | raise ValueError(f"Service {service.name} already exists") 64 | if isinstance(service, (Service, VirtualObject, Workflow)): 65 | self.services[service.name] = service 66 | else: 67 | raise ValueError(f"Invalid service type {service}") 68 | return self 69 | 70 | def streaming_protocol(self): 71 | """Use bidirectional streaming protocol. Use with servers that support HTTP2""" 72 | self.protocol = "bidi" 73 | return self 74 | 75 | def request_response_protocol(self): 76 | """Use request response style protocol for communication with restate.""" 77 | self.protocol = "request_response" 78 | 79 | def identity_key(self, identity_key: str): 80 | """Add an identity key to this endpoint.""" 81 | self.identity_keys.append(identity_key) 82 | 83 | def app(self): 84 | """ 85 | Returns the ASGI application for this endpoint. 86 | 87 | This method is responsible for creating and returning the ASGI application 88 | that will handle incoming requests for this endpoint. 89 | 90 | Returns: 91 | The ASGI application for this endpoint. 92 | """ 93 | # we need to import it here to avoid circular dependencies 94 | # pylint: disable=C0415 95 | # pylint: disable=R0401 96 | from restate.server import asgi_app 97 | return asgi_app(self) 98 | 99 | def app( 100 | services: typing.Iterable[typing.Union[Service, VirtualObject, Workflow]], 101 | protocol: typing.Optional[typing.Literal["bidi", "request_response"]] = None, 102 | identity_keys: typing.Optional[typing.List[str]] = None): 103 | """A restate ASGI application that hosts the given services.""" 104 | endpoint = Endpoint() 105 | if protocol == "bidi": 106 | endpoint.streaming_protocol() 107 | elif protocol == "request_response": 108 | endpoint.request_response_protocol() 109 | for service in services: 110 | endpoint.bind(service) 111 | if identity_keys: 112 | for key in identity_keys: 113 | endpoint.identity_key(key) 114 | return endpoint.app() 115 | -------------------------------------------------------------------------------- /python/restate/exceptions.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """This module contains the restate exceptions""" 12 | 13 | class TerminalError(Exception): 14 | """This exception is raised to indicate a termination of the execution""" 15 | 16 | def __init__(self, message: str, status_code: int = 500) -> None: 17 | super().__init__(message) 18 | self.message = message 19 | self.status_code = status_code 20 | -------------------------------------------------------------------------------- /python/restate/handler.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | 12 | # pylint: disable=R0917 13 | """ 14 | This module contains the definition of the Handler class, 15 | which is used to define the handlers for the services. 16 | """ 17 | 18 | from dataclasses import dataclass 19 | from inspect import Signature 20 | from typing import Any, Callable, Awaitable, Dict, Generic, Literal, Optional, TypeVar 21 | 22 | from restate.context import HandlerType 23 | from restate.exceptions import TerminalError 24 | from restate.serde import DefaultSerde, PydanticJsonSerde, Serde, is_pydantic 25 | 26 | I = TypeVar('I') 27 | O = TypeVar('O') 28 | T = TypeVar('T') 29 | 30 | # we will use this symbol to store the handler in the function 31 | RESTATE_UNIQUE_HANDLER_SYMBOL = str(object()) 32 | 33 | @dataclass 34 | class ServiceTag: 35 | """ 36 | This class is used to identify the service. 37 | """ 38 | kind: Literal["object", "service", "workflow"] 39 | name: str 40 | description: Optional[str] = None 41 | metadata: Optional[Dict[str, str]] = None 42 | 43 | @dataclass 44 | class TypeHint(Generic[T]): 45 | """ 46 | Represents a type hint. 47 | """ 48 | annotation: Optional[T] = None 49 | is_pydantic: bool = False 50 | is_void: bool = False 51 | 52 | @dataclass 53 | class HandlerIO(Generic[I, O]): 54 | """ 55 | Represents the input/output configuration for a handler. 56 | 57 | Attributes: 58 | accept (str): The accept header value for the handler. 59 | content_type (str): The content type header value for the handler. 60 | """ 61 | accept: str 62 | content_type: str 63 | input_serde: Serde[I] 64 | output_serde: Serde[O] 65 | input_type: Optional[TypeHint[I]] = None 66 | output_type: Optional[TypeHint[O]] = None 67 | 68 | 69 | def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Signature): 70 | """ 71 | Augment handler_io with additional information about the input and output types. 72 | 73 | This function has a special check for Pydantic models when these are provided. 74 | This method will inspect the signature of an handler and will look for 75 | the input and the return types of a function, and will: 76 | * capture any Pydantic models (to be used later at discovery) 77 | * replace the default json serializer (is unchanged by a user) with a Pydantic serde 78 | """ 79 | params = list(signature.parameters.values()) 80 | if len(params) == 1: 81 | # if there is only one parameter, it is the context. 82 | handler_io.input_type = TypeHint(is_void=True) 83 | else: 84 | annotation = params[-1].annotation 85 | handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False) 86 | if is_pydantic(annotation): 87 | handler_io.input_type.is_pydantic = True 88 | if isinstance(handler_io.input_serde, DefaultSerde): 89 | handler_io.input_serde = PydanticJsonSerde(annotation) 90 | 91 | annotation = signature.return_annotation 92 | if annotation is None or annotation is Signature.empty: 93 | # if there is no return annotation, we assume it is void 94 | handler_io.output_type = TypeHint(is_void=True) 95 | else: 96 | handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False) 97 | if is_pydantic(annotation): 98 | handler_io.output_type.is_pydantic=True 99 | if isinstance(handler_io.output_serde, DefaultSerde): 100 | handler_io.output_serde = PydanticJsonSerde(annotation) 101 | 102 | # pylint: disable=R0902 103 | @dataclass 104 | class Handler(Generic[I, O]): 105 | """ 106 | Represents a handler for a service. 107 | """ 108 | service_tag: ServiceTag 109 | handler_io: HandlerIO[I, O] 110 | kind: Optional[Literal["exclusive", "shared", "workflow"]] 111 | name: str 112 | fn: Callable[[Any, I], Awaitable[O]] | Callable[[Any], Awaitable[O]] 113 | arity: int 114 | description: Optional[str] = None 115 | metadata: Optional[Dict[str, str]] = None 116 | 117 | 118 | # disable too many arguments warning 119 | # pylint: disable=R0913 120 | 121 | def make_handler(service_tag: ServiceTag, 122 | handler_io: HandlerIO[I, O], 123 | name: str | None, 124 | kind: Optional[Literal["exclusive", "shared", "workflow"]], 125 | wrapped: Any, 126 | signature: Signature, 127 | description: Optional[str] = None, 128 | metadata: Optional[Dict[str, str]] = None) -> Handler[I, O]: 129 | """ 130 | Factory function to create a handler. 131 | """ 132 | # try to deduce the handler name 133 | handler_name = name 134 | if not handler_name: 135 | handler_name = wrapped.__name__ 136 | if not handler_name: 137 | raise ValueError("Handler name must be provided") 138 | 139 | if len(signature.parameters) == 0: 140 | raise ValueError("Handler must have at least one parameter") 141 | 142 | arity = len(signature.parameters) 143 | update_handler_io_with_type_hints(handler_io, signature) # mutates handler_io 144 | 145 | handler = Handler[I, O](service_tag=service_tag, 146 | handler_io=handler_io, 147 | kind=kind, 148 | name=handler_name, 149 | fn=wrapped, 150 | arity=arity, 151 | description=description, 152 | metadata=metadata) 153 | 154 | vars(wrapped)[RESTATE_UNIQUE_HANDLER_SYMBOL] = handler 155 | return handler 156 | 157 | def handler_from_callable(wrapper: HandlerType[I, O]) -> Handler[I, O]: 158 | """ 159 | Get the handler from the callable. 160 | """ 161 | try: 162 | return vars(wrapper)[RESTATE_UNIQUE_HANDLER_SYMBOL] 163 | except KeyError: 164 | raise ValueError("Handler not found") # pylint: disable=raise-missing-from 165 | 166 | async def invoke_handler(handler: Handler[I, O], ctx: Any, in_buffer: bytes) -> bytes: 167 | """ 168 | Invoke the handler with the given context and input. 169 | """ 170 | if handler.arity == 2: 171 | try: 172 | in_arg = handler.handler_io.input_serde.deserialize(in_buffer) 173 | except Exception as e: 174 | raise TerminalError(message=f"Unable to parse an input argument. {e}") from e 175 | out_arg = await handler.fn(ctx, in_arg) # type: ignore [call-arg, arg-type] 176 | else: 177 | out_arg = await handler.fn(ctx) # type: ignore [call-arg] 178 | out_buffer = handler.handler_io.output_serde.serialize(out_arg) 179 | return bytes(out_buffer) 180 | -------------------------------------------------------------------------------- /python/restate/harness.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """Test containers based wrapper for the restate server""" 12 | 13 | import abc 14 | import asyncio 15 | from dataclasses import dataclass 16 | import threading 17 | import typing 18 | from urllib.error import URLError 19 | import socket 20 | 21 | from hypercorn.config import Config 22 | from hypercorn.asyncio import serve 23 | from testcontainers.core.container import DockerContainer # type: ignore 24 | from testcontainers.core.waiting_utils import wait_container_is_ready # type: ignore 25 | import httpx 26 | 27 | 28 | def find_free_port(): 29 | """find the next free port to bind to on the host machine""" 30 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 31 | s.bind(("0.0.0.0", 0)) 32 | return s.getsockname()[1] 33 | 34 | def run_in_background(coro) -> threading.Thread: 35 | """run a coroutine in the background""" 36 | def runner(): 37 | asyncio.run(coro) 38 | 39 | thread = threading.Thread(target=runner, daemon=True) 40 | thread.start() 41 | return thread 42 | 43 | 44 | class BindAddress(abc.ABC): 45 | """A bind address for the ASGI server""" 46 | 47 | @abc.abstractmethod 48 | def get_local_bind_address(self) -> str: 49 | """return the local bind address for hypercorn to bind to""" 50 | 51 | @abc.abstractmethod 52 | def get_endpoint_connection_string(self) -> str: 53 | """return the SDK connection string to be used by restate""" 54 | 55 | @abc.abstractmethod 56 | def cleanup(self): 57 | """cleanup any resources used by the bind address""" 58 | 59 | class TcpSocketBindAddress(BindAddress): 60 | """Bind a TCP address that listens on a random TCP port""" 61 | 62 | def __init__(self): 63 | self.port = find_free_port() 64 | 65 | def get_local_bind_address(self) -> str: 66 | return f"0.0.0.0:{self.port}" 67 | 68 | def get_endpoint_connection_string(self) -> str: 69 | return f"http://host.docker.internal:{self.port}" 70 | 71 | def cleanup(self): 72 | pass 73 | 74 | 75 | class AsgiServer: 76 | """A simple ASGI server that listens on a unix domain socket""" 77 | 78 | thread: typing.Optional[threading.Thread] = None 79 | 80 | def __init__(self, asgi_app, bind_address: BindAddress): 81 | self.asgi_app = asgi_app 82 | self.bind_address = bind_address 83 | self.stop_event = asyncio.Event() 84 | self.exit_event = asyncio.Event() 85 | 86 | def stop(self): 87 | """stop the server""" 88 | self.stop_event.set() 89 | if self.thread: 90 | self.thread.join(timeout=1) 91 | self.thread = None 92 | self.exit_event.set() 93 | 94 | def start(self): 95 | """start the server""" 96 | 97 | def shutdown_trigger(): 98 | """trigger the shutdown event""" 99 | return self.stop_event.wait() 100 | 101 | async def run_asgi(): 102 | """run the asgi app on the given port""" 103 | config = Config() 104 | config.h2_max_concurrent_streams = 2147483647 105 | config.keep_alive_max_requests = 2147483647 106 | config.keep_alive_timeout = 2147483647 107 | 108 | bind = self.bind_address.get_local_bind_address() 109 | config.bind = [bind] 110 | try: 111 | print(f"Starting ASGI server on {bind}", flush=True) 112 | await serve(self.asgi_app, 113 | config=config, 114 | mode='asgi', 115 | shutdown_trigger=shutdown_trigger) 116 | except asyncio.CancelledError: 117 | print("ASGI server was cancelled", flush=True) 118 | except Exception as e: # pylint: disable=broad-except 119 | print(f"Failed to start the ASGI server: {e}", flush=True) 120 | raise e 121 | finally: 122 | self.exit_event.set() 123 | 124 | self.thread = run_in_background(run_asgi()) 125 | return self 126 | 127 | class RestateContainer(DockerContainer): 128 | """Create a Restate container""" 129 | 130 | log_thread: typing.Optional[threading.Thread] = None 131 | 132 | def __init__(self, image): 133 | super().__init__(image) 134 | self.with_exposed_ports(8080, 9070) 135 | self.with_env('RESTATE_LOG_FILTER', 'restate=info') 136 | self.with_env('RESTATE_BOOTSTRAP_NUM_PARTITIONS', '1') 137 | self.with_env('RESTATE_DEFAULT_NUM_PARTITIONS', '1') 138 | self.with_env('RESTATE_SHUTDOWN_TIMEOUT', '10s') 139 | self.with_env('RESTATE_ROCKSDB_TOTAL_MEMORY_SIZE', '32 MB') 140 | self.with_env('RESTATE_WORKER__INVOKER__IN_MEMORY_QUEUE_LENGTH_LIMIT', '64') 141 | self.with_env('RESTATE_WORKER__INVOKER__INACTIVITY_TIMEOUT', '10m') 142 | self.with_env('RESTATE_WORKER__INVOKER__ABORT_TIMEOUT', '10m') 143 | 144 | self.with_kwargs(extra_hosts={"host.docker.internal" : "host-gateway"}) 145 | 146 | def ingress_url(self): 147 | """return the URL to access the Restate ingress""" 148 | return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8080)}" 149 | 150 | def admin_url(self): 151 | """return the URL to access the Restate admin""" 152 | return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(9070)}" 153 | 154 | def get_admin_client(self): 155 | """return an httpx client to access the admin interface""" 156 | return httpx.Client(base_url=self.admin_url()) 157 | 158 | def get_ingress_client(self): 159 | """return an httpx client to access the ingress interface""" 160 | return httpx.Client(base_url=self.ingress_url()) 161 | 162 | @wait_container_is_ready(httpx.HTTPError, URLError) 163 | def _wait_healthy(self): 164 | """wait for restate's health checks to pass""" 165 | self.get_ingress_client().get("/restate/health").raise_for_status() 166 | self.get_admin_client().get("/health").raise_for_status() 167 | 168 | 169 | def start(self, stream_logs = False): 170 | """start the container and wait for health checks to pass""" 171 | super().start() 172 | 173 | def stream_log(): 174 | for line in self.get_wrapped_container().logs(stream=True): 175 | print(line.decode("utf-8"), end="", flush=True) 176 | 177 | if stream_logs: 178 | thread = threading.Thread(target=stream_log, daemon=True) 179 | thread.start() 180 | self.log_thread = thread 181 | 182 | self._wait_healthy() 183 | return self 184 | 185 | 186 | @dataclass 187 | class TestConfiguration: 188 | """A configuration for running tests""" 189 | restate_image: str = "restatedev/restate:latest" 190 | stream_logs: bool = False 191 | 192 | 193 | class RestateTestHarness: 194 | """A test harness for running Restate SDKs""" 195 | bind_address: typing.Optional[BindAddress] = None 196 | server: typing.Optional[AsgiServer] = None 197 | restate: typing.Optional[RestateContainer] = None 198 | 199 | def __init__(self, asgi_app, config: typing.Optional[TestConfiguration]): 200 | self.asgi_app = asgi_app 201 | if config: 202 | self.config = config 203 | else: 204 | self.config = TestConfiguration() 205 | 206 | def start(self): 207 | """start the restate server and the sdk""" 208 | self.bind_address = TcpSocketBindAddress() 209 | self.server = AsgiServer(self.asgi_app, self.bind_address).start() 210 | self.restate = RestateContainer(image=self.config.restate_image) \ 211 | .start(self.config.stream_logs) 212 | try: 213 | self._register_sdk() 214 | except Exception as e: 215 | self.stop() 216 | raise AssertionError("Failed to register the SDK with the Restate server") from e 217 | 218 | def _register_sdk(self): 219 | """register the sdk with the restate server""" 220 | assert self.bind_address is not None 221 | assert self.restate is not None 222 | 223 | uri = self.bind_address.get_endpoint_connection_string() 224 | client = self.restate.get_admin_client() 225 | res = client.post("/deployments", 226 | headers={"content-type" : "application/json"}, 227 | json={"uri": uri}) 228 | if not res.is_success: 229 | msg = f"unable to register the services at {uri} - {res.status_code} {res.text}" 230 | raise AssertionError(msg) 231 | 232 | def stop(self): 233 | """terminate the restate server and the sdk""" 234 | if self.server is not None: 235 | self.server.stop() 236 | if self.restate is not None: 237 | self.restate.stop() 238 | if self.bind_address is not None: 239 | self.bind_address.cleanup() 240 | 241 | def ingress_client(self): 242 | """return an httpx client to access the restate server's ingress""" 243 | if self.restate is None: 244 | raise AssertionError("The Restate server has not been started. Use .start()") 245 | return self.restate.get_ingress_client() 246 | 247 | 248 | def __enter__(self): 249 | self.start() 250 | return self 251 | 252 | def __exit__(self, exc_type, exc_value, traceback): 253 | self.stop() 254 | return False 255 | 256 | 257 | def test_harness(app, 258 | follow_logs: bool = False, 259 | restate_image: str = "restatedev/restate:latest") -> RestateTestHarness: 260 | """create a test harness for running Restate SDKs""" 261 | config = TestConfiguration( 262 | restate_image=restate_image, 263 | stream_logs=follow_logs, 264 | ) 265 | return RestateTestHarness(app, config) 266 | -------------------------------------------------------------------------------- /python/restate/object.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | # pylint: disable=R0917 12 | """ 13 | This module defines the Service class for representing a restate service. 14 | """ 15 | 16 | from functools import wraps 17 | import inspect 18 | import typing 19 | 20 | from restate.serde import Serde, DefaultSerde 21 | from restate.handler import Handler, HandlerIO, ServiceTag, make_handler 22 | 23 | I = typing.TypeVar('I') 24 | O = typing.TypeVar('O') 25 | 26 | 27 | # disable too many arguments warning 28 | # pylint: disable=R0913 29 | 30 | # disable line too long warning 31 | # pylint: disable=C0301 32 | 33 | class VirtualObject: 34 | """ 35 | Represents a restate virtual object. 36 | 37 | Args: 38 | name (str): The name of the object. 39 | description (str): The description of the object. 40 | metadata (dict): The metadata of the object. 41 | """ 42 | 43 | handlers: typing.Dict[str, Handler[typing.Any, typing.Any]] 44 | 45 | def __init__(self, name, 46 | description: typing.Optional[str] = None, 47 | metadata: typing.Optional[typing.Dict[str, str]]=None): 48 | self.service_tag = ServiceTag("object", name, description, metadata) 49 | self.handlers = {} 50 | 51 | @property 52 | def name(self): 53 | """ 54 | Returns the name of the object. 55 | """ 56 | return self.service_tag.name 57 | 58 | def handler(self, 59 | name: typing.Optional[str] = None, 60 | kind: typing.Optional[typing.Literal["exclusive", "shared"]] = "exclusive", 61 | accept: str = "application/json", 62 | content_type: str = "application/json", 63 | input_serde: Serde[I] = DefaultSerde(), 64 | output_serde: Serde[O] = DefaultSerde(), 65 | metadata: typing.Optional[dict] = None) -> typing.Callable: 66 | """ 67 | Decorator for defining a handler function. 68 | 69 | Args: 70 | name: The name of the handler. 71 | accept: The accept type of the request. Default "application/json". 72 | content_type: The content type of the request. Default "application/json". 73 | serializer: The serializer function to convert the response object to bytes. 74 | deserializer: The deserializer function to convert the request bytes to an object. 75 | 76 | Returns: 77 | Callable: The decorated function. 78 | 79 | Raises: 80 | ValueError: If the handler name is not provided. 81 | 82 | Example: 83 | @service.handler() 84 | def my_handler_func(ctx, request): 85 | # handler logic 86 | pass 87 | """ 88 | handler_io = HandlerIO[I,O](accept, content_type, input_serde, output_serde) 89 | def wrapper(fn): 90 | 91 | @wraps(fn) 92 | def wrapped(*args, **kwargs): 93 | return fn(*args, **kwargs) 94 | 95 | signature = inspect.signature(fn, eval_str=True) 96 | handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature, inspect.getdoc(fn), metadata) 97 | self.handlers[handler.name] = handler 98 | return wrapped 99 | 100 | return wrapper 101 | -------------------------------------------------------------------------------- /python/restate/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/restatedev/sdk-python/844d864e2c638a5d6e82a056361f683e60fcf2ca/python/restate/py.typed -------------------------------------------------------------------------------- /python/restate/serde.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """ This module contains functions for serializing and deserializing data. """ 12 | import abc 13 | import json 14 | import typing 15 | 16 | from dataclasses import asdict, is_dataclass 17 | 18 | def try_import_pydantic_base_model(): 19 | """ 20 | Try to import PydanticBaseModel from Pydantic. 21 | """ 22 | try: 23 | from pydantic import BaseModel # type: ignore # pylint: disable=import-outside-toplevel 24 | return BaseModel 25 | except ImportError: 26 | class Dummy: # pylint: disable=too-few-public-methods 27 | """a dummy class to use when Pydantic is not available""" 28 | 29 | return Dummy 30 | 31 | def try_import_from_dacite(): 32 | """ 33 | Try to import from_dict from dacite. 34 | """ 35 | try: 36 | from dacite import from_dict # type: ignore # pylint: disable=import-outside-toplevel 37 | 38 | return asdict, from_dict 39 | 40 | except ImportError: 41 | 42 | def to_dict(obj): 43 | """a dummy function when dacite is not available""" 44 | raise RuntimeError("Trying to deserialize into a @dataclass." \ 45 | "Please add the optional dependencies needed." \ 46 | "use pip install restate-sdk[serde] " 47 | "or" \ 48 | " pip install restate-sdk[all] to install all dependencies.") 49 | 50 | 51 | def from_dict(a,b): # pylint: disable=too-few-public-methods,unused-argument 52 | """a dummy function when dacite is not available""" 53 | 54 | raise RuntimeError("Trying to deserialize into a @dataclass." \ 55 | "Please add the optional dependencies needed." \ 56 | "use pip install restate-sdk[serde] " 57 | "or" \ 58 | " pip install restate-sdk[all] to install all dependencies.") 59 | 60 | return to_dict, from_dict 61 | 62 | PydanticBaseModel = try_import_pydantic_base_model() 63 | DaciteToDict, DaciteFromDict = try_import_from_dacite() 64 | 65 | T = typing.TypeVar('T') 66 | I = typing.TypeVar('I') 67 | O = typing.TypeVar('O') 68 | 69 | # disable to few parameters 70 | # pylint: disable=R0903 71 | 72 | def is_pydantic(annotation) -> bool: 73 | """ 74 | Check if an object is a Pydantic model. 75 | """ 76 | try: 77 | return issubclass(annotation, PydanticBaseModel) 78 | except TypeError: 79 | # annotation is not a class or a type 80 | return False 81 | 82 | class Serde(typing.Generic[T], abc.ABC): 83 | """serializer/deserializer interface.""" 84 | 85 | @abc.abstractmethod 86 | def deserialize(self, buf: bytes) -> typing.Optional[T]: 87 | """ 88 | Deserializes a bytearray to an object. 89 | """ 90 | 91 | @abc.abstractmethod 92 | def serialize(self, obj: typing.Optional[T]) -> bytes: 93 | """ 94 | Serializes an object to a bytearray. 95 | """ 96 | 97 | class BytesSerde(Serde[bytes]): 98 | """A pass-trough serializer/deserializer.""" 99 | 100 | def deserialize(self, buf: bytes) -> typing.Optional[bytes]: 101 | """ 102 | Deserializes a bytearray to a bytearray. 103 | 104 | Args: 105 | buf (bytearray): The bytearray to deserialize. 106 | 107 | Returns: 108 | typing.Optional[bytes]: The deserialized bytearray. 109 | """ 110 | return buf 111 | 112 | def serialize(self, obj: typing.Optional[bytes]) -> bytes: 113 | """ 114 | Serializes a bytearray to a bytearray. 115 | 116 | Args: 117 | obj (bytes): The bytearray to serialize. 118 | 119 | Returns: 120 | bytearray: The serialized bytearray. 121 | """ 122 | if obj is None: 123 | return bytes() 124 | return obj 125 | 126 | 127 | class JsonSerde(Serde[I]): 128 | """A JSON serializer/deserializer.""" 129 | 130 | def deserialize(self, buf: bytes) -> typing.Optional[I]: 131 | """ 132 | Deserializes a bytearray to a JSON object. 133 | 134 | Args: 135 | buf (bytearray): The bytearray to deserialize. 136 | 137 | Returns: 138 | typing.Optional[I]: The deserialized JSON object. 139 | """ 140 | if not buf: 141 | return None 142 | return json.loads(buf) 143 | 144 | def serialize(self, obj: typing.Optional[I]) -> bytes: 145 | """ 146 | Serializes a JSON object to a bytearray. 147 | 148 | Args: 149 | obj (I): The JSON object to serialize. 150 | 151 | Returns: 152 | bytearray: The serialized bytearray. 153 | """ 154 | if obj is None: 155 | return bytes() 156 | 157 | return bytes(json.dumps(obj), "utf-8") 158 | 159 | class DefaultSerde(Serde[I]): 160 | """ 161 | The default serializer/deserializer used when no explicit type hints are provided. 162 | 163 | Behavior: 164 | - Serialization: 165 | - If the object is an instance of Pydantic's `BaseModel`, 166 | it uses `model_dump_json()` for serialization. 167 | - Otherwise, it falls back to `json.dumps()`. 168 | - Deserialization: 169 | - Uses `json.loads()` to convert byte arrays into Python objects. 170 | - Does **not** automatically reconstruct Pydantic models; 171 | deserialized objects remain as generic JSON structures (dicts, lists, etc.). 172 | 173 | Serde Selection: 174 | - When using the `@handler` decorator, if a function's type hints specify a Pydantic model, 175 | `PydanticJsonSerde` is automatically selected instead of `DefaultSerde`. 176 | - `DefaultSerde` is only used if no explicit type hints are provided. 177 | 178 | This serde ensures compatibility with both structured (Pydantic) and unstructured JSON data, 179 | while allowing automatic serde selection based on type hints. 180 | """ 181 | 182 | def __init__(self, type_hint: typing.Optional[typing.Type[I]] = None): 183 | super().__init__() 184 | self.type_hint = type_hint 185 | 186 | def with_maybe_type(self, type_hint: typing.Type[I] | None = None) -> "DefaultSerde[I]": 187 | """ 188 | Returns a new instance of DefaultSerde with the provided type hint. 189 | This is useful for creating a serde that is specific to a certain type. 190 | NOTE: This method does not modify the current instance. 191 | Args: 192 | type_hint (Type[I] | None): The type hint to use for serialization/deserialization. 193 | Returns: 194 | DefaultSerde[I]: A new instance of DefaultSerde with the provided type hint. 195 | """ 196 | return DefaultSerde(type_hint) 197 | 198 | def deserialize(self, buf: bytes) -> typing.Optional[I]: 199 | """ 200 | Deserializes a byte array into a Python object. 201 | 202 | Args: 203 | buf (bytes): The byte array to deserialize. 204 | 205 | Returns: 206 | Optional[I]: The resulting Python object, or None if the input is empty. 207 | """ 208 | if not buf: 209 | return None 210 | if is_pydantic(self.type_hint): 211 | return self.type_hint.model_validate_json(buf) # type: ignore 212 | if is_dataclass(self.type_hint): 213 | data = json.loads(buf) 214 | return DaciteFromDict(self.type_hint, data) 215 | return json.loads(buf) 216 | 217 | def serialize(self, obj: typing.Optional[I]) -> bytes: 218 | """ 219 | Serializes a Python object into a byte array. 220 | If the object is a Pydantic BaseModel, uses its model_dump_json method. 221 | 222 | Args: 223 | obj (Optional[I]): The Python object to serialize. 224 | 225 | Returns: 226 | bytes: The serialized byte array. 227 | """ 228 | if obj is None: 229 | return bytes() 230 | if is_pydantic(self.type_hint): 231 | return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined] 232 | if is_dataclass(obj): 233 | data = DaciteToDict(obj) # type: ignore 234 | return json.dumps(data).encode("utf-8") 235 | return json.dumps(obj).encode("utf-8") 236 | 237 | 238 | class PydanticJsonSerde(Serde[I]): 239 | """ 240 | Serde for Pydantic models to/from JSON 241 | """ 242 | 243 | def __init__(self, model): 244 | self.model = model 245 | 246 | def deserialize(self, buf: bytes) -> typing.Optional[I]: 247 | """ 248 | Deserializes a bytearray to a Pydantic model. 249 | 250 | Args: 251 | buf (bytearray): The bytearray to deserialize. 252 | 253 | Returns: 254 | typing.Optional[I]: The deserialized Pydantic model. 255 | """ 256 | if not buf: 257 | return None 258 | return self.model.model_validate_json(buf) 259 | 260 | def serialize(self, obj: typing.Optional[I]) -> bytes: 261 | """ 262 | Serializes a Pydantic model to a bytearray. 263 | 264 | Args: 265 | obj (I): The Pydantic model to serialize. 266 | 267 | Returns: 268 | bytearray: The serialized bytearray. 269 | """ 270 | if obj is None: 271 | return bytes() 272 | json_str = obj.model_dump_json() # type: ignore[attr-defined] 273 | return json_str.encode("utf-8") 274 | -------------------------------------------------------------------------------- /python/restate/server.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """This module contains the ASGI server for the restate framework.""" 12 | 13 | import asyncio 14 | from typing import Dict, TypedDict, Literal 15 | import traceback 16 | from restate.discovery import compute_discovery_json 17 | from restate.endpoint import Endpoint 18 | from restate.server_context import ServerInvocationContext, DisconnectedException 19 | from restate.server_types import Receive, ReceiveChannel, Scope, Send, binary_to_header, header_to_binary # pylint: disable=line-too-long 20 | from restate.vm import VMWrapper 21 | from restate._internal import PyIdentityVerifier, IdentityVerificationException # pylint: disable=import-error,no-name-in-module 22 | from restate._internal import SDK_VERSION # pylint: disable=import-error,no-name-in-module 23 | from restate.aws_lambda import is_running_on_lambda, wrap_asgi_as_lambda_handler 24 | 25 | X_RESTATE_SERVER = header_to_binary([("x-restate-server", f"restate-sdk-python/{SDK_VERSION}")]) 26 | 27 | async def send_status(send, receive, status_code: int): 28 | """respond with a status code""" 29 | await send({'type': 'http.response.start', 'status': status_code, "headers": X_RESTATE_SERVER}) 30 | # For more info on why this loop, see ServerInvocationContext.leave() 31 | # pylint: disable=R0801 32 | while True: 33 | event = await receive() 34 | if event is None: 35 | break 36 | if event.get('type') == 'http.disconnect': 37 | break 38 | if event.get('type') == 'http.request' and event.get('more_body', False) is False: 39 | break 40 | await send({'type': 'http.response.body'}) 41 | 42 | async def send404(send, receive): 43 | """respond with a 404""" 44 | await send_status(send, receive, 404) 45 | 46 | async def send_discovery(scope: Scope, send: Send, endpoint: Endpoint): 47 | """respond with a discovery""" 48 | discovered_as: Literal["request_response", "bidi"] 49 | if scope['http_version'] == '1.1': 50 | discovered_as = "request_response" 51 | else: 52 | discovered_as = "bidi" 53 | headers, js = compute_discovery_json(endpoint, 1, discovered_as) 54 | bin_headers = header_to_binary(headers.items()) 55 | bin_headers.extend(X_RESTATE_SERVER) 56 | await send({ 57 | 'type': 'http.response.start', 58 | 'status': 200, 59 | 'headers': bin_headers, 60 | 'trailers': False 61 | }) 62 | await send({ 63 | 'type': 'http.response.body', 64 | 'body': js.encode('utf-8'), 65 | 'more_body': False, 66 | }) 67 | 68 | async def send_health_check(send: Send): 69 | """respond with an health check""" 70 | headers = header_to_binary([("content-type", "application/json")]) 71 | headers.extend(X_RESTATE_SERVER) 72 | await send({ 73 | 'type': 'http.response.start', 74 | 'status': 200, 75 | 'headers': headers, 76 | 'trailers': False 77 | }) 78 | await send({ 79 | 'type': 'http.response.body', 80 | 'body': b'{"status":"ok"}', 81 | 'more_body': False, 82 | }) 83 | 84 | 85 | async def process_invocation_to_completion(vm: VMWrapper, 86 | handler, 87 | attempt_headers: Dict[str, str], 88 | receive: ReceiveChannel, 89 | send: Send): 90 | """Invoke the user code.""" 91 | status, res_headers = vm.get_response_head() 92 | res_bin_headers = header_to_binary(res_headers) 93 | res_bin_headers.extend(X_RESTATE_SERVER) 94 | await send({ 95 | 'type': 'http.response.start', 96 | 'status': status, 97 | 'headers': res_bin_headers, 98 | 'trailers': False 99 | }) 100 | assert status == 200 101 | # ======================================== 102 | # Read the input and the journal 103 | # ======================================== 104 | while True: 105 | message = await receive() 106 | if message.get('type') == 'http.disconnect': 107 | # everything ends here really ... 108 | return 109 | if message.get('type') == 'http.request': 110 | body = message.get('body', None) 111 | assert isinstance(body, bytes) 112 | vm.notify_input(body) 113 | if not message.get('more_body', False): 114 | vm.notify_input_closed() 115 | break 116 | if vm.is_ready_to_execute(): 117 | break 118 | # ======================================== 119 | # Execute the user code 120 | # ======================================== 121 | invocation = vm.sys_input() 122 | context = ServerInvocationContext(vm=vm, 123 | handler=handler, 124 | invocation=invocation, 125 | attempt_headers=attempt_headers, 126 | send=send, 127 | receive=receive) 128 | try: 129 | await context.enter() 130 | except asyncio.exceptions.CancelledError: 131 | context.on_attempt_finished() 132 | raise 133 | except DisconnectedException: 134 | # The client disconnected before we could send the response 135 | context.on_attempt_finished() 136 | return 137 | # pylint: disable=W0718 138 | except Exception: 139 | traceback.print_exc() 140 | try: 141 | await context.leave() 142 | finally: 143 | context.on_attempt_finished() 144 | 145 | class LifeSpanNotImplemented(ValueError): 146 | """Signal to the asgi server that we didn't implement lifespans""" 147 | 148 | 149 | class ParsedPath(TypedDict): 150 | """Parsed path from the request.""" 151 | type: Literal["invocation", "health", "discover", "unknown"] 152 | service: str | None 153 | handler: str | None 154 | 155 | def parse_path(request: str) -> ParsedPath: 156 | """Parse the path from the request.""" 157 | # The following routes are possible 158 | # $mountpoint/health 159 | # $mountpoint/discover 160 | # $mountpoint/invoke/:service/:handler 161 | # as we don't know the mountpoint, we need to check the path carefully 162 | fragments = request.rsplit('/', 4) 163 | # /invoke/:service/:handler 164 | if len(fragments) >= 3 and fragments[-3] == 'invoke': 165 | return { "type": "invocation" , "handler" : fragments[-1], "service" : fragments[-2] } 166 | # /health 167 | if fragments[-1] == 'health': 168 | return { "type": "health", "service": None, "handler": None } 169 | # /discover 170 | if fragments[-1] == 'discover': 171 | return { "type": "discover" , "service": None, "handler": None } 172 | # anything other than invoke is 404 173 | return { "type": "unknown" , "service": None, "handler": None } 174 | 175 | 176 | def asgi_app(endpoint: Endpoint): 177 | """Create an ASGI-3 app for the given endpoint.""" 178 | 179 | # Prepare request signer 180 | identity_verifier = PyIdentityVerifier(endpoint.identity_keys) 181 | 182 | async def app(scope: Scope, receive: Receive, send: Send): 183 | try: 184 | if scope['type'] == 'lifespan': 185 | raise LifeSpanNotImplemented() 186 | if scope['type'] != 'http': 187 | raise NotImplementedError(f"Unknown scope type {scope['type']}") 188 | 189 | request_path = scope['path'] 190 | assert isinstance(request_path, str) 191 | request: ParsedPath = parse_path(request_path) 192 | 193 | # Health check 194 | if request['type'] == 'health': 195 | await send_health_check(send) 196 | return 197 | 198 | # Verify Identity 199 | assert not isinstance(scope['headers'], str) 200 | assert hasattr(scope['headers'], '__iter__') 201 | request_headers = binary_to_header(scope['headers']) 202 | try: 203 | identity_verifier.verify(request_headers, request_path) 204 | except IdentityVerificationException: 205 | # Identify verification failed, send back unauthorized and close 206 | await send_status(send, receive, 401) 207 | return 208 | 209 | # might be a discovery request 210 | if request['type'] == 'discover': 211 | await send_discovery(scope, send, endpoint) 212 | return 213 | # anything other than invoke is 404 214 | if request['type'] == 'unknown': 215 | await send404(send, receive) 216 | return 217 | assert request['type'] == 'invocation' 218 | assert request['service'] is not None 219 | assert request['handler'] is not None 220 | service_name, handler_name = request['service'], request['handler'] 221 | service = endpoint.services.get(service_name) 222 | if not service: 223 | await send404(send, receive) 224 | return 225 | handler = service.handlers.get(handler_name) 226 | if not handler: 227 | await send404(send, receive) 228 | return 229 | # 230 | # At this point we have a valid handler. 231 | # Let us setup restate's execution context for this invocation and handler. 232 | # 233 | receive_channel = ReceiveChannel(receive) 234 | try: 235 | await process_invocation_to_completion(VMWrapper(request_headers), 236 | handler, 237 | dict(request_headers), 238 | receive_channel, 239 | send) 240 | finally: 241 | await receive_channel.close() 242 | except LifeSpanNotImplemented as e: 243 | raise e 244 | except Exception as e: 245 | traceback.print_exc() 246 | raise e 247 | 248 | if is_running_on_lambda(): 249 | # If we're on Lambda, just return the adapter 250 | return wrap_asgi_as_lambda_handler(app) 251 | 252 | return app 253 | -------------------------------------------------------------------------------- /python/restate/server_types.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """ 12 | This module contains the ASGI types definitions. 13 | 14 | :see: https://github.com/django/asgiref/blob/main/asgiref/typing.py 15 | """ 16 | 17 | import asyncio 18 | from typing import (Awaitable, Callable, Dict, Iterable, List, 19 | Tuple, Union, TypedDict, Literal, Optional, NotRequired, Any) 20 | 21 | class ASGIVersions(TypedDict): 22 | """ASGI Versions""" 23 | spec_version: str 24 | version: Union[Literal["2.0"], Literal["3.0"]] 25 | 26 | class Scope(TypedDict): 27 | """ASGI Scope""" 28 | type: Literal["http"] 29 | asgi: ASGIVersions 30 | http_version: str 31 | method: str 32 | scheme: str 33 | path: str 34 | raw_path: bytes 35 | query_string: bytes 36 | root_path: str 37 | headers: Iterable[Tuple[bytes, bytes]] 38 | client: Optional[Tuple[str, int]] 39 | server: Optional[Tuple[str, Optional[int]]] 40 | state: NotRequired[Dict[str, Any]] 41 | extensions: Optional[Dict[str, Dict[object, object]]] 42 | 43 | class RestateEvent(TypedDict): 44 | """An event that represents a run completion""" 45 | type: Literal["restate.run_completed"] 46 | data: Optional[Dict[str, Any]] 47 | 48 | class HTTPRequestEvent(TypedDict): 49 | """ASGI Request event""" 50 | type: Literal["http.request"] 51 | body: bytes 52 | more_body: bool 53 | 54 | class HTTPResponseStartEvent(TypedDict): 55 | """ASGI Response start event""" 56 | type: Literal["http.response.start"] 57 | status: int 58 | headers: Iterable[Tuple[bytes, bytes]] 59 | trailers: bool 60 | 61 | class HTTPResponseBodyEvent(TypedDict): 62 | """ASGI Response body event""" 63 | type: Literal["http.response.body"] 64 | body: bytes 65 | more_body: bool 66 | 67 | 68 | ASGIReceiveEvent = HTTPRequestEvent 69 | 70 | 71 | ASGISendEvent = Union[ 72 | HTTPResponseStartEvent, 73 | HTTPResponseBodyEvent 74 | ] 75 | 76 | Receive = Callable[[], Awaitable[ASGIReceiveEvent]] 77 | Send = Callable[[ASGISendEvent], Awaitable[None]] 78 | 79 | ASGIApp = Callable[ 80 | [ 81 | Scope, 82 | Receive, 83 | Send, 84 | ], 85 | Awaitable[None], 86 | ] 87 | 88 | def header_to_binary(headers: Iterable[Tuple[str, str]]) -> List[Tuple[bytes, bytes]]: 89 | """Convert a list of headers to a list of binary headers.""" 90 | return [ (k.encode('utf-8'), v.encode('utf-8')) for k,v in headers ] 91 | 92 | def binary_to_header(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[str, str]]: 93 | """Convert a list of binary headers to a list of headers.""" 94 | return [ (k.decode('utf-8'), v.decode('utf-8')) for k,v in headers ] 95 | 96 | class ReceiveChannel: 97 | """ASGI receive channel.""" 98 | 99 | def __init__(self, receive: Receive) -> None: 100 | self._queue = asyncio.Queue[Union[ASGIReceiveEvent, RestateEvent]]() 101 | self._http_input_closed = asyncio.Event() 102 | self._disconnected = asyncio.Event() 103 | 104 | async def loop(): 105 | """Receive loop.""" 106 | while not self._disconnected.is_set(): 107 | event = await receive() 108 | if event.get('type') == 'http.request' and not event.get('more_body', False): 109 | self._http_input_closed.set() 110 | elif event.get('type') == 'http.disconnect': 111 | self._http_input_closed.set() 112 | self._disconnected.set() 113 | await self._queue.put(event) 114 | 115 | self._task = asyncio.create_task(loop()) 116 | 117 | async def __call__(self) -> ASGIReceiveEvent | RestateEvent: 118 | """Get the next message.""" 119 | what = await self._queue.get() 120 | self._queue.task_done() 121 | return what 122 | 123 | async def block_until_http_input_closed(self) -> None: 124 | """Wait until the HTTP input is closed""" 125 | await self._http_input_closed.wait() 126 | 127 | async def enqueue_restate_event(self, what: RestateEvent): 128 | """Add a message.""" 129 | await self._queue.put(what) 130 | 131 | async def close(self): 132 | """Close the channel.""" 133 | self._http_input_closed.set() 134 | self._disconnected.set() 135 | if self._task.done(): 136 | return 137 | self._task.cancel() 138 | try: 139 | await self._task 140 | except asyncio.CancelledError: 141 | pass 142 | -------------------------------------------------------------------------------- /python/restate/service.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | 12 | # pylint: disable=R0917 13 | """ 14 | This module defines the Service class for representing a restate service. 15 | """ 16 | 17 | from functools import wraps 18 | import inspect 19 | import typing 20 | 21 | from restate.serde import Serde, DefaultSerde 22 | from .handler import Handler, HandlerIO, ServiceTag, make_handler 23 | 24 | I = typing.TypeVar('I') 25 | O = typing.TypeVar('O') 26 | 27 | 28 | # disable too many arguments warning 29 | # pylint: disable=R0913 30 | 31 | # disable line too long warning 32 | # pylint: disable=C0301 33 | 34 | class Service: 35 | """ 36 | Represents a restate service. 37 | 38 | Args: 39 | name (str): The name of the service. 40 | """ 41 | 42 | def __init__(self, name: str, 43 | description: typing.Optional[str] = None, 44 | metadata: typing.Optional[typing.Dict[str, str]] = None) -> None: 45 | self.service_tag = ServiceTag("service", name, description, metadata) 46 | self.handlers: typing.Dict[str, Handler] = {} 47 | 48 | @property 49 | def name(self): 50 | """ 51 | Returns the name of the service. 52 | """ 53 | return self.service_tag.name 54 | 55 | def handler(self, 56 | name: typing.Optional[str] = None, 57 | accept: str = "application/json", 58 | content_type: str = "application/json", 59 | input_serde: Serde[I] = DefaultSerde(), 60 | output_serde: Serde[O] = DefaultSerde(), 61 | metadata: typing.Optional[typing.Dict[str, str]] = None) -> typing.Callable: 62 | 63 | """ 64 | Decorator for defining a handler function. 65 | 66 | Args: 67 | name: The name of the handler. 68 | accept: The accept type of the request. Default "application/json". 69 | content_type: The content type of the request. Default "application/json". 70 | serializer: The serializer function to convert the response object to bytes. 71 | deserializer: The deserializer function to convert the request bytes to an object. 72 | 73 | Returns: 74 | Callable: The decorated function. 75 | 76 | Raises: 77 | ValueError: If the handler name is not provided. 78 | 79 | Example: 80 | @service.handler() 81 | def my_handler_func(ctx, request): 82 | # handler logic 83 | pass 84 | """ 85 | handler_io = HandlerIO[I,O](accept, content_type, input_serde, output_serde) 86 | def wrapper(fn): 87 | @wraps(fn) 88 | def wrapped(*args, **kwargs): 89 | return fn(*args, **kwargs) 90 | 91 | signature = inspect.signature(fn, eval_str=True) 92 | handler = make_handler(self.service_tag, handler_io, name, None, wrapped, signature, inspect.getdoc(fn), metadata) 93 | self.handlers[handler.name] = handler 94 | return wrapped 95 | 96 | return wrapper 97 | -------------------------------------------------------------------------------- /python/restate/vm.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """ 12 | wrap the restate._internal.PyVM class 13 | """ 14 | # pylint: disable=E1101,R0917 15 | # pylint: disable=too-many-arguments 16 | # pylint: disable=too-few-public-methods 17 | 18 | from dataclasses import dataclass 19 | import typing 20 | from restate._internal import PyVM, PyHeader, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long 21 | 22 | @dataclass 23 | class Invocation: 24 | """ 25 | Invocation dataclass 26 | """ 27 | invocation_id: str 28 | random_seed: int 29 | headers: typing.List[typing.Tuple[str, str]] 30 | input_buffer: bytes 31 | key: str 32 | 33 | @dataclass 34 | class RunRetryConfig: 35 | """ 36 | Expo Retry Configuration 37 | """ 38 | initial_interval: typing.Optional[int] = None 39 | max_attempts: typing.Optional[int] = None 40 | max_duration: typing.Optional[int] = None 41 | 42 | @dataclass 43 | class Failure: 44 | """ 45 | Failure 46 | """ 47 | code: int 48 | message: str 49 | 50 | @dataclass 51 | class NotReady: 52 | """ 53 | NotReady 54 | """ 55 | 56 | class SuspendedException(Exception): 57 | """ 58 | Suspended Exception 59 | """ 60 | def __init__(self, *args: object) -> None: 61 | super().__init__(*args) 62 | 63 | NOT_READY = NotReady() 64 | SUSPENDED = SuspendedException() 65 | CANCEL_HANDLE = CANCEL_NOTIFICATION_HANDLE 66 | 67 | NotificationType = typing.Optional[typing.Union[bytes, Failure, NotReady, list[str], str]] 68 | 69 | class DoProgressAnyCompleted: 70 | """ 71 | Represents a notification that any of the handles has completed. 72 | """ 73 | 74 | class DoProgressReadFromInput: 75 | """ 76 | Represents a notification that the input needs to be read. 77 | """ 78 | 79 | class DoProgressExecuteRun: 80 | """ 81 | Represents a notification that a run needs to be executed. 82 | """ 83 | handle: int 84 | 85 | def __init__(self, handle): 86 | self.handle = handle 87 | 88 | class DoProgressCancelSignalReceived: 89 | """ 90 | Represents a notification that a cancel signal has been received 91 | """ 92 | 93 | class DoWaitPendingRun: 94 | """ 95 | Represents a notification that a run is pending 96 | """ 97 | 98 | DO_PROGRESS_ANY_COMPLETED = DoProgressAnyCompleted() 99 | DO_PROGRESS_READ_FROM_INPUT = DoProgressReadFromInput() 100 | DO_PROGRESS_CANCEL_SIGNAL_RECEIVED = DoProgressCancelSignalReceived() 101 | DO_WAIT_PENDING_RUN = DoWaitPendingRun() 102 | 103 | DoProgressResult = typing.Union[DoProgressAnyCompleted, 104 | DoProgressReadFromInput, 105 | DoProgressExecuteRun, 106 | DoProgressCancelSignalReceived, 107 | DoWaitPendingRun] 108 | 109 | 110 | # pylint: disable=too-many-public-methods 111 | class VMWrapper: 112 | """ 113 | A wrapper class for the restate_sdk._internal.PyVM class. 114 | It provides a type-friendly interface to our shared vm. 115 | """ 116 | 117 | def __init__(self, headers: typing.List[typing.Tuple[str, str]]): 118 | self.vm = PyVM(headers) 119 | 120 | def get_response_head(self) -> typing.Tuple[int, typing.Iterable[typing.Tuple[str, str]]]: 121 | """ 122 | Retrieves the response head from the virtual machine. 123 | 124 | Returns: 125 | A tuple containing the status code and a list of header tuples. 126 | """ 127 | result = self.vm.get_response_head() 128 | return (result.status_code, result.headers) 129 | 130 | def notify_input(self, input_buf: bytes): 131 | """Send input to the virtual machine.""" 132 | self.vm.notify_input(input_buf) 133 | 134 | def notify_input_closed(self): 135 | """Notify the virtual machine that the input has been closed.""" 136 | self.vm.notify_input_closed() 137 | 138 | def notify_error(self, error: str, stacktrace: str): 139 | """Notify the virtual machine of an error.""" 140 | self.vm.notify_error(error, stacktrace) 141 | 142 | def take_output(self) -> typing.Optional[bytes]: 143 | """Take the output from the virtual machine.""" 144 | return self.vm.take_output() 145 | 146 | def is_ready_to_execute(self) -> bool: 147 | """Returns true when the VM is ready to operate.""" 148 | return self.vm.is_ready_to_execute() 149 | 150 | def is_completed(self, handle: int) -> bool: 151 | """Returns true when the notification handle is completed and hasn't been taken yet.""" 152 | return self.vm.is_completed(handle) 153 | 154 | def do_progress(self, handles: list[int]) -> DoProgressResult: 155 | """Do progress with notifications.""" 156 | result = self.vm.do_progress(handles) 157 | if isinstance(result, PySuspended): 158 | raise SUSPENDED 159 | if isinstance(result, PyDoProgressAnyCompleted): 160 | return DO_PROGRESS_ANY_COMPLETED 161 | if isinstance(result, PyDoProgressReadFromInput): 162 | return DO_PROGRESS_READ_FROM_INPUT 163 | if isinstance(result, PyDoProgressExecuteRun): 164 | return DoProgressExecuteRun(result.handle) 165 | if isinstance(result, PyDoProgressCancelSignalReceived): 166 | return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED 167 | if isinstance(result, PyDoWaitForPendingRun): 168 | return DO_WAIT_PENDING_RUN 169 | raise ValueError(f"Unknown progress type: {result}") 170 | 171 | def take_notification(self, handle: int) -> NotificationType: 172 | """Take the result of an asynchronous operation.""" 173 | result = self.vm.take_notification(handle) 174 | if result is None: 175 | return NOT_READY 176 | if isinstance(result, PyVoid): 177 | # success with an empty value 178 | return None 179 | if isinstance(result, bytes): 180 | # success with a non empty value 181 | return result 182 | if isinstance(result, PyStateKeys): 183 | # success with state keys 184 | return result.keys 185 | if isinstance(result, str): 186 | # success with invocation id 187 | return result 188 | if isinstance(result, PyFailure): 189 | # a terminal failure 190 | code = result.code 191 | message = result.message 192 | return Failure(code, message) 193 | if isinstance(result, PySuspended): 194 | # the state machine had suspended 195 | raise SUSPENDED 196 | raise ValueError(f"Unknown result type: {result}") 197 | 198 | def sys_input(self) -> Invocation: 199 | """ 200 | Retrieves the system input from the virtual machine. 201 | 202 | Returns: 203 | An instance of the Invocation class containing the system input. 204 | """ 205 | inp = self.vm.sys_input() 206 | invocation_id: str = inp.invocation_id 207 | random_seed: int = inp.random_seed 208 | headers: typing.List[typing.Tuple[str, str]] = [(h.key, h.value) for h in inp.headers] 209 | input_buffer: bytes = bytes(inp.input) 210 | key: str = inp.key 211 | 212 | return Invocation( 213 | invocation_id=invocation_id, 214 | random_seed=random_seed, 215 | headers=headers, 216 | input_buffer=input_buffer, 217 | key=key) 218 | 219 | def sys_write_output_success(self, output: bytes): 220 | """ 221 | Writes the output to the system. 222 | 223 | Args: 224 | output: The output to be written. It can be either a bytes or a Failure object. 225 | 226 | Returns: 227 | None 228 | """ 229 | self.vm.sys_write_output_success(output) 230 | 231 | def sys_write_output_failure(self, output: Failure): 232 | """ 233 | Writes the output to the system. 234 | 235 | Args: 236 | output: The output to be written. It can be either a bytes or a Failure object. 237 | 238 | Returns: 239 | None 240 | """ 241 | res = PyFailure(output.code, output.message) 242 | self.vm.sys_write_output_failure(res) 243 | 244 | 245 | def sys_get_state(self, name) -> int: 246 | """ 247 | Retrieves a key-value binding. 248 | 249 | Args: 250 | name: The name of the value to be retrieved. 251 | 252 | Returns: 253 | The value associated with the given name. 254 | """ 255 | return self.vm.sys_get_state(name) 256 | 257 | 258 | def sys_get_state_keys(self) -> int: 259 | """ 260 | Retrieves all keys. 261 | 262 | Returns: 263 | The state keys 264 | """ 265 | return self.vm.sys_get_state_keys() 266 | 267 | 268 | def sys_set_state(self, name: str, value: bytes): 269 | """ 270 | Sets a key-value binding. 271 | 272 | Args: 273 | name: The name of the value to be set. 274 | value: The value to be set. 275 | 276 | Returns: 277 | None 278 | """ 279 | self.vm.sys_set_state(name, value) 280 | 281 | def sys_clear_state(self, name: str): 282 | """Clear the state associated with the given name.""" 283 | self.vm.sys_clear_state(name) 284 | 285 | def sys_clear_all_state(self): 286 | """Clear the state associated with the given name.""" 287 | self.vm.sys_clear_all_state() 288 | 289 | def sys_sleep(self, millis: int): 290 | """Ask to sleep for a given duration""" 291 | return self.vm.sys_sleep(millis) 292 | 293 | def sys_call(self, 294 | service: str, 295 | handler: str, 296 | parameter: bytes, 297 | key: typing.Optional[str] = None, 298 | idempotency_key: typing.Optional[str] = None, 299 | headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None 300 | ): 301 | """Call a service""" 302 | if headers: 303 | headers = [PyHeader(key=h[0], value=h[1]) for h in headers] 304 | return self.vm.sys_call(service, handler, parameter, key, idempotency_key, headers) 305 | 306 | # pylint: disable=too-many-arguments 307 | def sys_send(self, 308 | service: str, 309 | handler: str, 310 | parameter: bytes, 311 | key: typing.Optional[str] = None, 312 | delay: typing.Optional[int] = None, 313 | idempotency_key: typing.Optional[str] = None, 314 | headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None 315 | ) -> int: 316 | """ 317 | send an invocation to a service, and return the handle 318 | to the promise that will resolve with the invocation id 319 | """ 320 | if headers: 321 | headers = [PyHeader(key=h[0], value=h[1]) for h in headers] 322 | return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, headers) 323 | 324 | def sys_run(self, name: str) -> int: 325 | """ 326 | Register a run 327 | """ 328 | return self.vm.sys_run(name) 329 | 330 | def sys_awakeable(self) -> typing.Tuple[str, int]: 331 | """ 332 | Return a fresh awaitable 333 | """ 334 | return self.vm.sys_awakeable() 335 | 336 | def sys_resolve_awakeable(self, name: str, value: bytes): 337 | """ 338 | Resolve 339 | """ 340 | self.vm.sys_complete_awakeable_success(name, value) 341 | 342 | def sys_reject_awakeable(self, name: str, failure: Failure): 343 | """ 344 | Reject 345 | """ 346 | py_failure = PyFailure(failure.code, failure.message) 347 | self.vm.sys_complete_awakeable_failure(name, py_failure) 348 | 349 | def propose_run_completion_success(self, handle: int, output: bytes) -> int: 350 | """ 351 | Exit a side effect 352 | 353 | Args: 354 | output: The output of the side effect. 355 | 356 | Returns: 357 | handle 358 | """ 359 | return self.vm.propose_run_completion_success(handle, output) 360 | 361 | def sys_get_promise(self, name: str) -> int: 362 | """Returns the promise handle""" 363 | return self.vm.sys_get_promise(name) 364 | 365 | def sys_peek_promise(self, name: str) -> int: 366 | """Peek into the workflow promise""" 367 | return self.vm.sys_peek_promise(name) 368 | 369 | def sys_complete_promise_success(self, name: str, value: bytes) -> int: 370 | """Complete the promise""" 371 | return self.vm.sys_complete_promise_success(name, value) 372 | 373 | def sys_complete_promise_failure(self, name: str, failure: Failure) -> int: 374 | """reject the promise on failure""" 375 | res = PyFailure(failure.code, failure.message) 376 | return self.vm.sys_complete_promise_failure(name, res) 377 | 378 | def propose_run_completion_failure(self, handle: int, output: Failure) -> int: 379 | """ 380 | Exit a side effect 381 | 382 | Args: 383 | name: The name of the side effect. 384 | output: The output of the side effect. 385 | """ 386 | res = PyFailure(output.code, output.message) 387 | return self.vm.propose_run_completion_failure(handle, res) 388 | 389 | # pylint: disable=line-too-long 390 | def propose_run_completion_transient(self, handle: int, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig) -> int | None: 391 | """ 392 | Exit a side effect with a transient Error. 393 | This requires a retry policy to be provided. 394 | """ 395 | py_failure = PyFailure(failure.code, failure.message) 396 | py_config = PyExponentialRetryConfig(config.initial_interval, config.max_attempts, config.max_duration) 397 | try: 398 | handle = self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config) 399 | # The VM decided not to retry, therefore we get back an handle that will be resolved 400 | # with a terminal failure. 401 | return handle 402 | # pylint: disable=bare-except 403 | except: 404 | # The VM decided to retry, therefore we tear down the current execution 405 | return None 406 | 407 | def sys_end(self): 408 | """ 409 | This method is responsible for ending the system. 410 | 411 | It calls the `sys_end` method of the `vm` object. 412 | """ 413 | self.vm.sys_end() 414 | 415 | def sys_cancel(self, invocation_id: str): 416 | """ 417 | Cancel a running invocation 418 | """ 419 | self.vm.sys_cancel(invocation_id) 420 | 421 | def attach_invocation(self, invocation_id: str) -> int: 422 | """ 423 | Attach to an invocation 424 | """ 425 | return self.vm.attach_invocation(invocation_id) 426 | -------------------------------------------------------------------------------- /python/restate/workflow.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | 12 | # pylint: disable=R0917 13 | """ 14 | This module defines the Service class for representing a restate service. 15 | """ 16 | 17 | from functools import wraps 18 | import inspect 19 | import typing 20 | 21 | from restate.serde import DefaultSerde, Serde 22 | from restate.handler import Handler, HandlerIO, ServiceTag, make_handler 23 | 24 | I = typing.TypeVar('I') 25 | O = typing.TypeVar('O') 26 | 27 | 28 | # disable too many arguments warning 29 | # pylint: disable=R0913 30 | 31 | # disable line too long warning 32 | # pylint: disable=C0301 33 | 34 | # disable similar lines warning 35 | # pylint: disable=R0801 36 | 37 | class Workflow: 38 | """ 39 | Represents a restate workflow. 40 | 41 | Args: 42 | name (str): The name of the object. 43 | """ 44 | 45 | handlers: typing.Dict[str, Handler[typing.Any, typing.Any]] 46 | 47 | def __init__(self, name, description: typing.Optional[str] = None, metadata: typing.Optional[typing.Dict[str, str]] = None): 48 | self.service_tag = ServiceTag("workflow", name, description, metadata) 49 | self.handlers = {} 50 | 51 | @property 52 | def name(self): 53 | """ 54 | Returns the name of the object. 55 | """ 56 | return self.service_tag.name 57 | 58 | def main(self, 59 | name: typing.Optional[str] = None, 60 | accept: str = "application/json", 61 | content_type: str = "application/json", 62 | input_serde: Serde[I] = DefaultSerde[I](), # type: ignore 63 | output_serde: Serde[O] = DefaultSerde[O](), # type: ignore 64 | metadata: typing.Optional[typing.Dict[str, str]] = None) -> typing.Callable: # type: ignore 65 | """Mark this handler as a workflow entry point""" 66 | return self._add_handler(name, 67 | kind="workflow", 68 | accept=accept, 69 | content_type=content_type, 70 | input_serde=input_serde, 71 | output_serde=output_serde, 72 | metadata=metadata) 73 | 74 | def handler(self, 75 | name: typing.Optional[str] = None, 76 | accept: str = "application/json", 77 | content_type: str = "application/json", 78 | input_serde: Serde[I] = DefaultSerde[I](), # type: ignore 79 | output_serde: Serde[O] = DefaultSerde[O](), # type: ignore 80 | metadata: typing.Optional[typing.Dict[str, str]] = None) -> typing.Callable: 81 | """ 82 | Decorator for defining a handler function. 83 | """ 84 | return self._add_handler(name, "shared", accept, content_type, input_serde, output_serde, metadata) 85 | 86 | def _add_handler(self, 87 | name: typing.Optional[str] = None, 88 | kind: typing.Literal["workflow", "shared", "exclusive"] = "shared", 89 | accept: str = "application/json", 90 | content_type: str = "application/json", 91 | input_serde: Serde[I] = DefaultSerde[I](), # type: ignore 92 | output_serde: Serde[O] = DefaultSerde[O](), # type: ignore 93 | metadata: typing.Optional[typing.Dict[str, str]] = None) -> typing.Callable: # type: ignore 94 | """ 95 | Decorator for defining a handler function. 96 | 97 | Args: 98 | name: The name of the handler. 99 | accept: The accept type of the request. Default "application/json". 100 | content_type: The content type of the request. Default "application/json". 101 | serializer: The serializer function to convert the response object to bytes. 102 | deserializer: The deserializer function to convert the request bytes to an object. 103 | metadata: An optional dictionary of metadata. 104 | 105 | Returns: 106 | Callable: The decorated function. 107 | 108 | Raises: 109 | ValueError: If the handler name is not provided. 110 | 111 | Example: 112 | @service.handler() 113 | def my_handler_func(ctx, request): 114 | # handler logic 115 | pass 116 | """ 117 | handler_io = HandlerIO[I,O](accept, content_type, input_serde, output_serde) 118 | def wrapper(fn): 119 | 120 | @wraps(fn) 121 | def wrapped(*args, **kwargs): 122 | return fn(*args, **kwargs) 123 | 124 | signature = inspect.signature(fn, eval_str=True) 125 | description = inspect.getdoc(fn) 126 | handler = make_handler(service_tag=self.service_tag, 127 | handler_io=handler_io, 128 | name=name, 129 | kind=kind, 130 | wrapped=wrapped, 131 | signature=signature, 132 | description=description, 133 | metadata=metadata) 134 | self.handlers[handler.name] = handler 135 | return wrapped 136 | 137 | return wrapper 138 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mypy 2 | pylint 3 | hypercorn 4 | maturin 5 | pytest 6 | pydantic 7 | httpx 8 | testcontainers 9 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "stable" 3 | profile = "minimal" 4 | components = ["rustfmt", "clippy"] 5 | -------------------------------------------------------------------------------- /shell.nix: -------------------------------------------------------------------------------- 1 | { pkgs ? import {} }: 2 | 3 | (pkgs.buildFHSEnv { 4 | name = "sdk-python"; 5 | targetPkgs = pkgs: (with pkgs; [ 6 | python3 7 | python3Packages.pip 8 | python3Packages.virtualenv 9 | just 10 | 11 | # rust 12 | rustup 13 | cargo 14 | clang 15 | llvmPackages.bintools 16 | protobuf 17 | cmake 18 | liburing 19 | pkg-config 20 | ]); 21 | 22 | RUSTC_VERSION = 23 | builtins.elemAt 24 | (builtins.match 25 | ".*channel *= *\"([^\"]*)\".*" 26 | (pkgs.lib.readFile ./rust-toolchain.toml) 27 | ) 28 | 0; 29 | 30 | LIBCLANG_PATH = pkgs.lib.makeLibraryPath [ pkgs.llvmPackages_latest.libclang.lib ]; 31 | 32 | runScript = '' 33 | bash 34 | ''; 35 | }).env 36 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | use pyo3::create_exception; 2 | use pyo3::prelude::*; 3 | use pyo3::types::{PyBytes, PyNone, PyString}; 4 | use restate_sdk_shared_core::{ 5 | CallHandle, CoreVM, DoProgressResponse, Error, Header, IdentityVerifier, Input, NonEmptyValue, 6 | NotificationHandle, ResponseHead, RetryPolicy, RunExitResult, SuspendedOrVMError, 7 | TakeOutputResult, Target, TerminalFailure, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM, 8 | }; 9 | use std::time::{Duration, SystemTime}; 10 | 11 | // Current crate version 12 | const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION"); 13 | 14 | // Data model 15 | 16 | #[pyclass] 17 | #[derive(Clone)] 18 | struct PyHeader { 19 | #[pyo3(get, set)] 20 | key: String, 21 | #[pyo3(get, set)] 22 | value: String, 23 | } 24 | 25 | #[pymethods] 26 | impl PyHeader { 27 | #[new] 28 | fn new(key: String, value: String) -> PyHeader { 29 | Self { key, value } 30 | } 31 | } 32 | 33 | impl From
for PyHeader { 34 | fn from(h: Header) -> Self { 35 | PyHeader { 36 | key: h.key.into(), 37 | value: h.value.into(), 38 | } 39 | } 40 | } 41 | 42 | impl From for Header { 43 | fn from(h: PyHeader) -> Self { 44 | Header { 45 | key: h.key.into(), 46 | value: h.value.into(), 47 | } 48 | } 49 | } 50 | 51 | #[pyclass] 52 | struct PyResponseHead { 53 | #[pyo3(get, set)] 54 | status_code: u16, 55 | #[pyo3(get, set)] 56 | headers: Vec<(String, String)>, 57 | } 58 | 59 | impl From for PyResponseHead { 60 | fn from(value: ResponseHead) -> Self { 61 | PyResponseHead { 62 | status_code: value.status_code, 63 | headers: value 64 | .headers 65 | .into_iter() 66 | .map(|Header { key, value }| (key.into(), value.into())) 67 | .collect(), 68 | } 69 | } 70 | } 71 | 72 | fn take_output_result_into_py( 73 | py: Python, 74 | take_output_result: TakeOutputResult, 75 | ) -> Bound<'_, PyAny> { 76 | match take_output_result { 77 | TakeOutputResult::Buffer(b) => PyBytes::new_bound(py, &b).into_any(), 78 | TakeOutputResult::EOF => PyNone::get_bound(py).to_owned().into_any(), 79 | } 80 | } 81 | 82 | type PyNotificationHandle = u32; 83 | 84 | #[pyclass] 85 | struct PyVoid; 86 | 87 | #[pyclass] 88 | struct PySuspended; 89 | 90 | #[pyclass] 91 | #[derive(Clone)] 92 | struct PyFailure { 93 | #[pyo3(get, set)] 94 | code: u16, 95 | #[pyo3(get, set)] 96 | message: String, 97 | } 98 | 99 | #[pymethods] 100 | impl PyFailure { 101 | #[new] 102 | fn new(code: u16, message: String) -> PyFailure { 103 | Self { code, message } 104 | } 105 | } 106 | 107 | #[pyclass] 108 | #[derive(Clone)] 109 | struct PyExponentialRetryConfig { 110 | #[pyo3(get, set)] 111 | initial_interval: Option, 112 | #[pyo3(get, set)] 113 | max_attempts: Option, 114 | #[pyo3(get, set)] 115 | max_duration: Option, 116 | } 117 | 118 | #[pymethods] 119 | impl PyExponentialRetryConfig { 120 | #[pyo3(signature = (initial_interval=None, max_attempts=None, max_duration=None))] 121 | #[new] 122 | fn new( 123 | initial_interval: Option, 124 | max_attempts: Option, 125 | max_duration: Option, 126 | ) -> Self { 127 | Self { 128 | initial_interval, 129 | max_attempts, 130 | max_duration, 131 | } 132 | } 133 | } 134 | 135 | impl From for RetryPolicy { 136 | fn from(value: PyExponentialRetryConfig) -> Self { 137 | RetryPolicy::Exponential { 138 | initial_interval: Duration::from_millis(value.initial_interval.unwrap_or(10)), 139 | max_attempts: value.max_attempts, 140 | max_duration: value.max_duration.map(Duration::from_millis), 141 | factor: 2.0, 142 | max_interval: None, 143 | } 144 | } 145 | } 146 | 147 | impl From for PyFailure { 148 | fn from(value: TerminalFailure) -> Self { 149 | PyFailure { 150 | code: value.code, 151 | message: value.message, 152 | } 153 | } 154 | } 155 | 156 | impl From for TerminalFailure { 157 | fn from(value: PyFailure) -> Self { 158 | TerminalFailure { 159 | code: value.code, 160 | message: value.message, 161 | } 162 | } 163 | } 164 | 165 | impl From for Error { 166 | fn from(value: PyFailure) -> Self { 167 | Self::new(value.code, value.message) 168 | } 169 | } 170 | 171 | #[pyclass] 172 | #[derive(Clone)] 173 | struct PyStateKeys { 174 | #[pyo3(get, set)] 175 | keys: Vec, 176 | } 177 | 178 | #[pyclass] 179 | pub struct PyInput { 180 | #[pyo3(get, set)] 181 | invocation_id: String, 182 | #[pyo3(get, set)] 183 | random_seed: u64, 184 | #[pyo3(get, set)] 185 | key: String, 186 | #[pyo3(get, set)] 187 | headers: Vec, 188 | #[pyo3(get, set)] 189 | input: Vec, 190 | } 191 | 192 | impl From for PyInput { 193 | fn from(value: Input) -> Self { 194 | PyInput { 195 | invocation_id: value.invocation_id, 196 | random_seed: value.random_seed, 197 | key: value.key, 198 | headers: value.headers.into_iter().map(Into::into).collect(), 199 | input: value.input.into(), 200 | } 201 | } 202 | } 203 | 204 | #[pyclass] 205 | struct PyDoProgressReadFromInput; 206 | 207 | #[pyclass] 208 | struct PyDoProgressAnyCompleted; 209 | 210 | #[pyclass] 211 | struct PyDoProgressExecuteRun { 212 | #[pyo3(get)] 213 | handle: PyNotificationHandle, 214 | } 215 | 216 | #[pyclass] 217 | struct PyDoProgressCancelSignalReceived; 218 | 219 | #[pyclass] 220 | struct PyDoWaitForPendingRun; 221 | 222 | #[pyclass] 223 | pub struct PyCallHandle { 224 | #[pyo3(get)] 225 | invocation_id_handle: PyNotificationHandle, 226 | #[pyo3(get)] 227 | result_handle: PyNotificationHandle, 228 | } 229 | 230 | impl From for PyCallHandle { 231 | fn from(value: CallHandle) -> Self { 232 | PyCallHandle { 233 | invocation_id_handle: value.invocation_id_notification_handle.into(), 234 | result_handle: value.call_notification_handle.into(), 235 | } 236 | } 237 | } 238 | 239 | // Errors and Exceptions 240 | 241 | #[derive(Debug)] 242 | struct PyVMError(Error); 243 | 244 | // Python representation of restate_sdk_shared_core::Error 245 | create_exception!( 246 | restate_sdk_python_core, 247 | VMException, 248 | pyo3::exceptions::PyException, 249 | "Restate VM exception." 250 | ); 251 | 252 | impl From for PyErr { 253 | fn from(value: PyVMError) -> Self { 254 | VMException::new_err(value.0.to_string()) 255 | } 256 | } 257 | 258 | impl From for PyVMError { 259 | fn from(value: Error) -> Self { 260 | PyVMError(value) 261 | } 262 | } 263 | 264 | // VM implementation 265 | 266 | #[pyclass] 267 | struct PyVM { 268 | vm: CoreVM, 269 | } 270 | 271 | #[pymethods] 272 | impl PyVM { 273 | #[new] 274 | fn new(headers: Vec<(String, String)>) -> Result { 275 | Ok(Self { 276 | vm: CoreVM::new(headers, VMOptions::default())?, 277 | }) 278 | } 279 | 280 | fn get_response_head(self_: PyRef<'_, Self>) -> PyResponseHead { 281 | self_.vm.get_response_head().into() 282 | } 283 | 284 | // Notifications 285 | 286 | fn notify_input(mut self_: PyRefMut<'_, Self>, buffer: &Bound<'_, PyBytes>) { 287 | let buf = buffer.as_bytes().to_vec().into(); 288 | self_.vm.notify_input(buf); 289 | } 290 | 291 | fn notify_input_closed(mut self_: PyRefMut<'_, Self>) { 292 | self_.vm.notify_input_closed(); 293 | } 294 | 295 | #[pyo3(signature = (error, stacktrace=None))] 296 | fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option) { 297 | let mut error = Error::new(restate_sdk_shared_core::error::codes::INTERNAL, error); 298 | if let Some(desc) = stacktrace { 299 | error = error.with_stacktrace(desc); 300 | } 301 | CoreVM::notify_error(&mut self_.vm, error, None); 302 | } 303 | 304 | // Take(s) 305 | 306 | /// Returns either bytes or None, indicating EOF 307 | fn take_output(mut self_: PyRefMut<'_, Self>) -> Bound<'_, PyAny> { 308 | take_output_result_into_py(self_.py(), self_.vm.take_output()) 309 | } 310 | 311 | fn is_ready_to_execute(self_: PyRef<'_, Self>) -> Result { 312 | self_.vm.is_ready_to_execute().map_err(Into::into) 313 | } 314 | 315 | fn is_completed(self_: PyRef<'_, Self>, handle: PyNotificationHandle) -> bool { 316 | self_.vm.is_completed(handle.into()) 317 | } 318 | 319 | fn do_progress( 320 | mut self_: PyRefMut<'_, Self>, 321 | any_handle: Vec, 322 | ) -> Result, PyVMError> { 323 | let res = self_.vm.do_progress( 324 | any_handle 325 | .into_iter() 326 | .map(NotificationHandle::from) 327 | .collect(), 328 | ); 329 | 330 | let py = self_.py(); 331 | 332 | match res { 333 | Err(SuspendedOrVMError::VM(e)) => Err(e.into()), 334 | Err(SuspendedOrVMError::Suspended(_)) => { 335 | Ok(PySuspended.into_py(py).into_bound(py).into_any()) 336 | } 337 | Ok(DoProgressResponse::AnyCompleted) => Ok(PyDoProgressAnyCompleted 338 | .into_py(py) 339 | .into_bound(py) 340 | .into_any()), 341 | Ok(DoProgressResponse::ReadFromInput) => Ok(PyDoProgressReadFromInput 342 | .into_py(py) 343 | .into_bound(py) 344 | .into_any()), 345 | Ok(DoProgressResponse::ExecuteRun(handle)) => Ok(PyDoProgressExecuteRun { 346 | handle: handle.into(), 347 | } 348 | .into_py(py) 349 | .into_bound(py) 350 | .into_any()), 351 | Ok(DoProgressResponse::CancelSignalReceived) => Ok(PyDoProgressCancelSignalReceived 352 | .into_py(py) 353 | .into_bound(py) 354 | .into_any()), 355 | Ok(DoProgressResponse::WaitingPendingRun) => Ok(PyDoWaitForPendingRun 356 | .into_py(py) 357 | .into_bound(py) 358 | .into_any()), 359 | } 360 | } 361 | 362 | /// Returns either: 363 | /// 364 | /// * `PyBytes` in case the async result holds success value 365 | /// * `PyFailure` in case the async result holds failure value 366 | /// * `PyVoid` in case the async result holds Void value 367 | /// * `PyStateKeys` in case the async result holds StateKeys 368 | /// * `PyString` in case the async result holds invocation id 369 | /// * `PySuspended` in case the state machine is suspended 370 | /// * `None` in case the async result is not yet present 371 | fn take_notification( 372 | mut self_: PyRefMut<'_, Self>, 373 | handle: PyNotificationHandle, 374 | ) -> Result, PyVMError> { 375 | let res = self_.vm.take_notification(NotificationHandle::from(handle)); 376 | 377 | let py = self_.py(); 378 | 379 | match res { 380 | Err(SuspendedOrVMError::VM(e)) => Err(e.into()), 381 | Err(SuspendedOrVMError::Suspended(_)) => { 382 | Ok(PySuspended.into_py(py).into_bound(py).into_any()) 383 | } 384 | Ok(None) => Ok(PyNone::get_bound(py).to_owned().into_any()), 385 | Ok(Some(Value::Void)) => Ok(PyVoid.into_py(py).into_bound(py).into_any()), 386 | Ok(Some(Value::Success(b))) => Ok(PyBytes::new_bound(py, &b).into_any()), 387 | Ok(Some(Value::Failure(f))) => { 388 | Ok(PyFailure::from(f).into_py(py).into_bound(py).into_any()) 389 | } 390 | Ok(Some(Value::StateKeys(keys))) => { 391 | Ok(PyStateKeys { keys }.into_py(py).into_bound(py).into_any()) 392 | } 393 | Ok(Some(Value::InvocationId(invocation_id))) => { 394 | Ok(PyString::new_bound(py, &invocation_id).into_any()) 395 | } 396 | } 397 | } 398 | 399 | // Syscall(s) 400 | 401 | fn sys_input(mut self_: PyRefMut<'_, Self>) -> Result { 402 | self_.vm.sys_input().map(Into::into).map_err(Into::into) 403 | } 404 | 405 | fn sys_get_state( 406 | mut self_: PyRefMut<'_, Self>, 407 | key: String, 408 | ) -> Result { 409 | self_ 410 | .vm 411 | .sys_state_get(key) 412 | .map(Into::into) 413 | .map_err(Into::into) 414 | } 415 | 416 | fn sys_get_state_keys( 417 | mut self_: PyRefMut<'_, Self>, 418 | ) -> Result { 419 | self_ 420 | .vm 421 | .sys_state_get_keys() 422 | .map(Into::into) 423 | .map_err(Into::into) 424 | } 425 | 426 | fn sys_set_state( 427 | mut self_: PyRefMut<'_, Self>, 428 | key: String, 429 | buffer: &Bound<'_, PyBytes>, 430 | ) -> Result<(), PyVMError> { 431 | self_ 432 | .vm 433 | .sys_state_set(key, buffer.as_bytes().to_vec().into()) 434 | .map_err(Into::into) 435 | } 436 | 437 | fn sys_clear_state(mut self_: PyRefMut<'_, Self>, key: String) -> Result<(), PyVMError> { 438 | self_.vm.sys_state_clear(key).map_err(Into::into) 439 | } 440 | 441 | fn sys_clear_all_state(mut self_: PyRefMut<'_, Self>) -> Result<(), PyVMError> { 442 | self_.vm.sys_state_clear_all().map_err(Into::into) 443 | } 444 | 445 | fn sys_sleep( 446 | mut self_: PyRefMut<'_, Self>, 447 | millis: u64, 448 | ) -> Result { 449 | let now = SystemTime::now() 450 | .duration_since(SystemTime::UNIX_EPOCH) 451 | .expect("Duration since unix epoch cannot fail"); 452 | self_ 453 | .vm 454 | .sys_sleep(String::default(), now + Duration::from_millis(millis), Some(now)) 455 | .map(Into::into) 456 | .map_err(Into::into) 457 | } 458 | 459 | #[pyo3(signature = (service, handler, buffer, key=None, idempotency_key=None, headers=None))] 460 | fn sys_call( 461 | mut self_: PyRefMut<'_, Self>, 462 | service: String, 463 | handler: String, 464 | buffer: &Bound<'_, PyBytes>, 465 | key: Option, 466 | idempotency_key: Option, 467 | headers: Option>, 468 | ) -> Result { 469 | self_ 470 | .vm 471 | .sys_call( 472 | Target { 473 | service, 474 | handler, 475 | key, 476 | idempotency_key, 477 | headers: headers 478 | .unwrap_or_default() 479 | .into_iter() 480 | .map(Into::into) 481 | .collect(), 482 | }, 483 | buffer.as_bytes().to_vec().into(), 484 | ) 485 | .map(Into::into) 486 | .map_err(Into::into) 487 | } 488 | 489 | #[pyo3(signature = (service, handler, buffer, key=None, delay=None, idempotency_key=None, headers=None))] 490 | fn sys_send( 491 | mut self_: PyRefMut<'_, Self>, 492 | service: String, 493 | handler: String, 494 | buffer: &Bound<'_, PyBytes>, 495 | key: Option, 496 | delay: Option, 497 | idempotency_key: Option, 498 | headers: Option>, 499 | ) -> Result { 500 | self_ 501 | .vm 502 | .sys_send( 503 | Target { 504 | service, 505 | handler, 506 | key, 507 | idempotency_key, 508 | headers: headers 509 | .unwrap_or_default() 510 | .into_iter() 511 | .map(Into::into) 512 | .collect(), 513 | }, 514 | buffer.as_bytes().to_vec().into(), 515 | delay.map(|millis| { 516 | SystemTime::now() 517 | .duration_since(SystemTime::UNIX_EPOCH) 518 | .expect("Duration since unix epoch cannot fail") 519 | + Duration::from_millis(millis) 520 | }), 521 | ) 522 | .map(|s| s.invocation_id_notification_handle.into()) 523 | .map_err(Into::into) 524 | } 525 | 526 | fn sys_awakeable( 527 | mut self_: PyRefMut<'_, Self>, 528 | ) -> Result<(String, PyNotificationHandle), PyVMError> { 529 | self_ 530 | .vm 531 | .sys_awakeable() 532 | .map(|(id, handle)| (id, handle.into())) 533 | .map_err(Into::into) 534 | } 535 | 536 | fn sys_complete_awakeable_success( 537 | mut self_: PyRefMut<'_, Self>, 538 | id: String, 539 | buffer: &Bound<'_, PyBytes>, 540 | ) -> Result<(), PyVMError> { 541 | self_ 542 | .vm 543 | .sys_complete_awakeable( 544 | id, 545 | NonEmptyValue::Success(buffer.as_bytes().to_vec().into()), 546 | ) 547 | .map_err(Into::into) 548 | } 549 | 550 | fn sys_complete_awakeable_failure( 551 | mut self_: PyRefMut<'_, Self>, 552 | id: String, 553 | value: PyFailure, 554 | ) -> Result<(), PyVMError> { 555 | self_ 556 | .vm 557 | .sys_complete_awakeable(id, NonEmptyValue::Failure(value.into())) 558 | .map_err(Into::into) 559 | } 560 | 561 | fn sys_get_promise( 562 | mut self_: PyRefMut<'_, Self>, 563 | key: String, 564 | ) -> Result { 565 | self_ 566 | .vm 567 | .sys_get_promise(key) 568 | .map(Into::into) 569 | .map_err(Into::into) 570 | } 571 | 572 | fn sys_peek_promise( 573 | mut self_: PyRefMut<'_, Self>, 574 | key: String, 575 | ) -> Result { 576 | self_ 577 | .vm 578 | .sys_peek_promise(key) 579 | .map(Into::into) 580 | .map_err(Into::into) 581 | } 582 | 583 | fn sys_complete_promise_success( 584 | mut self_: PyRefMut<'_, Self>, 585 | key: String, 586 | buffer: &Bound<'_, PyBytes>, 587 | ) -> Result { 588 | self_ 589 | .vm 590 | .sys_complete_promise( 591 | key, 592 | NonEmptyValue::Success(buffer.as_bytes().to_vec().into()), 593 | ) 594 | .map(Into::into) 595 | .map_err(Into::into) 596 | } 597 | 598 | fn sys_complete_promise_failure( 599 | mut self_: PyRefMut<'_, Self>, 600 | key: String, 601 | value: PyFailure, 602 | ) -> Result { 603 | self_ 604 | .vm 605 | .sys_complete_promise(key, NonEmptyValue::Failure(value.into())) 606 | .map(Into::into) 607 | .map_err(Into::into) 608 | } 609 | 610 | /// Returns the associated `PyNotificationHandle`. 611 | fn sys_run( 612 | mut self_: PyRefMut<'_, Self>, 613 | name: String, 614 | ) -> Result { 615 | self_.vm.sys_run(name).map(Into::into).map_err(Into::into) 616 | } 617 | 618 | fn sys_cancel( 619 | mut self_: PyRefMut<'_, Self>, 620 | invocation_id: String, 621 | ) -> Result<(), PyVMError> { 622 | self_.vm.sys_cancel_invocation(invocation_id).map_err(Into::into) 623 | } 624 | 625 | fn propose_run_completion_success( 626 | mut self_: PyRefMut<'_, Self>, 627 | handle: PyNotificationHandle, 628 | buffer: &Bound<'_, PyBytes>, 629 | ) -> Result<(), PyVMError> { 630 | CoreVM::propose_run_completion( 631 | &mut self_.vm, 632 | handle.into(), 633 | RunExitResult::Success(buffer.as_bytes().to_vec().into()), 634 | RetryPolicy::None, 635 | ) 636 | .map_err(Into::into) 637 | } 638 | 639 | fn propose_run_completion_failure( 640 | mut self_: PyRefMut<'_, Self>, 641 | handle: PyNotificationHandle, 642 | value: PyFailure, 643 | ) -> Result<(), PyVMError> { 644 | self_ 645 | .vm 646 | .propose_run_completion( 647 | handle.into(), 648 | RunExitResult::TerminalFailure(value.into()), 649 | RetryPolicy::None, 650 | ) 651 | .map_err(Into::into) 652 | } 653 | 654 | fn propose_run_completion_failure_transient( 655 | mut self_: PyRefMut<'_, Self>, 656 | handle: PyNotificationHandle, 657 | value: PyFailure, 658 | attempt_duration: u64, 659 | config: PyExponentialRetryConfig, 660 | ) -> Result<(), PyVMError> { 661 | self_ 662 | .vm 663 | .propose_run_completion( 664 | handle.into(), 665 | RunExitResult::RetryableFailure { 666 | attempt_duration: Duration::from_millis(attempt_duration), 667 | error: value.into(), 668 | }, 669 | config.into(), 670 | ) 671 | .map_err(Into::into) 672 | } 673 | 674 | fn sys_write_output_success( 675 | mut self_: PyRefMut<'_, Self>, 676 | buffer: &Bound<'_, PyBytes>, 677 | ) -> Result<(), PyVMError> { 678 | self_ 679 | .vm 680 | .sys_write_output(NonEmptyValue::Success(buffer.as_bytes().to_vec().into())) 681 | .map(Into::into) 682 | .map_err(Into::into) 683 | } 684 | 685 | fn sys_write_output_failure( 686 | mut self_: PyRefMut<'_, Self>, 687 | value: PyFailure, 688 | ) -> Result<(), PyVMError> { 689 | self_ 690 | .vm 691 | .sys_write_output(NonEmptyValue::Failure(value.into())) 692 | .map(Into::into) 693 | .map_err(Into::into) 694 | } 695 | 696 | fn attach_invocation( 697 | mut self_: PyRefMut<'_, Self>, 698 | invocation_id: String, 699 | ) -> Result { 700 | self_ 701 | .vm 702 | .sys_attach_invocation(restate_sdk_shared_core::AttachInvocationTarget::InvocationId(invocation_id)) 703 | .map(Into::into) 704 | .map_err(Into::into) 705 | } 706 | 707 | fn sys_end(mut self_: PyRefMut<'_, Self>) -> Result<(), PyVMError> { 708 | self_.vm.sys_end().map(Into::into).map_err(Into::into) 709 | } 710 | } 711 | 712 | #[pyclass] 713 | struct PyIdentityVerifier { 714 | verifier: IdentityVerifier, 715 | } 716 | 717 | // Exceptions 718 | create_exception!( 719 | restate_sdk_python_core, 720 | IdentityKeyException, 721 | pyo3::exceptions::PyException, 722 | "Restate identity key exception." 723 | ); 724 | 725 | create_exception!( 726 | restate_sdk_python_core, 727 | IdentityVerificationException, 728 | pyo3::exceptions::PyException, 729 | "Restate identity verification exception." 730 | ); 731 | 732 | #[pymethods] 733 | impl PyIdentityVerifier { 734 | #[new] 735 | fn new(keys: Vec) -> PyResult { 736 | Ok(Self { 737 | verifier: IdentityVerifier::new(&keys.iter().map(|x| &**x).collect::>()) 738 | .map_err(|e| IdentityKeyException::new_err(e.to_string()))?, 739 | }) 740 | } 741 | 742 | fn verify( 743 | self_: PyRef<'_, Self>, 744 | headers: Vec<(String, String)>, 745 | path: String, 746 | ) -> PyResult<()> { 747 | self_ 748 | .verifier 749 | .verify_identity(&headers, &path) 750 | .map_err(|e| IdentityVerificationException::new_err(e.to_string())) 751 | } 752 | } 753 | 754 | #[pymodule] 755 | fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { 756 | use tracing_subscriber::EnvFilter; 757 | 758 | tracing_subscriber::fmt() 759 | .with_env_filter(EnvFilter::from_env("RESTATE_CORE_LOG")) 760 | .init(); 761 | 762 | m.add_class::()?; 763 | m.add_class::()?; 764 | m.add_class::()?; 765 | m.add_class::()?; 766 | m.add_class::()?; 767 | m.add_class::()?; 768 | m.add_class::()?; 769 | m.add_class::()?; 770 | m.add_class::()?; 771 | m.add_class::()?; 772 | m.add_class::()?; 773 | m.add_class::()?; 774 | m.add_class::()?; 775 | m.add_class::()?; 776 | m.add_class::()?; 777 | m.add_class::()?; 778 | 779 | m.add("VMException", m.py().get_type_bound::())?; 780 | m.add( 781 | "IdentityKeyException", 782 | m.py().get_type_bound::(), 783 | )?; 784 | m.add( 785 | "IdentityVerificationException", 786 | m.py().get_type_bound::(), 787 | )?; 788 | m.add("SDK_VERSION", CURRENT_VERSION)?; 789 | m.add( 790 | "CANCEL_NOTIFICATION_HANDLE", 791 | PyNotificationHandle::from(CANCEL_NOTIFICATION_HANDLE), 792 | )?; 793 | Ok(()) 794 | } 795 | -------------------------------------------------------------------------------- /test-services/.env: -------------------------------------------------------------------------------- 1 | RESTATE_CORE_LOG=trace -------------------------------------------------------------------------------- /test-services/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker.io/docker/dockerfile:1.7-labs 2 | 3 | FROM ghcr.io/pyo3/maturin AS build-sdk 4 | 5 | WORKDIR /usr/src/app 6 | 7 | COPY src ./src/ 8 | COPY python ./python/ 9 | COPY Cargo.lock . 10 | COPY Cargo.toml . 11 | COPY rust-toolchain.toml . 12 | COPY requirements.txt . 13 | COPY pyproject.toml . 14 | COPY LICENSE . 15 | 16 | RUN maturin build --out dist --interpreter python3.12 17 | 18 | FROM python:3.12-slim AS test-services 19 | 20 | WORKDIR /usr/src/app 21 | 22 | COPY --from=build-sdk /usr/src/app/dist/* /usr/src/app/deps/ 23 | 24 | RUN pip install deps/* && pip install hypercorn 25 | COPY test-services/ . 26 | 27 | EXPOSE 9080 28 | 29 | ENV RESTATE_CORE_LOG=debug 30 | ENV RUST_BACKTRACE=1 31 | ENV PORT 9080 32 | 33 | ENTRYPOINT ["./entrypoint.sh"] 34 | -------------------------------------------------------------------------------- /test-services/README.md: -------------------------------------------------------------------------------- 1 | # Test services to run the sdk-test-suite 2 | 3 | ## To run locally 4 | 5 | * Grab the release of sdk-test-suite: https://github.com/restatedev/sdk-test-suite/releases 6 | 7 | * Prepare the docker image: 8 | ```shell 9 | docker build . -f test-services/Dockerfile -t restatedev/test-services 10 | ``` 11 | 12 | * Run the tests (requires JVM >= 17): 13 | ```shell 14 | java -jar restate-sdk-test-suite.jar run --exclusions-file test-services/exclusions.yaml restatedev/test-services 15 | ``` 16 | 17 | ## To debug a single test: 18 | 19 | * Run the python service using your IDE 20 | * Run the test runner in debug mode specifying test suite and test: 21 | ```shell 22 | java -jar restate-sdk-test-suite.jar debug --test-suite=lazyState --test-name=dev.restate.sdktesting.tests.State default-service=9080 23 | ``` 24 | 25 | For more info: https://github.com/restatedev/sdk-test-suite -------------------------------------------------------------------------------- /test-services/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | PORT=${PORT:-"9080"} 4 | 5 | python3 -m hypercorn testservices:app --config hypercorn-config.toml --bind "0.0.0.0:${PORT}" 6 | -------------------------------------------------------------------------------- /test-services/exclusions.yaml: -------------------------------------------------------------------------------- 1 | exclusions: {} 2 | -------------------------------------------------------------------------------- /test-services/hypercorn-config.toml: -------------------------------------------------------------------------------- 1 | h2_max_concurrent_streams = 2147483647 2 | keep_alive_max_requests = 2147483647 3 | keep_alive_timeout = 2147483647 4 | workers = 8 5 | 6 | -------------------------------------------------------------------------------- /test-services/requirements.txt: -------------------------------------------------------------------------------- 1 | hypercorn 2 | restate_sdk -------------------------------------------------------------------------------- /test-services/services/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | 11 | from typing import Dict, Union 12 | from restate import Service, VirtualObject, Workflow 13 | 14 | from .counter import counter_object as s1 15 | from .proxy import proxy as s2 16 | from .awakeable_holder import awakeable_holder as s3 17 | from. block_and_wait_workflow import workflow as s4 18 | from .cancel_test import runner, blocking_service as s5 19 | from .failing import failing as s6 20 | from .kill_test import kill_runner, kill_singleton as s7 21 | from .list_object import list_object as s8 22 | from .map_object import map_object as s9 23 | from .non_determinism import non_deterministic as s10 24 | from .test_utils import test_utils as s11 25 | from .virtual_object_command_interpreter import virtual_object_command_interpreter as s16 26 | 27 | from .interpreter import layer_0 as s12 28 | from .interpreter import layer_1 as s13 29 | from .interpreter import layer_2 as s14 30 | from .interpreter import helper as s15 31 | 32 | def list_services(bindings): 33 | """List all services in this module""" 34 | return {obj.name : obj for _, obj in bindings.items() if isinstance(obj, (Service, VirtualObject, Workflow))} 35 | 36 | def services_named(service_names): 37 | return [ _all_services[name] for name in service_names ] 38 | 39 | def all_services(): 40 | return _all_services.values() 41 | 42 | _all_services = list_services(locals()) 43 | -------------------------------------------------------------------------------- /test-services/services/awakeable_holder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | # pylint: disable=W0622 15 | 16 | from restate import VirtualObject, ObjectContext 17 | from restate.exceptions import TerminalError 18 | 19 | awakeable_holder = VirtualObject("AwakeableHolder") 20 | 21 | @awakeable_holder.handler() 22 | async def hold(ctx: ObjectContext, id: str): 23 | ctx.set("id", id) 24 | 25 | @awakeable_holder.handler(name="hasAwakeable") 26 | async def has_awakeable(ctx: ObjectContext) -> bool: 27 | res = await ctx.get("id") 28 | return res is not None 29 | 30 | @awakeable_holder.handler() 31 | async def unlock(ctx: ObjectContext, payload: str): 32 | id = await ctx.get("id") 33 | if id is None: 34 | raise TerminalError(message="No awakeable is registered") 35 | ctx.resolve_awakeable(id, payload) 36 | -------------------------------------------------------------------------------- /test-services/services/block_and_wait_workflow.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | # pylint: disable=W0622 15 | 16 | from restate import Workflow, WorkflowContext, WorkflowSharedContext 17 | from restate.exceptions import TerminalError 18 | 19 | workflow = Workflow("BlockAndWaitWorkflow") 20 | 21 | @workflow.main() 22 | async def run(ctx: WorkflowContext, input: str): 23 | ctx.set("my-state", input) 24 | output = await ctx.promise("durable-promise").value() 25 | 26 | peek = await ctx.promise("durable-promise").peek() 27 | if peek is None: 28 | raise TerminalError(message="Durable promise should be completed") 29 | 30 | return output 31 | 32 | 33 | @workflow.handler() 34 | async def unblock(ctx: WorkflowSharedContext, output: str): 35 | await ctx.promise("durable-promise").resolve(output) 36 | 37 | @workflow.handler(name="getState") 38 | async def get_state(ctx: WorkflowSharedContext, output: str) -> str | None: 39 | return await ctx.get("my-state") 40 | -------------------------------------------------------------------------------- /test-services/services/cancel_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from datetime import timedelta 16 | from typing import Literal 17 | 18 | from restate import VirtualObject, ObjectContext 19 | from restate.exceptions import TerminalError 20 | 21 | from . import awakeable_holder 22 | 23 | BlockingOperation = Literal["CALL", "SLEEP", "AWAKEABLE"] 24 | 25 | runner = VirtualObject("CancelTestRunner") 26 | 27 | @runner.handler(name="startTest") 28 | async def start_test(ctx: ObjectContext, op: BlockingOperation): 29 | try: 30 | await ctx.object_call(block, key=ctx.key(), arg=op) 31 | except TerminalError as t: 32 | if t.status_code == 409: 33 | ctx.set("state", True) 34 | else: 35 | raise t 36 | 37 | @runner.handler(name="verifyTest") 38 | async def verify_test(ctx: ObjectContext) -> bool: 39 | state = await ctx.get("state") 40 | if state is None: 41 | return False 42 | return state 43 | 44 | 45 | blocking_service = VirtualObject("CancelTestBlockingService") 46 | 47 | @blocking_service.handler() 48 | async def block(ctx: ObjectContext, op: BlockingOperation): 49 | name, awakeable = ctx.awakeable() 50 | await ctx.object_call(awakeable_holder.hold, key=ctx.key(), arg=name) 51 | await awakeable 52 | 53 | if op == "CALL": 54 | await ctx.object_call(block, key=ctx.key(), arg=op) 55 | elif op == "SLEEP": 56 | await ctx.sleep(timedelta(days=1024)) 57 | elif op == "AWAKEABLE": 58 | name, uncompleteable = ctx.awakeable() 59 | await uncompleteable 60 | 61 | @blocking_service.handler(name="isUnlocked") 62 | async def is_unlocked(ctx: ObjectContext): 63 | return None 64 | -------------------------------------------------------------------------------- /test-services/services/counter.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from typing import TypedDict 16 | from restate import VirtualObject, ObjectContext 17 | from restate.exceptions import TerminalError 18 | 19 | counter_object = VirtualObject("Counter") 20 | 21 | COUNTER_KEY = "counter" 22 | 23 | 24 | @counter_object.handler() 25 | async def reset(ctx: ObjectContext): 26 | ctx.clear(COUNTER_KEY) 27 | 28 | 29 | @counter_object.handler() 30 | async def get(ctx: ObjectContext) -> int: 31 | c: int | None = await ctx.get(COUNTER_KEY) 32 | if c is None: 33 | return 0 34 | return c 35 | 36 | 37 | class CounterUpdateResponse(TypedDict): 38 | oldValue: int 39 | newValue: int 40 | 41 | 42 | @counter_object.handler() 43 | async def add(ctx: ObjectContext, addend: int) -> CounterUpdateResponse: 44 | old_value: int | None = await ctx.get(COUNTER_KEY) 45 | if old_value is None: 46 | old_value = 0 47 | new_value = old_value + addend 48 | ctx.set(COUNTER_KEY, new_value) 49 | return CounterUpdateResponse(oldValue=old_value, newValue=new_value) 50 | 51 | 52 | @counter_object.handler(name="addThenFail") 53 | async def add_then_fail(ctx: ObjectContext, addend: int): 54 | old_value: int | None = await ctx.get(COUNTER_KEY) 55 | if old_value is None: 56 | old_value = 0 57 | new_value = old_value + addend 58 | ctx.set(COUNTER_KEY, new_value) 59 | 60 | raise TerminalError(message=ctx.key()) -------------------------------------------------------------------------------- /test-services/services/failing.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | # pylint: disable=W0622 15 | 16 | from restate import VirtualObject, ObjectContext 17 | from restate.exceptions import TerminalError 18 | 19 | failing = VirtualObject("Failing") 20 | 21 | @failing.handler(name="terminallyFailingCall") 22 | async def terminally_failing_call(ctx: ObjectContext, msg: str): 23 | raise TerminalError(message=msg) 24 | 25 | @failing.handler(name="callTerminallyFailingCall") 26 | async def call_terminally_failing_call(ctx: ObjectContext, msg: str) -> str: 27 | await ctx.object_call(terminally_failing_call, key="random-583e1bf2", arg=msg) 28 | 29 | raise Exception("Should not reach here") 30 | 31 | failures = 0 32 | 33 | @failing.handler(name="failingCallWithEventualSuccess") 34 | async def failing_call_with_eventual_success(ctx: ObjectContext) -> int: 35 | global failures 36 | failures += 1 37 | if failures >= 4: 38 | failures = 0 39 | return 4 40 | raise ValueError(f"Failed at attempt: {failures}") 41 | 42 | @failing.handler(name="terminallyFailingSideEffect") 43 | async def terminally_failing_side_effect(ctx: ObjectContext, error_message: str): 44 | 45 | def side_effect(): 46 | raise TerminalError(message=error_message) 47 | 48 | await ctx.run("sideEffect", side_effect) 49 | raise ValueError("Should not reach here") 50 | 51 | 52 | eventual_success_side_effects = 0 53 | 54 | @failing.handler(name="sideEffectSucceedsAfterGivenAttempts") 55 | async def side_effect_succeeds_after_given_attempts(ctx: ObjectContext, minimum_attempts: int) -> int: 56 | 57 | def side_effect(): 58 | global eventual_success_side_effects 59 | eventual_success_side_effects += 1 60 | if eventual_success_side_effects >= minimum_attempts: 61 | return eventual_success_side_effects 62 | raise ValueError(f"Failed at attempt: {eventual_success_side_effects}") 63 | 64 | return await ctx.run("sideEffect", side_effect, max_attempts=minimum_attempts + 1) # type: ignore 65 | 66 | eventual_failure_side_effects = 0 67 | 68 | @failing.handler(name="sideEffectFailsAfterGivenAttempts") 69 | async def side_effect_fails_after_given_attempts(ctx: ObjectContext, retry_policy_max_retry_count: int) -> int: 70 | 71 | def side_effect(): 72 | global eventual_failure_side_effects 73 | eventual_failure_side_effects += 1 74 | raise ValueError(f"Failed at attempt: {eventual_failure_side_effects}") 75 | 76 | try: 77 | await ctx.run("sideEffect", side_effect, max_attempts=retry_policy_max_retry_count) 78 | raise ValueError("Side effect did not fail.") 79 | except TerminalError as t: 80 | global eventual_failure_side_effects 81 | return eventual_failure_side_effects 82 | 83 | -------------------------------------------------------------------------------- /test-services/services/interpreter.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """Verification test""" 12 | 13 | from datetime import timedelta 14 | import json 15 | from typing import TypedDict 16 | import typing 17 | import random 18 | 19 | 20 | from restate.context import Context, ObjectContext, ObjectSharedContext 21 | from restate.exceptions import TerminalError 22 | from restate.object import VirtualObject 23 | from restate.serde import JsonSerde 24 | 25 | import restate 26 | 27 | SET_STATE = 1 28 | GET_STATE = 2 29 | CLEAR_STATE = 3 30 | INCREMENT_STATE_COUNTER = 4 31 | INCREMENT_STATE_COUNTER_INDIRECTLY = 5 32 | SLEEP = 6 33 | CALL_SERVICE = 7 34 | CALL_SLOW_SERVICE = 8 35 | INCREMENT_VIA_DELAYED_CALL = 9 36 | SIDE_EFFECT = 10 37 | THROWING_SIDE_EFFECT = 11 38 | SLOW_SIDE_EFFECT = 12 39 | RECOVER_TERMINAL_CALL = 13 40 | RECOVER_TERMINAL_MAYBE_UN_AWAITED = 14 41 | AWAIT_PROMISE = 15 42 | RESOLVE_AWAKEABLE = 16 43 | REJECT_AWAKEABLE = 17 44 | INCREMENT_STATE_COUNTER_VIA_AWAKEABLE = 18 45 | CALL_NEXT_LAYER_OBJECT = 19 46 | 47 | # suppress missing docstring 48 | # pylint: disable=C0115 49 | # pylint: disable=C0116 50 | # pylint: disable=C0301 51 | # pylint: disable=R0914, R0912, R0915, R0913 52 | 53 | 54 | helper = restate.Service("ServiceInterpreterHelper") 55 | 56 | @helper.handler() 57 | async def ping(ctx: Context) -> None: # pylint: disable=unused-argument 58 | pass 59 | 60 | @helper.handler() 61 | async def echo(ctx: Context, parameters: str) -> str: # pylint: disable=unused-argument 62 | return parameters 63 | 64 | @helper.handler(name = "echoLater") 65 | async def echo_later(ctx: Context, parameter: dict[str, typing.Any]) -> str: 66 | await ctx.sleep(timedelta(milliseconds=parameter['sleep'])) 67 | return parameter['parameter'] 68 | 69 | @helper.handler(name="terminalFailure") 70 | async def terminal_failure(ctx: Context) -> str: 71 | raise TerminalError("bye") 72 | 73 | @helper.handler(name="incrementIndirectly") 74 | async def increment_indirectly(ctx: Context, parameter) -> None: 75 | 76 | layer = parameter['layer'] 77 | key = parameter['key'] 78 | 79 | program = { 80 | "commands": [ 81 | { 82 | "kind": INCREMENT_STATE_COUNTER, 83 | }, 84 | ], 85 | } 86 | 87 | program_bytes = json.dumps(program).encode('utf-8') 88 | 89 | ctx.generic_send(f"ObjectInterpreterL{layer}", "interpret", program_bytes, key) 90 | 91 | @helper.handler(name="resolveAwakeable") 92 | async def resolve_awakeable(ctx: Context, aid: str) -> None: 93 | ctx.resolve_awakeable(aid, "ok") 94 | 95 | @helper.handler(name="rejectAwakeable") 96 | async def reject_awakeable(ctx: Context, aid: str) -> None: 97 | ctx.reject_awakeable(aid, "error") 98 | 99 | @helper.handler(name="incrementViaAwakeableDance") 100 | async def increment_via_awakeable_dance(ctx: Context, input: dict[str, typing.Any]) -> None: 101 | tx_promise_id = input['txPromiseId'] 102 | layer = input['interpreter']['layer'] 103 | key = input['interpreter']['key'] 104 | 105 | aid, promise = ctx.awakeable() 106 | ctx.resolve_awakeable(tx_promise_id, aid) 107 | await promise 108 | 109 | program = { 110 | "commands": [ 111 | { 112 | "kind": INCREMENT_STATE_COUNTER, 113 | }, 114 | ], 115 | } 116 | 117 | program_bytes = json.dumps(program).encode('utf-8') 118 | 119 | ctx.generic_send(f"ObjectInterpreterL{layer}", "interpret", program_bytes, key) 120 | 121 | 122 | class SupportService: 123 | 124 | def __init__(self, ctx: ObjectContext) -> None: 125 | self.ctx = ctx 126 | self.serde = JsonSerde[typing.Any]() 127 | 128 | async def call(self, method: str, arg: typing.Any) -> typing.Any: 129 | buffer = self.serde.serialize(arg) 130 | out_buffer = await self.ctx.generic_call("ServiceInterpreterHelper", method, buffer) 131 | return self.serde.deserialize(out_buffer) 132 | 133 | def send(self, method: str, arg: typing.Any, delay: int | None = None) -> None: 134 | buffer = self.serde.serialize(arg) 135 | if delay is None: 136 | send_delay = None 137 | else: 138 | send_delay = timedelta(milliseconds=delay) 139 | self.ctx.generic_send("ServiceInterpreterHelper", method, buffer, send_delay=send_delay) 140 | 141 | async def ping(self) -> None: 142 | return await self.call(method="ping", arg=None) 143 | 144 | async def echo(self, parameters: str) -> str: 145 | return await self.call(method="echo", arg=parameters) 146 | 147 | async def echo_later(self, parameter: str, sleep: int) -> str: 148 | arg = {"parameter": parameter, "sleep": sleep} 149 | return await self.call(method="echoLater", arg=arg) 150 | 151 | async def terminal_failure(self) -> str: 152 | return await self.call(method="terminalFailure", arg=None) 153 | 154 | async def increment_indirectly(self, layer: int, key: str, delay: typing.Optional[int] = None) -> None: 155 | arg = {"layer": layer, "key": key} 156 | self.send(method="incrementIndirectly", arg=arg, delay=delay) 157 | 158 | def resolve_awakeable(self, aid: str) -> None: 159 | self.send("resolveAwakeable", aid) 160 | 161 | def reject_awakeable(self, aid: str) -> None: 162 | self.send("rejectAwakeable", aid) 163 | 164 | def increment_via_awakeable_dance(self, layer: int, key: str, tx_promise_id: str) -> None: 165 | arg = { "interpreter" : { "layer": layer, "key": key} , "txPromiseId": tx_promise_id } 166 | self.send("incrementViaAwakeableDance", arg) 167 | 168 | 169 | class Command(TypedDict): 170 | kind: int 171 | key: int 172 | duration: int 173 | sleep: int 174 | index: int 175 | program: typing.Any # avoid circular type 176 | 177 | 178 | Program = dict[typing.Literal['commands'], 179 | typing.List[Command]] 180 | 181 | 182 | async def interpreter(layer: int, 183 | ctx: ObjectContext, 184 | program: Program) -> None: 185 | """Interprets a command and executes it.""" 186 | service = SupportService(ctx) 187 | coros: dict[int, 188 | typing.Tuple[typing.Any, typing.Awaitable[typing.Any]]] = {} 189 | 190 | async def await_promise(index: int) -> None: 191 | if index not in coros: 192 | return 193 | 194 | expected, coro = coros[index] 195 | del coros[index] 196 | try: 197 | result = await coro 198 | except TerminalError: 199 | result = "rejected" 200 | 201 | if result != expected: 202 | raise TerminalError(f"Expected {expected} but got {result}") 203 | 204 | for i, command in enumerate(program['commands']): 205 | command_type = command['kind'] 206 | if command_type == SET_STATE: 207 | ctx.set(f"key-{command['key']}", f"value-{command['key']}") 208 | elif command_type == GET_STATE: 209 | await ctx.get(f"key-{command['key']}") 210 | elif command_type == CLEAR_STATE: 211 | ctx.clear(f"key-{command['key']}") 212 | elif command_type == INCREMENT_STATE_COUNTER: 213 | c = await ctx.get("counter") or 0 214 | c += 1 215 | ctx.set("counter", c) 216 | elif command_type == SLEEP: 217 | duration = timedelta(milliseconds=command['duration']) 218 | await ctx.sleep(duration) 219 | elif command_type == CALL_SERVICE: 220 | expected = f"hello-{i}" 221 | coros[i] = (expected, service.echo(expected)) 222 | elif command_type == INCREMENT_VIA_DELAYED_CALL: 223 | delay = command['duration'] 224 | await service.increment_indirectly(layer=layer, key=ctx.key(), delay=delay) 225 | elif command_type == CALL_SLOW_SERVICE: 226 | expected = f"hello-{i}" 227 | coros[i] = (expected, service.echo_later(expected, command['sleep'])) 228 | elif command_type == SIDE_EFFECT: 229 | expected = f"hello-{i}" 230 | result = await ctx.run("sideEffect", lambda : expected) # pylint: disable=W0640 231 | if result != expected: 232 | raise TerminalError(f"Expected {expected} but got {result}") 233 | elif command_type == SLOW_SIDE_EFFECT: 234 | pass 235 | elif command_type == RECOVER_TERMINAL_CALL: 236 | try: 237 | await service.terminal_failure() 238 | except TerminalError: 239 | pass 240 | else: 241 | raise TerminalError("Expected terminal error") 242 | elif command_type == RECOVER_TERMINAL_MAYBE_UN_AWAITED: 243 | pass 244 | elif command_type == THROWING_SIDE_EFFECT: 245 | async def side_effect(): 246 | if bool(random.getrandbits(1)): 247 | raise ValueError("Random error") 248 | 249 | await ctx.run("throwingSideEffect", side_effect) 250 | elif command_type == INCREMENT_STATE_COUNTER_INDIRECTLY: 251 | await service.increment_indirectly(layer=layer, key=ctx.key()) 252 | elif command_type == AWAIT_PROMISE: 253 | index = command['index'] 254 | await await_promise(index) 255 | elif command_type == RESOLVE_AWAKEABLE: 256 | name, promise = ctx.awakeable() 257 | coros[i] = ("ok", promise) 258 | service.resolve_awakeable(name) 259 | elif command_type == REJECT_AWAKEABLE: 260 | name, promise = ctx.awakeable() 261 | coros[i] = ("rejected", promise) 262 | service.reject_awakeable(name) 263 | elif command_type == INCREMENT_STATE_COUNTER_VIA_AWAKEABLE: 264 | tx_promise_id, tx_promise = ctx.awakeable() 265 | service.increment_via_awakeable_dance(layer=layer, key=ctx.key(), tx_promise_id=tx_promise_id) 266 | their_promise_for_us_to_resolve: str = await tx_promise 267 | ctx.resolve_awakeable(their_promise_for_us_to_resolve, "ok") 268 | elif command_type == CALL_NEXT_LAYER_OBJECT: 269 | next_layer = f"ObjectInterpreterL{layer + 1}" 270 | key = f"{command['key']}" 271 | program = command['program'] 272 | js_program = json.dumps(program) 273 | raw_js_program = js_program.encode('utf-8') 274 | promise = ctx.generic_call(next_layer, "interpret", raw_js_program, key) 275 | coros[i] = (b'', promise) 276 | else: 277 | raise ValueError(f"Unknown command type: {command_type}") 278 | await await_promise(i) 279 | 280 | def make_layer(i): 281 | layer = VirtualObject(f"ObjectInterpreterL{i}") 282 | 283 | @layer.handler() 284 | async def interpret(ctx: ObjectContext, program: Program) -> None: 285 | await interpreter(i, ctx, program) 286 | 287 | @layer.handler(kind="shared") 288 | async def counter(ctx: ObjectSharedContext) -> int: 289 | return await ctx.get("counter") or 0 290 | 291 | return layer 292 | 293 | 294 | layer_0 = make_layer(0) 295 | layer_1 = make_layer(1) 296 | layer_2 = make_layer(2) 297 | -------------------------------------------------------------------------------- /test-services/services/kill_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from restate import VirtualObject, ObjectContext 16 | 17 | from . import awakeable_holder 18 | 19 | kill_runner = VirtualObject("KillTestRunner") 20 | 21 | @kill_runner.handler(name="startCallTree") 22 | async def start_call_tree(ctx: ObjectContext): 23 | await ctx.object_call(recursive_call, key=ctx.key(), arg=None) 24 | 25 | kill_singleton = VirtualObject("KillTestSingleton") 26 | 27 | @kill_singleton.handler(name="recursiveCall") 28 | async def recursive_call(ctx: ObjectContext): 29 | name, promise = ctx.awakeable() 30 | ctx.object_send(awakeable_holder.hold, key=ctx.key(), arg=name) 31 | await promise 32 | 33 | await ctx.object_call(recursive_call, key=ctx.key(), arg=None) 34 | 35 | @kill_singleton.handler(name="isUnlocked") 36 | async def is_unlocked(ctx: ObjectContext): 37 | return None 38 | -------------------------------------------------------------------------------- /test-services/services/list_object.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from restate import VirtualObject, ObjectContext 16 | 17 | list_object = VirtualObject("ListObject") 18 | 19 | @list_object.handler() 20 | async def append(ctx: ObjectContext, value: str): 21 | list = await ctx.get("list") or [] 22 | ctx.set("list", list + [value]) 23 | 24 | @list_object.handler() 25 | async def get(ctx: ObjectContext) -> list[str]: 26 | return await ctx.get("list") or [] 27 | 28 | @list_object.handler() 29 | async def clear(ctx: ObjectContext) -> list[str]: 30 | result = await ctx.get("list") or [] 31 | ctx.clear("list") 32 | return result 33 | -------------------------------------------------------------------------------- /test-services/services/map_object.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from typing import TypedDict 16 | from restate import VirtualObject, ObjectContext 17 | 18 | map_object = VirtualObject("MapObject") 19 | 20 | 21 | class Entry(TypedDict): 22 | key: str 23 | value: str 24 | 25 | @map_object.handler(name="set") 26 | async def map_set(ctx: ObjectContext, entry: Entry): 27 | ctx.set(entry["key"], entry["value"]) 28 | 29 | @map_object.handler(name="get") 30 | async def map_get(ctx: ObjectContext, key: str) -> str: 31 | return await ctx.get(key) or "" 32 | 33 | @map_object.handler(name="clearAll") 34 | async def map_clear_all(ctx: ObjectContext) -> list[Entry]: 35 | entries = [] 36 | for key in await ctx.state_keys(): 37 | value: str = await ctx.get(key) # type: ignore 38 | entry = Entry(key=key, value=value) 39 | entries.append(entry) 40 | ctx.clear(key) 41 | return entries -------------------------------------------------------------------------------- /test-services/services/non_determinism.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from datetime import timedelta 16 | from typing import Dict 17 | from restate import VirtualObject, ObjectContext 18 | 19 | from . import counter 20 | 21 | invoke_counts: Dict[str, int] = {} 22 | 23 | def do_left_action(ctx: ObjectContext) -> bool: 24 | count_key = ctx.key() 25 | invoke_counts[count_key] = invoke_counts.get(count_key, 0) + 1 26 | return invoke_counts[count_key] % 2 == 1 27 | 28 | def increment_counter(ctx: ObjectContext): 29 | ctx.object_send(counter.add, key=ctx.key(), arg=1) 30 | 31 | non_deterministic = VirtualObject("NonDeterministic") 32 | 33 | @non_deterministic.handler(name="setDifferentKey") 34 | async def set_different_key(ctx: ObjectContext): 35 | if do_left_action(ctx): 36 | ctx.set("a", "my-state") 37 | else: 38 | ctx.set("b", "my-state") 39 | await ctx.sleep(timedelta(milliseconds=100)) 40 | increment_counter(ctx) 41 | 42 | @non_deterministic.handler(name="backgroundInvokeWithDifferentTargets") 43 | async def background_invoke_with_different_targets(ctx: ObjectContext): 44 | if do_left_action(ctx): 45 | ctx.object_send(counter.get, key="abc", arg=None) 46 | else: 47 | ctx.object_send(counter.reset, key="abc", arg=None) 48 | await ctx.sleep(timedelta(milliseconds=100)) 49 | increment_counter(ctx) 50 | 51 | @non_deterministic.handler(name="callDifferentMethod") 52 | async def call_different_method(ctx: ObjectContext): 53 | if do_left_action(ctx): 54 | await ctx.object_call(counter.get, key="abc", arg=None) 55 | else: 56 | await ctx.object_call(counter.reset, key="abc", arg=None) 57 | await ctx.sleep(timedelta(milliseconds=100)) 58 | increment_counter(ctx) 59 | 60 | @non_deterministic.handler(name="eitherSleepOrCall") 61 | async def either_sleep_or_call(ctx: ObjectContext): 62 | if do_left_action(ctx): 63 | await ctx.sleep(timedelta(milliseconds=100)) 64 | else: 65 | await ctx.object_call(counter.get, key="abc", arg=None) 66 | await ctx.sleep(timedelta(milliseconds=100)) 67 | increment_counter(ctx) 68 | -------------------------------------------------------------------------------- /test-services/services/proxy.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from datetime import timedelta 16 | from restate import Service, Context 17 | from typing import TypedDict, Optional, Iterable 18 | 19 | proxy = Service("Proxy") 20 | 21 | 22 | class ProxyRequest(TypedDict): 23 | serviceName: str 24 | virtualObjectKey: Optional[str] 25 | handlerName: str 26 | message: Iterable[int] 27 | delayMillis: Optional[int] 28 | idempotencyKey: Optional[str] 29 | 30 | 31 | @proxy.handler() 32 | async def call(ctx: Context, req: ProxyRequest) -> Iterable[int]: 33 | response = await ctx.generic_call( 34 | req['serviceName'], 35 | req['handlerName'], 36 | bytes(req['message']), 37 | req.get('virtualObjectKey'), 38 | req.get('idempotencyKey')) 39 | return list(response) 40 | 41 | 42 | @proxy.handler(name="oneWayCall") 43 | async def one_way_call(ctx: Context, req: ProxyRequest) -> str: 44 | send_delay = None 45 | if req.get('delayMillis'): 46 | send_delay = timedelta(milliseconds=req['delayMillis']) 47 | handle = ctx.generic_send( 48 | req['serviceName'], 49 | req['handlerName'], 50 | bytes(req['message']), 51 | req.get('virtualObjectKey'), 52 | send_delay=send_delay, 53 | idempotency_key=req.get('idempotencyKey') 54 | ) 55 | invocation_id = await handle.invocation_id() 56 | return invocation_id 57 | 58 | 59 | class ManyCallRequest(TypedDict): 60 | proxyRequest: ProxyRequest 61 | oneWayCall: bool 62 | awaitAtTheEnd: bool 63 | 64 | @proxy.handler(name="manyCalls") 65 | async def many_calls(ctx: Context, requests: Iterable[ManyCallRequest]): 66 | to_await = [] 67 | 68 | for req in requests: 69 | if req['oneWayCall']: 70 | send_delay = None 71 | if req['proxyRequest'].get('delayMillis'): 72 | send_delay = timedelta(milliseconds=req['proxyRequest']['delayMillis']) 73 | ctx.generic_send( 74 | req['proxyRequest']['serviceName'], 75 | req['proxyRequest']['handlerName'], 76 | bytes(req['proxyRequest']['message']), 77 | req['proxyRequest'].get('virtualObjectKey'), 78 | send_delay=send_delay, 79 | idempotency_key=req['proxyRequest'].get('idempotencyKey') 80 | ) 81 | else: 82 | awaitable = ctx.generic_call( 83 | req['proxyRequest']['serviceName'], 84 | req['proxyRequest']['handlerName'], 85 | bytes(req['proxyRequest']['message']), 86 | req['proxyRequest'].get('virtualObjectKey'), 87 | idempotency_key=req['proxyRequest'].get('idempotencyKey')) 88 | if req['awaitAtTheEnd']: 89 | to_await.append(awaitable) 90 | 91 | for awaitable in to_await: 92 | await awaitable 93 | -------------------------------------------------------------------------------- /test-services/services/test_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | from datetime import timedelta 16 | from typing import Dict, List 17 | from restate import Service, Context 18 | from restate.serde import BytesSerde 19 | 20 | test_utils = Service("TestUtilsService") 21 | 22 | @test_utils.handler() 23 | async def echo(context: Context, input: str) -> str: 24 | return input 25 | 26 | @test_utils.handler(name="uppercaseEcho") 27 | async def uppercase_echo(context: Context, input: str) -> str: 28 | return input.upper() 29 | 30 | @test_utils.handler(name="echoHeaders") 31 | async def echo_headers(context: Context) -> Dict[str, str]: 32 | return context.request().headers 33 | 34 | @test_utils.handler(name="rawEcho", accept="*/*", content_type="application/octet-stream", input_serde=BytesSerde(), output_serde=BytesSerde()) 35 | async def raw_echo(context: Context, input: bytes) -> bytes: 36 | return input 37 | 38 | @test_utils.handler(name="sleepConcurrently") 39 | async def sleep_concurrently(context: Context, millis_duration: List[int]) -> None: 40 | timers = [context.sleep(timedelta(milliseconds=duration)) for duration in millis_duration] 41 | 42 | for timer in timers: 43 | await timer 44 | 45 | 46 | @test_utils.handler(name="countExecutedSideEffects") 47 | async def count_executed_side_effects(context: Context, increments: int) -> int: 48 | invoked_side_effects = 0 49 | 50 | def effect(): 51 | nonlocal invoked_side_effects 52 | invoked_side_effects += 1 53 | 54 | for _ in range(increments): 55 | await context.run("count", effect) 56 | 57 | return invoked_side_effects 58 | 59 | @test_utils.handler(name="cancelInvocation") 60 | async def cancel_invocation(context: Context, invocation_id: str) -> None: 61 | context.cancel_invocation(invocation_id) 62 | -------------------------------------------------------------------------------- /test-services/services/virtual_object_command_interpreter.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """example.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | import os 16 | from datetime import timedelta 17 | from typing import Iterable, List, Union, TypedDict, Literal, Any 18 | from restate import VirtualObject, ObjectSharedContext, ObjectContext, RestateDurableFuture, RestateDurableSleepFuture 19 | from restate import select, wait_completed, as_completed 20 | from restate.exceptions import TerminalError 21 | 22 | virtual_object_command_interpreter = VirtualObject("VirtualObjectCommandInterpreter") 23 | 24 | @virtual_object_command_interpreter.handler(name="getResults", kind="shared") 25 | async def get_results(ctx: ObjectSharedContext | ObjectContext) -> List[str]: 26 | return (await ctx.get("results")) or [] 27 | 28 | @virtual_object_command_interpreter.handler(name="hasAwakeable", kind="shared") 29 | async def has_awakeable(ctx: ObjectSharedContext, awk_key: str) -> bool: 30 | awk_id = await ctx.get("awk-" + awk_key) 31 | if awk_id: 32 | return True 33 | return False 34 | 35 | class CreateAwakeable(TypedDict): 36 | type: Literal["createAwakeable"] 37 | awakeableKey: str 38 | 39 | class Sleep(TypedDict): 40 | type: Literal["sleep"] 41 | timeoutMillis: int 42 | 43 | class RunThrowTerminalException(TypedDict): 44 | type: Literal["runThrowTerminalException"] 45 | reason: str 46 | 47 | AwaitableCommand = Union[ 48 | CreateAwakeable, 49 | Sleep, 50 | RunThrowTerminalException 51 | ] 52 | 53 | class AwaitOne(TypedDict): 54 | type: Literal["awaitOne"] 55 | command: AwaitableCommand 56 | 57 | class AwaitAnySuccessful(TypedDict): 58 | type: Literal["awaitAnySuccessful"] 59 | commands: List[AwaitableCommand] 60 | 61 | class AwaitAny(TypedDict): 62 | type: Literal["awaitAny"] 63 | commands: List[AwaitableCommand] 64 | 65 | class AwaitAwakeableOrTimeout(TypedDict): 66 | type: Literal["awaitAwakeableOrTimeout"] 67 | awakeableKey: str 68 | timeoutMillis: int 69 | 70 | class ResolveAwakeable(TypedDict): 71 | type: Literal["resolveAwakeable"] 72 | awakeableKey: str 73 | value: str 74 | 75 | class RejectAwakeable(TypedDict): 76 | type: Literal["rejectAwakeable"] 77 | awakeableKey: str 78 | reason: str 79 | 80 | class GetEnvVariable(TypedDict): 81 | type: Literal["getEnvVariable"] 82 | envName: str 83 | 84 | Command = Union[ 85 | AwaitOne, 86 | AwaitAny, 87 | AwaitAnySuccessful, 88 | AwaitAwakeableOrTimeout, 89 | ResolveAwakeable, 90 | RejectAwakeable, 91 | GetEnvVariable 92 | ] 93 | 94 | class InterpretRequest(TypedDict): 95 | commands: Iterable[Command] 96 | 97 | @virtual_object_command_interpreter.handler(name="resolveAwakeable", kind="shared") 98 | async def resolve_awakeable(ctx: ObjectSharedContext | ObjectContext, req: ResolveAwakeable): 99 | awk_id = await ctx.get("awk-" + req['awakeableKey']) 100 | if not awk_id: 101 | raise TerminalError(message="No awakeable is registered") 102 | ctx.resolve_awakeable(awk_id, req['value']) 103 | 104 | @virtual_object_command_interpreter.handler(name="rejectAwakeable", kind="shared") 105 | async def reject_awakeable(ctx: ObjectSharedContext | ObjectContext, req: RejectAwakeable): 106 | awk_id = await ctx.get("awk-" + req['awakeableKey']) 107 | if not awk_id: 108 | raise TerminalError(message="No awakeable is registered") 109 | ctx.reject_awakeable(awk_id, req['reason']) 110 | 111 | def to_durable_future(ctx: ObjectContext, cmd: AwaitableCommand) -> RestateDurableFuture[Any]: 112 | if cmd['type'] == "createAwakeable": 113 | awk_id, awakeable = ctx.awakeable() 114 | ctx.set("awk-" + cmd['awakeableKey'], awk_id) 115 | return awakeable 116 | elif cmd['type'] == "sleep": 117 | return ctx.sleep(timedelta(milliseconds=cmd['timeoutMillis'])) 118 | elif cmd['type'] == "runThrowTerminalException": 119 | def side_effect(reason): 120 | raise TerminalError(message=reason) 121 | res = ctx.run("run should fail command", side_effect, args=(cmd['reason'],)) 122 | return res 123 | 124 | @virtual_object_command_interpreter.handler(name="interpretCommands") 125 | async def interpret_commands(ctx: ObjectContext, req: InterpretRequest): 126 | result = "" 127 | 128 | for cmd in req['commands']: 129 | if cmd['type'] == "awaitAwakeableOrTimeout": 130 | awk_id, awakeable = ctx.awakeable() 131 | ctx.set("awk-" + cmd['awakeableKey'], awk_id) 132 | match await select(awakeable=awakeable, timeout=ctx.sleep(timedelta(milliseconds=cmd['timeoutMillis']))): 133 | case ['awakeable', awk_res]: 134 | result = awk_res 135 | case ['timeout', _]: 136 | raise TerminalError(message="await-timeout", status_code=500) 137 | elif cmd['type'] == "resolveAwakeable": 138 | await resolve_awakeable(ctx, cmd) 139 | result = "" 140 | elif cmd['type'] == "rejectAwakeable": 141 | await reject_awakeable(ctx, cmd) 142 | result = "" 143 | elif cmd['type'] == "getEnvVariable": 144 | env_name = cmd['envName'] 145 | result = await ctx.run("get_env", lambda e=env_name: os.environ.get(e, "")) 146 | elif cmd['type'] == "awaitOne": 147 | awaitable = to_durable_future(ctx, cmd['command']) 148 | # We need this dance because the Python SDK doesn't support .map on futures 149 | if isinstance(awaitable, RestateDurableSleepFuture): 150 | await awaitable 151 | result = "sleep" 152 | else: 153 | result = await awaitable 154 | elif cmd['type'] == "awaitAny": 155 | futures = [to_durable_future(ctx, c) for c in cmd['commands']] 156 | done, _ = await wait_completed(*futures) 157 | done_fut = done[0] 158 | # We need this dance because the Python SDK doesn't support .map on futures 159 | if isinstance(done_fut, RestateDurableSleepFuture): 160 | await done_fut 161 | result = "sleep" 162 | else: 163 | result = await done_fut 164 | elif cmd['type'] == "awaitAnySuccessful": 165 | futures = [to_durable_future(ctx, c) for c in cmd['commands']] 166 | async for done_fut in as_completed(*futures): 167 | try: 168 | # We need this dance because the Python SDK doesn't support .map on futures 169 | if isinstance(done_fut, RestateDurableSleepFuture): 170 | await done_fut 171 | result = "sleep" 172 | break 173 | result = await done_fut 174 | break 175 | except TerminalError: 176 | pass 177 | 178 | last_results = await get_results(ctx) 179 | last_results.append(result) 180 | ctx.set("results", last_results) 181 | 182 | return result 183 | 184 | -------------------------------------------------------------------------------- /test-services/testservices.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH 3 | # 4 | # This file is part of the Restate SDK for Python, 5 | # which is released under the MIT license. 6 | # 7 | # You can find a copy of the license in file LICENSE in the root 8 | # directory of this repository or package, or at 9 | # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE 10 | # 11 | """testservices.py""" 12 | # pylint: disable=C0116 13 | # pylint: disable=W0613 14 | 15 | import os 16 | import restate 17 | import services 18 | 19 | def test_services(): 20 | names = os.environ.get('SERVICES') 21 | return services.services_named(names.split(',')) if names else services.all_services() 22 | 23 | identity_keys = None 24 | e2e_signing_key_env = os.environ.get('E2E_REQUEST_SIGNING_ENV') 25 | if os.environ.get('E2E_REQUEST_SIGNING_ENV'): 26 | identity_keys = [os.environ.get('E2E_REQUEST_SIGNING_ENV')] 27 | 28 | app = restate.app(services=test_services(), identity_keys=identity_keys) 29 | -------------------------------------------------------------------------------- /tests/serde.py: -------------------------------------------------------------------------------- 1 | from restate.serde import BytesSerde 2 | 3 | def test_bytes_serde(): 4 | s = BytesSerde() 5 | assert bytes(range(20)) == s.serialize(bytes(range(20))) --------------------------------------------------------------------------------