├── .dockerignore ├── .github └── workflows │ ├── ci.yml │ ├── cla.yml │ └── integration.yaml ├── .gitignore ├── .gitmodules ├── Cargo.toml ├── LICENSE ├── README.md ├── build.rs ├── endpoint_manifest_schema.json ├── examples ├── counter.rs ├── cron.rs ├── failures.rs ├── greeter.rs ├── run.rs ├── schema.rs ├── services │ ├── mod.rs │ ├── my_service.rs │ ├── my_virtual_object.rs │ └── my_workflow.rs └── tracing.rs ├── justfile ├── macros ├── Cargo.toml └── src │ ├── ast.rs │ ├── gen.rs │ └── lib.rs ├── rust-toolchain.toml ├── src ├── context │ ├── macro_support.rs │ ├── mod.rs │ ├── request.rs │ ├── run.rs │ └── select.rs ├── discovery.rs ├── endpoint │ ├── context.rs │ ├── futures │ │ ├── async_result_poll.rs │ │ ├── durable_future_impl.rs │ │ ├── handler_state_aware.rs │ │ ├── intercept_error.rs │ │ ├── mod.rs │ │ ├── select_poll.rs │ │ └── trap.rs │ ├── handler_state.rs │ └── mod.rs ├── errors.rs ├── filter.rs ├── http_server.rs ├── hyper.rs ├── lib.rs ├── serde.rs └── service.rs ├── test-services ├── Cargo.toml ├── Dockerfile ├── README.md ├── exclusions.yaml └── src │ ├── awakeable_holder.rs │ ├── block_and_wait_workflow.rs │ ├── cancel_test.rs │ ├── counter.rs │ ├── failing.rs │ ├── kill_test.rs │ ├── list_object.rs │ ├── main.rs │ ├── map_object.rs │ ├── non_deterministic.rs │ ├── proxy.rs │ ├── test_utils_service.rs │ └── virtual_object_command_interpreter.rs ├── testcontainers ├── Cargo.toml ├── src │ └── lib.rs └── tests │ └── test_container.rs └── tests ├── compiletest.rs ├── schema.rs ├── service.rs └── ui ├── shared_handler_in_service.rs └── shared_handler_in_service.stderr /.dockerignore: -------------------------------------------------------------------------------- 1 | .github 2 | .idea 3 | target 4 | tests -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | workflow_call: 6 | workflow_dispatch: 7 | push: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build-and-test: 13 | name: Build and test (${{ matrix.os }}) 14 | runs-on: ${{ matrix.os }} 15 | permissions: 16 | contents: read 17 | packages: read 18 | timeout-minutes: 30 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | os: [ubuntu-22.04] 23 | env: 24 | RUST_BACKTRACE: full 25 | steps: 26 | - uses: actions/checkout@v4 27 | 28 | - name: Install Rust toolchain 29 | uses: actions-rust-lang/setup-rust-toolchain@v1 30 | with: 31 | components: clippy 32 | rustflags: "" 33 | 34 | - name: Install nextest 35 | uses: taiki-e/install-action@nextest 36 | 37 | - name: Setup just 38 | uses: extractions/setup-just@v2 39 | env: 40 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 41 | 42 | - name: Run verify 43 | run: just verify 44 | -------------------------------------------------------------------------------- /.github/workflows/cla.yml: -------------------------------------------------------------------------------- 1 | name: "CLA Assistant" 2 | on: 3 | issue_comment: 4 | types: [created] 5 | pull_request_target: 6 | types: [opened, closed, synchronize] 7 | 8 | jobs: 9 | CLAAssistant: 10 | uses: restatedev/restate/.github/workflows/cla.yml@main 11 | secrets: inherit 12 | -------------------------------------------------------------------------------- /.github/workflows/integration.yaml: -------------------------------------------------------------------------------- 1 | name: Integration 2 | 3 | # Controls when the workflow will run 4 | on: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | schedule: 10 | - cron: '0 */6 * * *' # Every 6 hours 11 | workflow_dispatch: 12 | inputs: 13 | restateCommit: 14 | description: 'restate commit' 15 | required: false 16 | default: '' 17 | type: string 18 | restateImage: 19 | description: 'restate image, superseded by restate commit' 20 | required: false 21 | default: 'ghcr.io/restatedev/restate:main' 22 | type: string 23 | workflow_call: 24 | inputs: 25 | restateCommit: 26 | description: 'restate commit' 27 | required: false 28 | default: '' 29 | type: string 30 | restateImage: 31 | description: 'restate image, superseded by restate commit' 32 | required: false 33 | default: 'ghcr.io/restatedev/restate:main' 34 | type: string 35 | 36 | jobs: 37 | 38 | sdk-test-suite: 39 | if: github.repository_owner == 'restatedev' 40 | runs-on: ubuntu-latest 41 | name: Features integration test 42 | permissions: 43 | contents: read 44 | issues: read 45 | checks: write 46 | pull-requests: write 47 | actions: read 48 | 49 | steps: 50 | - uses: actions/checkout@v4 51 | with: 52 | repository: restatedev/sdk-rust 53 | 54 | - name: Set up Docker containerd snapshotter 55 | uses: crazy-max/ghaction-setup-docker@v3 56 | with: 57 | set-host: true 58 | daemon-config: | 59 | { 60 | "features": { 61 | "containerd-snapshotter": true 62 | } 63 | } 64 | 65 | ### Download the Restate container image, if needed 66 | # Setup restate snapshot if necessary 67 | # Due to https://github.com/actions/upload-artifact/issues/53 68 | # We must use download-artifact to get artifacts created during *this* workflow run, ie by workflow call 69 | - name: Download restate snapshot from in-progress workflow 70 | if: ${{ inputs.restateCommit != '' && github.event_name != 'workflow_dispatch' }} 71 | uses: actions/download-artifact@v4 72 | with: 73 | name: restate.tar 74 | # In the workflow dispatch case where the artifact was created in a previous run, we can download as normal 75 | - name: Download restate snapshot from completed workflow 76 | if: ${{ inputs.restateCommit != '' && github.event_name == 'workflow_dispatch' }} 77 | uses: dawidd6/action-download-artifact@v3 78 | with: 79 | repo: restatedev/restate 80 | workflow: ci.yml 81 | commit: ${{ inputs.restateCommit }} 82 | name: restate.tar 83 | - name: Install restate snapshot 84 | if: ${{ inputs.restateCommit != '' }} 85 | run: | 86 | output=$(docker load --input restate.tar | head -n 1) 87 | docker tag "${output#*: }" "localhost/restatedev/restate-commit-download:latest" 88 | docker image ls -a 89 | 90 | - name: Set up QEMU 91 | uses: docker/setup-qemu-action@v3 92 | - name: Set up Docker Buildx 93 | uses: docker/setup-buildx-action@v3 94 | 95 | - name: Build Rust test-services image 96 | id: build 97 | uses: docker/build-push-action@v6 98 | with: 99 | context: . 100 | file: "test-services/Dockerfile" 101 | push: false 102 | load: true 103 | tags: restatedev/rust-test-services 104 | cache-from: type=gha,scope=${{ github.workflow }} 105 | cache-to: type=gha,mode=max,scope=${{ github.workflow }} 106 | 107 | - name: Run test tool 108 | uses: restatedev/sdk-test-suite@v3.0 109 | with: 110 | restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }} 111 | serviceContainerImage: "restatedev/rust-test-services" 112 | exclusionsFile: "test-services/exclusions.yaml" 113 | testArtifactOutput: "sdk-rust-integration-test-report" 114 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | # Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | 12 | # Jetbrains IDEs 13 | .idea 14 | *.iml 15 | 16 | Cargo.lock -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "restate-proto"] 2 | path = restate-proto 3 | url = ../proto.git -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "restate-sdk" 3 | version = "0.5.0" 4 | edition = "2021" 5 | description = "Restate SDK for Rust" 6 | license = "MIT" 7 | repository = "https://github.com/restatedev/sdk-rust" 8 | rust-version = "1.76.0" 9 | 10 | [[example]] 11 | name = "tracing" 12 | path = "examples/tracing.rs" 13 | required-features = ["tracing-span-filter"] 14 | 15 | [[example]] 16 | name = "schema" 17 | path = "examples/schema.rs" 18 | required-features = ["schemars"] 19 | 20 | [features] 21 | default = ["http_server", "rand", "uuid", "tracing-span-filter"] 22 | hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"] 23 | http_server = ["hyper", "hyper/server", "hyper/http2", "hyper-util", "tokio/net", "tokio/signal", "tokio/macros"] 24 | tracing-span-filter = ["dep:tracing-subscriber"] 25 | 26 | [dependencies] 27 | bytes = "1.10" 28 | futures = "0.3" 29 | http = "1.3" 30 | http-body-util = { version = "0.1", optional = true } 31 | hyper = { version = "1.6", optional = true} 32 | hyper-util = { version = "0.1", features = ["tokio", "server", "server-graceful", "http2"], optional = true } 33 | pin-project-lite = "0.2" 34 | rand = { version = "0.9", optional = true } 35 | regress = "0.10" 36 | restate-sdk-macros = { version = "0.5", path = "macros" } 37 | restate-sdk-shared-core = { version = "0.3.0", features = ["request_identity", "sha2_random_seed", "http"] } 38 | schemars = { version = "1.0.0-alpha.17", optional = true } 39 | serde = "1.0" 40 | serde_json = "1.0" 41 | thiserror = "2.0" 42 | tokio = { version = "1.44", default-features = false, features = ["sync"] } 43 | tracing = "0.1" 44 | tracing-subscriber = { version = "0.3", features = ["registry"], optional = true } 45 | uuid = { version = "1.16.0", optional = true } 46 | 47 | [dev-dependencies] 48 | tokio = { version = "1", features = ["full"] } 49 | tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] } 50 | trybuild = "1.0" 51 | reqwest = { version = "0.12", features = ["json"] } 52 | rand = "0.9" 53 | schemars = "1.0.0-alpha.17" 54 | 55 | [build-dependencies] 56 | jsonptr = "0.5.1" 57 | prettyplease = "0.2" 58 | serde_json = { version = "1.0" } 59 | syn = "2.0" 60 | typify = { version = "0.1.0" } 61 | 62 | [workspace] 63 | members = ["macros", "test-services", "testcontainers"] 64 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 - Restate Software, Inc., Restate GmbH 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Documentation](https://img.shields.io/docsrs/restate-sdk)](https://docs.rs/restate-sdk) 2 | [![crates.io](https://img.shields.io/crates/v/restate_sdk.svg)](https://crates.io/crates/restate-sdk/) 3 | [![Examples](https://img.shields.io/badge/view-examples-blue)](https://github.com/restatedev/examples) 4 | [![Discord](https://img.shields.io/discord/1128210118216007792?logo=discord)](https://discord.gg/skW3AZ6uGd) 5 | [![Twitter](https://img.shields.io/twitter/follow/restatedev.svg?style=social&label=Follow)](https://twitter.com/intent/follow?screen_name=restatedev) 6 | 7 | # Restate Rust SDK 8 | 9 | [Restate](https://restate.dev/) is a system for easily building resilient applications using _distributed durable async/await_. This repository contains the Restate SDK for writing services using Rust. 10 | 11 | ## Community 12 | 13 | * 🤗️ [Join our online community](https://discord.gg/skW3AZ6uGd) for help, sharing feedback and talking to the community. 14 | * 📖 [Check out our documentation](https://docs.restate.dev) to get quickly started! 15 | * 📣 [Follow us on Twitter](https://twitter.com/restatedev) for staying up to date. 16 | * 🙋 [Create a GitHub issue](https://github.com/restatedev/sdk-java/issues) for requesting a new feature or reporting a problem. 17 | * 🏠 [Visit our GitHub org](https://github.com/restatedev) for exploring other repositories. 18 | 19 | ## Using the SDK 20 | 21 | Add Restate and Tokio as dependencies: 22 | 23 | ```toml 24 | [dependencies] 25 | restate-sdk = "0.1" 26 | tokio = { version = "1", features = ["full"] } 27 | ``` 28 | 29 | Then you're ready to develop your Restate service using Rust: 30 | 31 | ```rust 32 | use restate_sdk::prelude::*; 33 | 34 | #[restate_sdk::service] 35 | trait Greeter { 36 | async fn greet(name: String) -> HandlerResult; 37 | } 38 | 39 | struct GreeterImpl; 40 | 41 | impl Greeter for GreeterImpl { 42 | async fn greet(&self, _: Context<'_>, name: String) -> HandlerResult { 43 | Ok(format!("Greetings {name}")) 44 | } 45 | } 46 | 47 | #[tokio::main] 48 | async fn main() { 49 | // To enable logging/tracing 50 | // tracing_subscriber::fmt::init(); 51 | HttpServer::new( 52 | Endpoint::builder() 53 | .with_service(GreeterImpl.serve()) 54 | .build(), 55 | ) 56 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 57 | .await; 58 | } 59 | ``` 60 | 61 | ### Logging 62 | 63 | The SDK uses tokio's [`tracing`](https://docs.rs/tracing/latest/tracing/) crate to generate logs. 64 | Just configure it as usual through [`tracing_subscriber`](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) to get your logs. 65 | 66 | ### Testing 67 | 68 | The SDK uses [Testcontainers](https://rust.testcontainers.org/) to support integration testing using a Docker-deployed restate server. 69 | The `restate-sdk-testcontainers` crate provides a framework for initializing the test environment, and an integration test example in `testcontainers/tests/test_container.rs`. 70 | 71 | ```rust 72 | #[tokio::test] 73 | async fn test_container() { 74 | tracing_subscriber::fmt::fmt() 75 | .with_max_level(tracing::Level::INFO) // Set the maximum log level 76 | .init(); 77 | 78 | let endpoint = Endpoint::builder().bind(MyServiceImpl.serve()).build(); 79 | 80 | // simple test container intialization with default configuration 81 | //let test_container = TestContainer::default().start(endpoint).await.unwrap(); 82 | 83 | // custom test container initialization with builder 84 | let test_container = TestContainer::builder() 85 | // optional passthrough logging from the resstate server testcontainer 86 | // prints container logs to tracing::info level 87 | .with_container_logging() 88 | .with_container( 89 | "docker.io/restatedev/restate".to_string(), 90 | "latest".to_string(), 91 | ) 92 | .build() 93 | .start(endpoint) 94 | .await 95 | .unwrap(); 96 | 97 | let ingress_url = test_container.ingress_url(); 98 | 99 | // call container ingress url for /MyService/my_handler 100 | let response = reqwest::Client::new() 101 | .post(format!("{}/MyService/my_handler", ingress_url)) 102 | .header("Accept", "application/json") 103 | .header("Content-Type", "*/*") 104 | .header("idempotency-key", "abc") 105 | .send() 106 | .await 107 | .unwrap(); 108 | 109 | assert_eq!(response.status(), StatusCode::OK); 110 | 111 | info!( 112 | "/MyService/my_handler response: {:?}", 113 | response.text().await.unwrap() 114 | ); 115 | } 116 | ``` 117 | 118 | ## Versions 119 | 120 | The Rust SDK is currently in active development, and might break across releases. 121 | 122 | The compatibility with Restate is described in the following table: 123 | 124 | | Restate Server\sdk-rust | 0.0 - 0.2 | 0.3 | 0.4 - 0.5 | 125 | |-------------------------|-----------|-----|-----------| 126 | | 1.0 | ✅ | ❌ | ❌ | 127 | | 1.1 | ✅ | ✅ | ❌ | 128 | | 1.2 | ✅ | ✅ | ❌ | 129 | | 1.3 | ✅ | ✅ | ✅ | 130 | 131 | ## Contributing 132 | 133 | We’re excited if you join the Restate community and start contributing! 134 | Whether it is feature requests, bug reports, ideas & feedback or PRs, we appreciate any and all contributions. 135 | We know that your time is precious and, therefore, deeply value any effort to contribute! 136 | 137 | ### Building the SDK locally 138 | 139 | Prerequisites: 140 | 141 | - [Rust](https://rustup.rs/) 142 | - [Just](https://github.com/casey/just) 143 | 144 | To build and test the SDK: 145 | 146 | ```shell 147 | just verify 148 | ``` 149 | 150 | ### Releasing 151 | 152 | You need the [Rust toolchain](https://rustup.rs/). To verify: 153 | 154 | ``` 155 | just verify 156 | ``` 157 | 158 | To release we use [cargo-release](https://github.com/crate-ci/cargo-release): 159 | 160 | ``` 161 | cargo release --exclude test-services --workspace 162 | ``` 163 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 - Restate Software, Inc., Restate GmbH. 2 | // All rights reserved. 3 | // 4 | // Use of this software is governed by the Business Source License 5 | // included in the LICENSE file. 6 | // 7 | // As of the Change Date specified in that file, in accordance with 8 | // the Business Source License, use of this software will be governed 9 | // by the Apache License, Version 2.0. 10 | 11 | use jsonptr::Pointer; 12 | use std::env; 13 | use std::fs::File; 14 | use std::path::PathBuf; 15 | use typify::{TypeSpace, TypeSpaceSettings}; 16 | 17 | fn main() -> std::io::Result<()> { 18 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 19 | 20 | let mut parsed_content: serde_json::Value = 21 | serde_json::from_reader(File::open("./endpoint_manifest_schema.json").unwrap()).unwrap(); 22 | 23 | // Patch schema for https://github.com/oxidecomputer/typify/issues/531 24 | // We can get rid of this once the issue in typify is solved. 25 | Pointer::from_static( 26 | "/properties/services/items/properties/handlers/items/properties/input/default", 27 | ) 28 | .delete(&mut parsed_content); 29 | Pointer::from_static( 30 | "/properties/services/items/properties/handlers/items/properties/input/examples", 31 | ) 32 | .delete(&mut parsed_content); 33 | Pointer::from_static( 34 | "/properties/services/items/properties/handlers/items/properties/output/default", 35 | ) 36 | .delete(&mut parsed_content); 37 | Pointer::from_static( 38 | "/properties/services/items/properties/handlers/items/properties/output/examples", 39 | ) 40 | .delete(&mut parsed_content); 41 | 42 | // Instantiate type space and run code-generation 43 | let mut type_space = 44 | TypeSpace::new(TypeSpaceSettings::default().with_derive("Clone".to_owned())); 45 | type_space 46 | .add_root_schema(serde_json::from_value(parsed_content).unwrap()) 47 | .unwrap(); 48 | 49 | let contents = format!( 50 | "{}\n{}", 51 | "use serde::{Deserialize, Serialize};", 52 | prettyplease::unparse(&syn::parse2::(type_space.to_stream()).unwrap()) 53 | ); 54 | 55 | std::fs::write(out_dir.join("endpoint_manifest.rs"), contents) 56 | } 57 | -------------------------------------------------------------------------------- /endpoint_manifest_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$id": "https://restate.dev/endpoint.manifest.json", 3 | "$schema": "https://json-schema.org/draft/2020-12/schema", 4 | "type": "object", 5 | "title": "Endpoint", 6 | "description": "Restate endpoint manifest v1", 7 | "properties": { 8 | "protocolMode": { 9 | "title": "ProtocolMode", 10 | "enum": ["BIDI_STREAM", "REQUEST_RESPONSE"] 11 | }, 12 | "minProtocolVersion": { 13 | "type": "integer", 14 | "minimum": 1, 15 | "maximum": 2147483647, 16 | "description": "Minimum supported protocol version" 17 | }, 18 | "maxProtocolVersion": { 19 | "type": "integer", 20 | "minimum": 1, 21 | "maximum": 2147483647, 22 | "description": "Maximum supported protocol version" 23 | }, 24 | "services": { 25 | "type": "array", 26 | "items": { 27 | "type": "object", 28 | "title": "Service", 29 | "properties": { 30 | "name": { 31 | "type": "string", 32 | "pattern": "^([a-zA-Z]|_[a-zA-Z0-9])[a-zA-Z0-9._-]*$" 33 | }, 34 | "ty": { 35 | "title": "ServiceType", 36 | "enum": ["VIRTUAL_OBJECT", "SERVICE", "WORKFLOW"] 37 | }, 38 | "handlers": { 39 | "type": "array", 40 | "items": { 41 | "type": "object", 42 | "title": "Handler", 43 | "properties": { 44 | "name": { 45 | "type": "string", 46 | "pattern": "^([a-zA-Z]|_[a-zA-Z0-9])[a-zA-Z0-9_]*$" 47 | }, 48 | "ty": { 49 | "title": "HandlerType", 50 | "enum": ["WORKFLOW", "EXCLUSIVE", "SHARED"], 51 | "description": "If unspecified, defaults to EXCLUSIVE for Virtual Object or WORKFLOW for Workflows. This should be unset for Services." 52 | }, 53 | "input": { 54 | "type": "object", 55 | "title": "InputPayload", 56 | "description": "Description of an input payload. This will be used by Restate to validate incoming requests.", 57 | "properties": { 58 | "required": { 59 | "type": "boolean", 60 | "description": "If true, a body MUST be sent with a content-type, even if the body length is zero." 61 | }, 62 | "contentType": { 63 | "type": "string", 64 | "description": "Content type of the input. It can accept wildcards, in the same format as the 'Accept' header. When this field is unset, it implies emptiness, meaning no content-type/body is expected." 65 | }, 66 | "jsonSchema": {} 67 | }, 68 | "additionalProperties": false, 69 | "default": { 70 | "contentType": "*/*", 71 | "required": false 72 | }, 73 | "examples": { 74 | "empty input": {}, 75 | "non empty json input": { 76 | "required": true, 77 | "contentType": "application/json", 78 | "jsonSchema": true 79 | }, 80 | "either empty or non empty json input": { 81 | "required": false, 82 | "contentType": "application/json", 83 | "jsonSchema": true 84 | }, 85 | "bytes input": { 86 | "required": true, 87 | "contentType": "application/octet-stream" 88 | } 89 | } 90 | }, 91 | "output": { 92 | "type": "object", 93 | "title": "OutputPayload", 94 | "description": "Description of an output payload.", 95 | "properties": { 96 | "contentType": { 97 | "type": "string", 98 | "description": "Content type set on output. This will be used by Restate to set the output content type at the ingress." 99 | }, 100 | "setContentTypeIfEmpty": { 101 | "type": "boolean", 102 | "description": "If true, the specified content-type is set even if the output is empty." 103 | }, 104 | "jsonSchema": {} 105 | }, 106 | "additionalProperties": false, 107 | "default": { 108 | "contentType": "application/json", 109 | "setContentTypeIfEmpty": false 110 | }, 111 | "examples": { 112 | "empty output": { 113 | "setContentTypeIfEmpty": false 114 | }, 115 | "non-empty json output": { 116 | "contentType": "application/json", 117 | "setContentTypeIfEmpty": false, 118 | "jsonSchema": true 119 | }, 120 | "protobuf output": { 121 | "contentType": "application/proto", 122 | "setContentTypeIfEmpty": true 123 | } 124 | } 125 | } 126 | }, 127 | "required": ["name"], 128 | "additionalProperties": false 129 | } 130 | } 131 | }, 132 | "required": ["name", "ty", "handlers"], 133 | "additionalProperties": false 134 | } 135 | } 136 | }, 137 | "required": ["minProtocolVersion", "maxProtocolVersion", "services"], 138 | "additionalProperties": false 139 | } 140 | -------------------------------------------------------------------------------- /examples/counter.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | #[restate_sdk::object] 4 | trait Counter { 5 | #[shared] 6 | async fn get() -> Result; 7 | async fn add(val: u64) -> Result; 8 | async fn increment() -> Result; 9 | async fn reset() -> Result<(), TerminalError>; 10 | } 11 | 12 | struct CounterImpl; 13 | 14 | const COUNT: &str = "count"; 15 | 16 | impl Counter for CounterImpl { 17 | async fn get(&self, ctx: SharedObjectContext<'_>) -> Result { 18 | Ok(ctx.get::(COUNT).await?.unwrap_or(0)) 19 | } 20 | 21 | async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> Result { 22 | let current = ctx.get::(COUNT).await?.unwrap_or(0); 23 | let new = current + val; 24 | ctx.set(COUNT, new); 25 | Ok(new) 26 | } 27 | 28 | async fn increment(&self, ctx: ObjectContext<'_>) -> Result { 29 | self.add(ctx, 1).await 30 | } 31 | 32 | async fn reset(&self, ctx: ObjectContext<'_>) -> Result<(), TerminalError> { 33 | ctx.clear(COUNT); 34 | Ok(()) 35 | } 36 | } 37 | 38 | #[tokio::main] 39 | async fn main() { 40 | tracing_subscriber::fmt::init(); 41 | HttpServer::new(Endpoint::builder().bind(CounterImpl.serve()).build()) 42 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 43 | .await; 44 | } 45 | -------------------------------------------------------------------------------- /examples/cron.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | use std::time::Duration; 3 | 4 | /// This example shows how to implement a periodic task, by invoking itself in a loop. 5 | /// 6 | /// The `start()` handler schedules the first call to `run()`, and then each `run()` will re-schedule itself. 7 | /// 8 | /// To "break" the loop, we use a flag we persist in state, which is removed when `stop()` is invoked. 9 | /// Its presence determines whether the task is active or not. 10 | /// 11 | /// To start it: 12 | /// 13 | /// ```shell 14 | /// $ curl -v http://localhost:8080/PeriodicTask/my-periodic-task/start 15 | /// ``` 16 | #[restate_sdk::object] 17 | trait PeriodicTask { 18 | /// Schedules the periodic task to start 19 | async fn start() -> Result<(), TerminalError>; 20 | /// Stops the periodic task 21 | async fn stop() -> Result<(), TerminalError>; 22 | /// Business logic of the periodic task 23 | async fn run() -> Result<(), TerminalError>; 24 | } 25 | 26 | struct PeriodicTaskImpl; 27 | 28 | const ACTIVE: &str = "active"; 29 | 30 | impl PeriodicTask for PeriodicTaskImpl { 31 | async fn start(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { 32 | if context 33 | .get::(ACTIVE) 34 | .await? 35 | .is_some_and(|enabled| enabled) 36 | { 37 | // If it's already activated, just do nothing 38 | return Ok(()); 39 | } 40 | 41 | // Schedule the periodic task 42 | PeriodicTaskImpl::schedule_next(&context); 43 | 44 | // Mark the periodic task as active 45 | context.set(ACTIVE, true); 46 | 47 | Ok(()) 48 | } 49 | 50 | async fn stop(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { 51 | // Remove the active flag 52 | context.clear(ACTIVE); 53 | 54 | Ok(()) 55 | } 56 | 57 | async fn run(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { 58 | if context.get::(ACTIVE).await?.is_none() { 59 | // Task is inactive, do nothing 60 | return Ok(()); 61 | } 62 | 63 | // --- Periodic task business logic! 64 | println!("Triggered the periodic task!"); 65 | 66 | // Schedule the periodic task 67 | PeriodicTaskImpl::schedule_next(&context); 68 | 69 | Ok(()) 70 | } 71 | } 72 | 73 | impl PeriodicTaskImpl { 74 | fn schedule_next(context: &ObjectContext<'_>) { 75 | // To schedule, create a client to the callee handler (in this case, we're calling ourselves) 76 | context 77 | .object_client::(context.key()) 78 | .run() 79 | // And send with a delay 80 | .send_after(Duration::from_secs(10)); 81 | } 82 | } 83 | 84 | #[tokio::main] 85 | async fn main() { 86 | tracing_subscriber::fmt::init(); 87 | HttpServer::new(Endpoint::builder().bind(PeriodicTaskImpl.serve()).build()) 88 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 89 | .await; 90 | } 91 | -------------------------------------------------------------------------------- /examples/failures.rs: -------------------------------------------------------------------------------- 1 | use rand::RngCore; 2 | use restate_sdk::prelude::*; 3 | 4 | #[restate_sdk::service] 5 | trait FailureExample { 6 | #[name = "doRun"] 7 | async fn do_run() -> Result<(), TerminalError>; 8 | } 9 | 10 | struct FailureExampleImpl; 11 | 12 | #[derive(Debug, thiserror::Error)] 13 | #[error("I'm very bad, retry me")] 14 | struct MyError; 15 | 16 | impl FailureExample for FailureExampleImpl { 17 | async fn do_run(&self, context: Context<'_>) -> Result<(), TerminalError> { 18 | context 19 | .run::<_, _, ()>(|| async move { 20 | if rand::rng().next_u32() % 4 == 0 { 21 | Err(TerminalError::new("Failed!!!"))? 22 | } 23 | 24 | Err(MyError)? 25 | }) 26 | .await?; 27 | 28 | Ok(()) 29 | } 30 | } 31 | 32 | #[tokio::main] 33 | async fn main() { 34 | tracing_subscriber::fmt::init(); 35 | HttpServer::new(Endpoint::builder().bind(FailureExampleImpl.serve()).build()) 36 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 37 | .await; 38 | } 39 | -------------------------------------------------------------------------------- /examples/greeter.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | use std::convert::Infallible; 3 | 4 | #[restate_sdk::service] 5 | trait Greeter { 6 | async fn greet(name: String) -> Result; 7 | } 8 | 9 | struct GreeterImpl; 10 | 11 | impl Greeter for GreeterImpl { 12 | async fn greet(&self, _: Context<'_>, name: String) -> Result { 13 | Ok(format!("Greetings {name}")) 14 | } 15 | } 16 | 17 | #[tokio::main] 18 | async fn main() { 19 | tracing_subscriber::fmt::init(); 20 | HttpServer::new(Endpoint::builder().bind(GreeterImpl.serve()).build()) 21 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 22 | .await; 23 | } 24 | -------------------------------------------------------------------------------- /examples/run.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | use std::collections::HashMap; 3 | 4 | #[restate_sdk::service] 5 | trait RunExample { 6 | async fn do_run() -> Result>, HandlerError>; 7 | } 8 | 9 | struct RunExampleImpl(reqwest::Client); 10 | 11 | impl RunExample for RunExampleImpl { 12 | async fn do_run( 13 | &self, 14 | context: Context<'_>, 15 | ) -> Result>, HandlerError> { 16 | let res = context 17 | .run(|| async move { 18 | let req = self.0.get("https://httpbin.org/ip").build()?; 19 | 20 | let res = self 21 | .0 22 | .execute(req) 23 | .await? 24 | .json::>() 25 | .await?; 26 | 27 | Ok(Json::from(res)) 28 | }) 29 | .name("get_ip") 30 | .await? 31 | .into_inner(); 32 | 33 | Ok(res.into()) 34 | } 35 | } 36 | 37 | #[tokio::main] 38 | async fn main() { 39 | tracing_subscriber::fmt::init(); 40 | HttpServer::new( 41 | Endpoint::builder() 42 | .bind(RunExampleImpl(reqwest::Client::new()).serve()) 43 | .build(), 44 | ) 45 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 46 | .await; 47 | } 48 | -------------------------------------------------------------------------------- /examples/schema.rs: -------------------------------------------------------------------------------- 1 | //! Run with auto-generated schemas for `Json` using `schemars`: 2 | //! cargo run --example schema --features schemars 3 | //! 4 | //! Run with primitive schemas only: 5 | //! cargo run --example schema 6 | 7 | use restate_sdk::prelude::*; 8 | use schemars::JsonSchema; 9 | use serde::{Deserialize, Serialize}; 10 | use std::time::Duration; 11 | 12 | #[derive(Serialize, Deserialize, JsonSchema)] 13 | struct Product { 14 | id: String, 15 | name: String, 16 | price_cents: u32, 17 | } 18 | 19 | #[restate_sdk::service] 20 | trait CatalogService { 21 | async fn get_product_by_id(product_id: String) -> Result, HandlerError>; 22 | async fn save_product(product: Json) -> Result; 23 | async fn is_in_stock(product_id: String) -> Result; 24 | } 25 | 26 | struct CatalogServiceImpl; 27 | 28 | impl CatalogService for CatalogServiceImpl { 29 | async fn get_product_by_id( 30 | &self, 31 | ctx: Context<'_>, 32 | product_id: String, 33 | ) -> Result, HandlerError> { 34 | ctx.sleep(Duration::from_millis(50)).await?; 35 | Ok(Json(Product { 36 | id: product_id, 37 | name: "Sample Product".to_string(), 38 | price_cents: 1995, 39 | })) 40 | } 41 | 42 | async fn save_product( 43 | &self, 44 | _ctx: Context<'_>, 45 | product: Json, 46 | ) -> Result { 47 | Ok(product.0.id) 48 | } 49 | 50 | async fn is_in_stock( 51 | &self, 52 | _ctx: Context<'_>, 53 | product_id: String, 54 | ) -> Result { 55 | Ok(!product_id.contains("out-of-stock")) 56 | } 57 | } 58 | 59 | #[tokio::main] 60 | async fn main() { 61 | tracing_subscriber::fmt::init(); 62 | HttpServer::new(Endpoint::builder().bind(CatalogServiceImpl.serve()).build()) 63 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 64 | .await; 65 | } 66 | -------------------------------------------------------------------------------- /examples/services/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod my_service; 2 | pub mod my_virtual_object; 3 | pub mod my_workflow; -------------------------------------------------------------------------------- /examples/services/my_service.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | #[restate_sdk::service] 4 | pub trait MyService { 5 | async fn my_handler(greeting: String) -> Result; 6 | } 7 | 8 | pub struct MyServiceImpl; 9 | 10 | impl MyService for MyServiceImpl { 11 | async fn my_handler(&self, _ctx: Context<'_>, greeting: String) -> Result { 12 | Ok(format!("{greeting}!")) 13 | } 14 | } 15 | 16 | #[tokio::main] 17 | async fn main() { 18 | tracing_subscriber::fmt::init(); 19 | HttpServer::new(Endpoint::builder().bind(MyServiceImpl.serve()).build()) 20 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 21 | .await; 22 | } 23 | -------------------------------------------------------------------------------- /examples/services/my_virtual_object.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | #[restate_sdk::object] 4 | pub trait MyVirtualObject { 5 | async fn my_handler(name: String) -> Result; 6 | #[shared] 7 | async fn my_concurrent_handler(name: String) -> Result; 8 | } 9 | 10 | pub struct MyVirtualObjectImpl; 11 | 12 | impl MyVirtualObject for MyVirtualObjectImpl { 13 | async fn my_handler( 14 | &self, 15 | ctx: ObjectContext<'_>, 16 | greeting: String, 17 | ) -> Result { 18 | Ok(format!("Greetings {} {}", greeting, ctx.key())) 19 | } 20 | async fn my_concurrent_handler( 21 | &self, 22 | ctx: SharedObjectContext<'_>, 23 | greeting: String, 24 | ) -> Result { 25 | Ok(format!("Greetings {} {}", greeting, ctx.key())) 26 | } 27 | } 28 | 29 | #[tokio::main] 30 | async fn main() { 31 | tracing_subscriber::fmt::init(); 32 | HttpServer::new( 33 | Endpoint::builder() 34 | .bind(MyVirtualObjectImpl.serve()) 35 | .build(), 36 | ) 37 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 38 | .await; 39 | } 40 | -------------------------------------------------------------------------------- /examples/services/my_workflow.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | #[restate_sdk::workflow] 4 | pub trait MyWorkflow { 5 | async fn run(req: String) -> Result; 6 | #[shared] 7 | async fn interact_with_workflow() -> Result<(), HandlerError>; 8 | } 9 | 10 | pub struct MyWorkflowImpl; 11 | 12 | impl MyWorkflow for MyWorkflowImpl { 13 | async fn run(&self, _ctx: WorkflowContext<'_>, _req: String) -> Result { 14 | // implement workflow logic here 15 | 16 | Ok(String::from("success")) 17 | } 18 | async fn interact_with_workflow( 19 | &self, 20 | _ctx: SharedWorkflowContext<'_>, 21 | ) -> Result<(), HandlerError> { 22 | // implement interaction logic here 23 | // e.g. resolve a promise that the workflow is waiting on 24 | 25 | Ok(()) 26 | } 27 | } 28 | 29 | #[tokio::main] 30 | async fn main() { 31 | tracing_subscriber::fmt::init(); 32 | HttpServer::new(Endpoint::builder().bind(MyWorkflowImpl.serve()).build()) 33 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 34 | .await; 35 | } 36 | -------------------------------------------------------------------------------- /examples/tracing.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | use std::time::Duration; 3 | use tracing::info; 4 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; 5 | 6 | #[restate_sdk::service] 7 | trait Greeter { 8 | async fn greet(name: String) -> Result; 9 | } 10 | 11 | struct GreeterImpl; 12 | 13 | impl Greeter for GreeterImpl { 14 | async fn greet(&self, ctx: Context<'_>, name: String) -> Result { 15 | info!("Before sleep"); 16 | ctx.sleep(Duration::from_secs(61)).await?; // More than suspension timeout to trigger replay 17 | info!("After sleep"); 18 | Ok(format!("Greetings {name}")) 19 | } 20 | } 21 | 22 | #[tokio::main] 23 | async fn main() { 24 | let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() 25 | .unwrap_or_else(|_| "info,restate_sdk=debug".into()); 26 | let replay_filter = restate_sdk::filter::ReplayAwareFilter; 27 | tracing_subscriber::registry() 28 | .with( 29 | tracing_subscriber::fmt::layer() 30 | .with_filter(env_filter) 31 | .with_filter(replay_filter), 32 | ) 33 | .init(); 34 | HttpServer::new(Endpoint::builder().bind(GreeterImpl.serve()).build()) 35 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 36 | .await; 37 | } 38 | -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | features := "" 2 | libc := "gnu" 3 | arch := "" # use the default architecture 4 | os := "" # use the default os 5 | 6 | _features := if features == "all" { 7 | "--all-features" 8 | } else if features != "" { 9 | "--features=" + features 10 | } else { "" } 11 | 12 | _arch := if arch == "" { 13 | arch() 14 | } else if arch == "amd64" { 15 | "x86_64" 16 | } else if arch == "x86_64" { 17 | "x86_64" 18 | } else if arch == "arm64" { 19 | "aarch64" 20 | } else if arch == "aarch64" { 21 | "aarch64" 22 | } else { 23 | error("unsupported arch=" + arch) 24 | } 25 | 26 | _os := if os == "" { 27 | os() 28 | } else { 29 | os 30 | } 31 | 32 | _os_target := if _os == "macos" { 33 | "apple-darwin" 34 | } else if _os == "linux" { 35 | "unknown-linux" 36 | } else { 37 | error("unsupported os=" + _os) 38 | } 39 | 40 | _default_target := `rustc -vV | sed -n 's|host: ||p'` 41 | target := _arch + "-" + _os_target + if _os == "linux" { "-" + libc } else { "" } 42 | _resolved_target := if target != _default_target { target } else { "" } 43 | _target-option := if _resolved_target != "" { "--target " + _resolved_target } else { "" } 44 | 45 | clean: 46 | cargo clean 47 | 48 | fmt: 49 | cargo fmt --all 50 | 51 | check-fmt: 52 | cargo fmt --all -- --check 53 | 54 | clippy: (_target-installed target) 55 | cargo clippy {{ _target-option }} --all-targets --workspace -- -D warnings 56 | 57 | # Runs all lints (fmt, clippy, deny) 58 | lint: check-fmt clippy 59 | 60 | build *flags: (_target-installed target) 61 | cargo build {{ _target-option }} {{ _features }} {{ flags }} 62 | 63 | print-target: 64 | @echo {{ _resolved_target }} 65 | 66 | test: (_target-installed target) 67 | cargo nextest run {{ _target-option }} --all-features --workspace 68 | 69 | doctest: 70 | cargo test --doc 71 | 72 | # Runs lints and tests 73 | verify: lint test doctest 74 | 75 | udeps *flags: 76 | RUSTC_BOOTSTRAP=1 cargo udeps --all-features --all-targets {{ flags }} 77 | 78 | _target-installed target: 79 | #!/usr/bin/env bash 80 | set -euo pipefail 81 | if ! rustup target list --installed |grep -qF '{{ target }}' 2>/dev/null ; then 82 | rustup target add '{{ target }}' 83 | fi 84 | -------------------------------------------------------------------------------- /macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "restate-sdk-macros" 3 | version = "0.5.0" 4 | edition = "2021" 5 | description = "Restate SDK for Rust macros" 6 | license = "MIT" 7 | repository = "https://github.com/restatedev/sdk-rust" 8 | 9 | [lib] 10 | proc-macro = true 11 | 12 | [dependencies] 13 | proc-macro2 = "1.0" 14 | quote = "1.0" 15 | syn = { version = "2.0", features = ["full"] } 16 | -------------------------------------------------------------------------------- /macros/src/ast.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 - Restate Software, Inc., Restate GmbH. 2 | // All rights reserved. 3 | // 4 | // Use of this software is governed by the Business Source License 5 | // included in the LICENSE file. 6 | // 7 | // As of the Change Date specified in that file, in accordance with 8 | // the Business Source License, use of this software will be governed 9 | // by the Apache License, Version 2.0. 10 | 11 | // Some parts copied from https://github.com/dtolnay/thiserror/blob/39aaeb00ff270a49e3c254d7b38b10e934d3c7a5/impl/src/ast.rs 12 | // License Apache-2.0 or MIT 13 | 14 | use syn::ext::IdentExt; 15 | use syn::parse::{Parse, ParseStream}; 16 | use syn::spanned::Spanned; 17 | use syn::token::Comma; 18 | use syn::{ 19 | braced, parenthesized, parse_quote, Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, 20 | Ident, Lit, Pat, PatType, Path, PathArguments, Result, ReturnType, Token, Type, Visibility, 21 | }; 22 | 23 | /// Accumulates multiple errors into a result. 24 | /// Only use this for recoverable errors, i.e. non-parse errors. Fatal errors should early exit to 25 | /// avoid further complications. 26 | macro_rules! extend_errors { 27 | ($errors: ident, $e: expr) => { 28 | match $errors { 29 | Ok(_) => $errors = Err($e), 30 | Err(ref mut errors) => errors.extend($e), 31 | } 32 | }; 33 | } 34 | 35 | #[derive(Clone, Copy, Debug, Eq, PartialEq)] 36 | pub(crate) enum ServiceType { 37 | Service, 38 | Object, 39 | Workflow, 40 | } 41 | 42 | pub(crate) struct Service(pub(crate) ServiceInner); 43 | 44 | impl Parse for Service { 45 | fn parse(input: ParseStream) -> Result { 46 | Ok(Service(ServiceInner::parse(ServiceType::Service, input)?)) 47 | } 48 | } 49 | 50 | pub(crate) struct Object(pub(crate) ServiceInner); 51 | 52 | impl Parse for Object { 53 | fn parse(input: ParseStream) -> Result { 54 | Ok(Object(ServiceInner::parse(ServiceType::Object, input)?)) 55 | } 56 | } 57 | 58 | pub(crate) struct Workflow(pub(crate) ServiceInner); 59 | 60 | impl Parse for Workflow { 61 | fn parse(input: ParseStream) -> Result { 62 | Ok(Workflow(ServiceInner::parse(ServiceType::Workflow, input)?)) 63 | } 64 | } 65 | 66 | pub(crate) struct ServiceInner { 67 | pub(crate) attrs: Vec, 68 | pub(crate) restate_name: String, 69 | pub(crate) vis: Visibility, 70 | pub(crate) ident: Ident, 71 | pub(crate) handlers: Vec, 72 | } 73 | 74 | impl ServiceInner { 75 | fn parse(service_type: ServiceType, input: ParseStream) -> Result { 76 | let parsed_attrs = input.call(Attribute::parse_outer)?; 77 | let vis = input.parse()?; 78 | input.parse::()?; 79 | let ident: Ident = input.parse()?; 80 | let content; 81 | braced!(content in input); 82 | let mut rpcs = Vec::::new(); 83 | while !content.is_empty() { 84 | let h: Handler = content.parse()?; 85 | 86 | if h.is_shared && service_type == ServiceType::Service { 87 | return Err(Error::new( 88 | h.ident.span(), 89 | "Service handlers cannot be annotated with #[shared]", 90 | )); 91 | } 92 | 93 | rpcs.push(h); 94 | } 95 | let mut ident_errors = Ok(()); 96 | for rpc in &rpcs { 97 | if rpc.ident == "new" { 98 | extend_errors!( 99 | ident_errors, 100 | Error::new( 101 | rpc.ident.span(), 102 | format!( 103 | "method name conflicts with generated fn `{}Client::new`", 104 | ident.unraw() 105 | ) 106 | ) 107 | ); 108 | } 109 | if rpc.ident == "serve" { 110 | extend_errors!( 111 | ident_errors, 112 | Error::new( 113 | rpc.ident.span(), 114 | format!("method name conflicts with generated fn `{ident}::serve`") 115 | ) 116 | ); 117 | } 118 | } 119 | ident_errors?; 120 | 121 | let mut attrs = vec![]; 122 | let mut restate_name = ident.to_string(); 123 | for attr in parsed_attrs { 124 | if let Some(name) = read_literal_attribute_name(&attr)? { 125 | restate_name = name; 126 | } else { 127 | // Just propagate 128 | attrs.push(attr); 129 | } 130 | } 131 | 132 | Ok(Self { 133 | attrs, 134 | restate_name, 135 | vis, 136 | ident, 137 | handlers: rpcs, 138 | }) 139 | } 140 | } 141 | 142 | pub(crate) struct Handler { 143 | pub(crate) attrs: Vec, 144 | pub(crate) is_shared: bool, 145 | pub(crate) restate_name: String, 146 | pub(crate) ident: Ident, 147 | pub(crate) arg: Option, 148 | pub(crate) output_ok: Type, 149 | pub(crate) output_err: Type, 150 | } 151 | 152 | impl Parse for Handler { 153 | fn parse(input: ParseStream) -> Result { 154 | let parsed_attrs = input.call(Attribute::parse_outer)?; 155 | 156 | input.parse::()?; 157 | input.parse::()?; 158 | let ident: Ident = input.parse()?; 159 | 160 | // Parse arguments 161 | let content; 162 | parenthesized!(content in input); 163 | let mut args = Vec::new(); 164 | let mut errors = Ok(()); 165 | for arg in content.parse_terminated(FnArg::parse, Comma)? { 166 | match arg { 167 | FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => { 168 | args.push(captured); 169 | } 170 | FnArg::Typed(captured) => { 171 | extend_errors!( 172 | errors, 173 | Error::new(captured.pat.span(), "patterns aren't allowed in RPC args") 174 | ); 175 | } 176 | FnArg::Receiver(_) => { 177 | extend_errors!( 178 | errors, 179 | Error::new(arg.span(), "method args cannot start with self") 180 | ); 181 | } 182 | } 183 | } 184 | if args.len() > 1 { 185 | extend_errors!( 186 | errors, 187 | Error::new(content.span(), "Only one input argument is supported") 188 | ); 189 | } 190 | errors?; 191 | 192 | // Parse return type 193 | let return_type: ReturnType = input.parse()?; 194 | input.parse::()?; 195 | 196 | let (ok_ty, err_ty) = match &return_type { 197 | ReturnType::Default => return Err(Error::new( 198 | return_type.span(), 199 | "The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type", 200 | )), 201 | ReturnType::Type(_, ty) => { 202 | if let Some((ok_ty, err_ty)) = extract_handler_result_parameter(ty) { 203 | (ok_ty, err_ty) 204 | } else { 205 | return Err(Error::new( 206 | return_type.span(), 207 | "Only Result or restate_sdk::prelude::HandlerResult is supported as return type", 208 | )); 209 | } 210 | } 211 | }; 212 | 213 | // Process attributes 214 | let mut is_shared = false; 215 | let mut restate_name = ident.to_string(); 216 | let mut attrs = vec![]; 217 | for attr in parsed_attrs { 218 | if is_shared_attr(&attr) { 219 | is_shared = true; 220 | } else if let Some(name) = read_literal_attribute_name(&attr)? { 221 | restate_name = name; 222 | } else { 223 | // Just propagate 224 | attrs.push(attr); 225 | } 226 | } 227 | 228 | Ok(Self { 229 | attrs, 230 | is_shared, 231 | restate_name, 232 | ident, 233 | arg: args.pop(), 234 | output_ok: ok_ty, 235 | output_err: err_ty, 236 | }) 237 | } 238 | } 239 | 240 | fn is_shared_attr(attr: &Attribute) -> bool { 241 | attr.meta 242 | .require_path_only() 243 | .and_then(Path::require_ident) 244 | .is_ok_and(|i| i == "shared") 245 | } 246 | 247 | fn read_literal_attribute_name(attr: &Attribute) -> Result> { 248 | attr.meta 249 | .require_name_value() 250 | .ok() 251 | .filter(|val| val.path.require_ident().is_ok_and(|i| i == "name")) 252 | .map(|val| { 253 | if let Expr::Lit(ExprLit { 254 | lit: Lit::Str(ref literal), 255 | .. 256 | }) = &val.value 257 | { 258 | Ok(literal.value()) 259 | } else { 260 | Err(Error::new( 261 | val.span(), 262 | "Only string literal is allowed for the 'name' attribute", 263 | )) 264 | } 265 | }) 266 | .transpose() 267 | } 268 | 269 | fn extract_handler_result_parameter(ty: &Type) -> Option<(Type, Type)> { 270 | let path = match ty { 271 | Type::Path(ty) => &ty.path, 272 | _ => return None, 273 | }; 274 | 275 | let last = path.segments.last().unwrap(); 276 | let is_result = last.ident == "Result"; 277 | let is_handler_result = last.ident == "HandlerResult"; 278 | if !is_result && !is_handler_result { 279 | return None; 280 | } 281 | 282 | let bracketed = match &last.arguments { 283 | PathArguments::AngleBracketed(bracketed) => bracketed, 284 | _ => return None, 285 | }; 286 | 287 | if is_handler_result && bracketed.args.len() == 1 { 288 | match &bracketed.args[0] { 289 | GenericArgument::Type(arg) => Some(( 290 | arg.clone(), 291 | parse_quote!(::restate_sdk::prelude::HandlerError), 292 | )), 293 | _ => None, 294 | } 295 | } else if is_result && bracketed.args.len() == 2 { 296 | match (&bracketed.args[0], &bracketed.args[1]) { 297 | (GenericArgument::Type(ok_arg), GenericArgument::Type(err_arg)) => { 298 | Some((ok_arg.clone(), err_arg.clone())) 299 | } 300 | _ => None, 301 | } 302 | } else { 303 | None 304 | } 305 | } 306 | -------------------------------------------------------------------------------- /macros/src/gen.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::{Handler, Object, Service, ServiceInner, ServiceType, Workflow}; 2 | use proc_macro2::TokenStream as TokenStream2; 3 | use proc_macro2::{Ident, Literal}; 4 | use quote::{format_ident, quote, ToTokens}; 5 | use syn::{Attribute, PatType, Visibility}; 6 | 7 | pub(crate) struct ServiceGenerator<'a> { 8 | pub(crate) service_ty: ServiceType, 9 | pub(crate) restate_name: &'a str, 10 | pub(crate) service_ident: &'a Ident, 11 | pub(crate) client_ident: Ident, 12 | pub(crate) serve_ident: Ident, 13 | pub(crate) vis: &'a Visibility, 14 | pub(crate) attrs: &'a [Attribute], 15 | pub(crate) handlers: &'a [Handler], 16 | } 17 | 18 | impl<'a> ServiceGenerator<'a> { 19 | fn new(service_ty: ServiceType, s: &'a ServiceInner) -> Self { 20 | ServiceGenerator { 21 | service_ty, 22 | restate_name: &s.restate_name, 23 | service_ident: &s.ident, 24 | client_ident: format_ident!("{}Client", s.ident), 25 | serve_ident: format_ident!("Serve{}", s.ident), 26 | vis: &s.vis, 27 | attrs: &s.attrs, 28 | handlers: &s.handlers, 29 | } 30 | } 31 | 32 | pub(crate) fn new_service(s: &'a Service) -> Self { 33 | Self::new(ServiceType::Service, &s.0) 34 | } 35 | 36 | pub(crate) fn new_object(s: &'a Object) -> Self { 37 | Self::new(ServiceType::Object, &s.0) 38 | } 39 | 40 | pub(crate) fn new_workflow(s: &'a Workflow) -> Self { 41 | Self::new(ServiceType::Workflow, &s.0) 42 | } 43 | 44 | fn trait_service(&self) -> TokenStream2 { 45 | let Self { 46 | attrs, 47 | handlers, 48 | vis, 49 | service_ident, 50 | service_ty, 51 | serve_ident, 52 | .. 53 | } = self; 54 | 55 | let handler_fns = handlers 56 | .iter() 57 | .map( 58 | |Handler { attrs, ident, arg, is_shared, output_ok, output_err, .. }| { 59 | let args = arg.iter(); 60 | 61 | let ctx = match (&service_ty, is_shared) { 62 | (ServiceType::Service, _) => quote! { ::restate_sdk::prelude::Context }, 63 | (ServiceType::Object, true) => quote! { ::restate_sdk::prelude::SharedObjectContext }, 64 | (ServiceType::Object, false) => quote! { ::restate_sdk::prelude::ObjectContext }, 65 | (ServiceType::Workflow, true) => quote! { ::restate_sdk::prelude::SharedWorkflowContext }, 66 | (ServiceType::Workflow, false) => quote! { ::restate_sdk::prelude::WorkflowContext }, 67 | }; 68 | 69 | quote! { 70 | #( #attrs )* 71 | fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future> + ::core::marker::Send; 72 | } 73 | }, 74 | ); 75 | 76 | quote! { 77 | #( #attrs )* 78 | #vis trait #service_ident: ::core::marker::Sized { 79 | #( #handler_fns )* 80 | 81 | /// Returns a serving function to use with [::restate_sdk::endpoint::Builder::with_service]. 82 | fn serve(self) -> #serve_ident { 83 | #serve_ident { service: ::std::sync::Arc::new(self) } 84 | } 85 | } 86 | } 87 | } 88 | 89 | fn struct_serve(&self) -> TokenStream2 { 90 | let &Self { 91 | vis, 92 | ref serve_ident, 93 | .. 94 | } = self; 95 | 96 | quote! { 97 | /// Struct implementing [::restate_sdk::service::Service], to be used with [::restate_sdk::endpoint::Builder::with_service]. 98 | #[derive(Clone)] 99 | #vis struct #serve_ident { 100 | service: ::std::sync::Arc, 101 | } 102 | } 103 | } 104 | 105 | fn impl_service_for_serve(&self) -> TokenStream2 { 106 | let Self { 107 | serve_ident, 108 | service_ident, 109 | handlers, 110 | .. 111 | } = self; 112 | 113 | let match_arms = handlers.iter().map(|handler| { 114 | let handler_ident = &handler.ident; 115 | 116 | let get_input_and_call = if handler.arg.is_some() { 117 | quote! { 118 | let (input, metadata) = ctx.input().await; 119 | let fut = S::#handler_ident(&service_clone, (&ctx, metadata).into(), input); 120 | } 121 | } else { 122 | quote! { 123 | let (_, metadata) = ctx.input::<()>().await; 124 | let fut = S::#handler_ident(&service_clone, (&ctx, metadata).into()); 125 | } 126 | }; 127 | 128 | let handler_literal = Literal::string(&handler.restate_name); 129 | 130 | quote! { 131 | #handler_literal => { 132 | #get_input_and_call 133 | let res = fut.await.map_err(::restate_sdk::errors::HandlerError::from); 134 | ctx.handle_handler_result(res); 135 | ctx.end(); 136 | Ok(()) 137 | } 138 | } 139 | }); 140 | 141 | quote! { 142 | impl ::restate_sdk::service::Service for #serve_ident 143 | where S: #service_ident + Send + Sync + 'static, 144 | { 145 | type Future = ::restate_sdk::service::ServiceBoxFuture; 146 | 147 | fn handle(&self, ctx: ::restate_sdk::endpoint::ContextInternal) -> Self::Future { 148 | let service_clone = ::std::sync::Arc::clone(&self.service); 149 | Box::pin(async move { 150 | match ctx.handler_name() { 151 | #( #match_arms ),* 152 | _ => { 153 | return Err(::restate_sdk::endpoint::Error::unknown_handler( 154 | ctx.service_name(), 155 | ctx.handler_name(), 156 | )) 157 | } 158 | } 159 | }) 160 | } 161 | } 162 | } 163 | } 164 | 165 | fn impl_discoverable(&self) -> TokenStream2 { 166 | let Self { 167 | service_ty, 168 | serve_ident, 169 | service_ident, 170 | handlers, 171 | restate_name, 172 | .. 173 | } = self; 174 | 175 | let service_literal = Literal::string(restate_name); 176 | 177 | let service_ty_token = match service_ty { 178 | ServiceType::Service => quote! { ::restate_sdk::discovery::ServiceType::Service }, 179 | ServiceType::Object => { 180 | quote! { ::restate_sdk::discovery::ServiceType::VirtualObject } 181 | } 182 | ServiceType::Workflow => quote! { ::restate_sdk::discovery::ServiceType::Workflow }, 183 | }; 184 | 185 | let handlers = handlers.iter().map(|handler| { 186 | let handler_literal = Literal::string(&handler.restate_name); 187 | 188 | let handler_ty = if handler.is_shared { 189 | quote! { Some(::restate_sdk::discovery::HandlerType::Shared) } 190 | } else if *service_ty == ServiceType::Workflow { 191 | quote! { Some(::restate_sdk::discovery::HandlerType::Workflow) } 192 | } else { 193 | // Macro has same defaulting rules of the discovery manifest 194 | quote! { None } 195 | }; 196 | 197 | let input_schema = match &handler.arg { 198 | Some(PatType { ty, .. }) => { 199 | quote! { 200 | Some(::restate_sdk::discovery::InputPayload::from_metadata::<#ty>()) 201 | } 202 | } 203 | None => quote! { 204 | Some(::restate_sdk::discovery::InputPayload::empty()) 205 | } 206 | }; 207 | 208 | let output_ty = &handler.output_ok; 209 | let output_schema = match output_ty { 210 | syn::Type::Tuple(tuple) if tuple.elems.is_empty() => quote! { 211 | Some(::restate_sdk::discovery::OutputPayload::empty()) 212 | }, 213 | _ => quote! { 214 | Some(::restate_sdk::discovery::OutputPayload::from_metadata::<#output_ty>()) 215 | } 216 | }; 217 | 218 | quote! { 219 | ::restate_sdk::discovery::Handler { 220 | name: ::restate_sdk::discovery::HandlerName::try_from(#handler_literal).expect("Handler name valid"), 221 | input: #input_schema, 222 | output: #output_schema, 223 | ty: #handler_ty, 224 | } 225 | } 226 | }); 227 | 228 | quote! { 229 | impl ::restate_sdk::service::Discoverable for #serve_ident 230 | where S: #service_ident, 231 | { 232 | fn discover() -> ::restate_sdk::discovery::Service { 233 | ::restate_sdk::discovery::Service { 234 | ty: #service_ty_token, 235 | name: ::restate_sdk::discovery::ServiceName::try_from(#service_literal.to_string()) 236 | .expect("Service name valid"), 237 | handlers: vec![#( #handlers ),*], 238 | } 239 | } 240 | } 241 | } 242 | } 243 | 244 | fn struct_client(&self) -> TokenStream2 { 245 | let &Self { 246 | vis, 247 | ref client_ident, 248 | // service_ident, 249 | ref service_ty, 250 | .. 251 | } = self; 252 | 253 | let key_field = match service_ty { 254 | ServiceType::Service => quote! {}, 255 | ServiceType::Object | ServiceType::Workflow => quote! { 256 | key: String, 257 | }, 258 | }; 259 | 260 | let into_client_impl = match service_ty { 261 | ServiceType::Service => { 262 | quote! { 263 | impl<'ctx> ::restate_sdk::context::IntoServiceClient<'ctx> for #client_ident<'ctx> { 264 | fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal) -> Self { 265 | Self { ctx } 266 | } 267 | } 268 | } 269 | } 270 | ServiceType::Object => quote! { 271 | impl<'ctx> ::restate_sdk::context::IntoObjectClient<'ctx> for #client_ident<'ctx> { 272 | fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, key: String) -> Self { 273 | Self { ctx, key } 274 | } 275 | } 276 | }, 277 | ServiceType::Workflow => quote! { 278 | impl<'ctx> ::restate_sdk::context::IntoWorkflowClient<'ctx> for #client_ident<'ctx> { 279 | fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, key: String) -> Self { 280 | Self { ctx, key } 281 | } 282 | } 283 | }, 284 | }; 285 | 286 | quote! { 287 | /// Struct exposing the client to invoke [#service_ident] from another service. 288 | #vis struct #client_ident<'ctx> { 289 | ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, 290 | #key_field 291 | } 292 | 293 | #into_client_impl 294 | } 295 | } 296 | 297 | fn impl_client(&self) -> TokenStream2 { 298 | let &Self { 299 | vis, 300 | ref client_ident, 301 | service_ident, 302 | handlers, 303 | restate_name, 304 | service_ty, 305 | .. 306 | } = self; 307 | 308 | let service_literal = Literal::string(restate_name); 309 | 310 | let handlers_fns = handlers.iter().map(|handler| { 311 | let handler_ident = &handler.ident; 312 | let handler_literal = Literal::string(&handler.restate_name); 313 | 314 | let argument = match &handler.arg { 315 | None => quote! {}, 316 | Some(PatType { 317 | ty, .. 318 | }) => quote! { req: #ty } 319 | }; 320 | let argument_ty = match &handler.arg { 321 | None => quote! { () }, 322 | Some(PatType { 323 | ty, .. 324 | }) => quote! { #ty } 325 | }; 326 | let res_ty = &handler.output_ok; 327 | let input = match &handler.arg { 328 | None => quote! { () }, 329 | Some(_) => quote! { req } 330 | }; 331 | let request_target = match service_ty { 332 | ServiceType::Service => quote! { 333 | ::restate_sdk::context::RequestTarget::service(#service_literal, #handler_literal) 334 | }, 335 | ServiceType::Object => quote! { 336 | ::restate_sdk::context::RequestTarget::object(#service_literal, &self.key, #handler_literal) 337 | }, 338 | ServiceType::Workflow => quote! { 339 | ::restate_sdk::context::RequestTarget::workflow(#service_literal, &self.key, #handler_literal) 340 | } 341 | }; 342 | 343 | quote! { 344 | #vis fn #handler_ident(&self, #argument) -> ::restate_sdk::context::Request<'ctx, #argument_ty, #res_ty> { 345 | self.ctx.request(#request_target, #input) 346 | } 347 | } 348 | }); 349 | 350 | let doc_msg = format!( 351 | "Struct exposing the client to invoke [`{service_ident}`] from another service." 352 | ); 353 | quote! { 354 | #[doc = #doc_msg] 355 | impl<'ctx> #client_ident<'ctx> { 356 | #( #handlers_fns )* 357 | } 358 | } 359 | } 360 | } 361 | 362 | impl<'a> ToTokens for ServiceGenerator<'a> { 363 | fn to_tokens(&self, output: &mut TokenStream2) { 364 | output.extend(vec![ 365 | self.trait_service(), 366 | self.struct_serve(), 367 | self.impl_service_for_serve(), 368 | self.impl_discoverable(), 369 | self.struct_client(), 370 | self.impl_client(), 371 | ]); 372 | } 373 | } 374 | -------------------------------------------------------------------------------- /macros/src/lib.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 - Restate Software, Inc., Restate GmbH. 2 | // All rights reserved. 3 | // 4 | // Use of this software is governed by the Business Source License 5 | // included in the LICENSE file. 6 | // 7 | // As of the Change Date specified in that file, in accordance with 8 | // the Business Source License, use of this software will be governed 9 | // by the Apache License, Version 2.0. 10 | 11 | // Some parts of this codebase were taken from https://github.com/google/tarpc/blob/b826f332312d3702667880a464e247556ad7dbfe/plugins/src/lib.rs 12 | // License MIT 13 | 14 | extern crate proc_macro; 15 | 16 | mod ast; 17 | mod gen; 18 | 19 | use crate::ast::{Object, Service, Workflow}; 20 | use crate::gen::ServiceGenerator; 21 | use proc_macro::TokenStream; 22 | use quote::ToTokens; 23 | use syn::parse_macro_input; 24 | 25 | #[proc_macro_attribute] 26 | pub fn service(_: TokenStream, input: TokenStream) -> TokenStream { 27 | let svc = parse_macro_input!(input as Service); 28 | 29 | ServiceGenerator::new_service(&svc) 30 | .into_token_stream() 31 | .into() 32 | } 33 | 34 | #[proc_macro_attribute] 35 | pub fn object(_: TokenStream, input: TokenStream) -> TokenStream { 36 | let svc = parse_macro_input!(input as Object); 37 | 38 | ServiceGenerator::new_object(&svc) 39 | .into_token_stream() 40 | .into() 41 | } 42 | 43 | #[proc_macro_attribute] 44 | pub fn workflow(_: TokenStream, input: TokenStream) -> TokenStream { 45 | let svc = parse_macro_input!(input as Workflow); 46 | 47 | ServiceGenerator::new_workflow(&svc) 48 | .into_token_stream() 49 | .into() 50 | } 51 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "1.81.0" 3 | profile = "minimal" 4 | components = ["rustfmt", "clippy"] 5 | -------------------------------------------------------------------------------- /src/context/macro_support.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::ContextInternal; 2 | use restate_sdk_shared_core::NotificationHandle; 3 | 4 | // Sealed future trait, used by select statement 5 | #[doc(hidden)] 6 | pub trait SealedDurableFuture { 7 | fn inner_context(&self) -> ContextInternal; 8 | fn handle(&self) -> NotificationHandle; 9 | } 10 | -------------------------------------------------------------------------------- /src/context/request.rs: -------------------------------------------------------------------------------- 1 | use super::DurableFuture; 2 | 3 | use crate::endpoint::ContextInternal; 4 | use crate::errors::TerminalError; 5 | use crate::serde::{Deserialize, Serialize}; 6 | use std::fmt; 7 | use std::future::Future; 8 | use std::marker::PhantomData; 9 | use std::time::Duration; 10 | 11 | /// Target of a request to a Restate service. 12 | #[derive(Debug, Clone)] 13 | pub enum RequestTarget { 14 | Service { 15 | name: String, 16 | handler: String, 17 | }, 18 | Object { 19 | name: String, 20 | key: String, 21 | handler: String, 22 | }, 23 | Workflow { 24 | name: String, 25 | key: String, 26 | handler: String, 27 | }, 28 | } 29 | 30 | impl RequestTarget { 31 | pub fn service(name: impl Into, handler: impl Into) -> Self { 32 | Self::Service { 33 | name: name.into(), 34 | handler: handler.into(), 35 | } 36 | } 37 | 38 | pub fn object( 39 | name: impl Into, 40 | key: impl Into, 41 | handler: impl Into, 42 | ) -> Self { 43 | Self::Object { 44 | name: name.into(), 45 | key: key.into(), 46 | handler: handler.into(), 47 | } 48 | } 49 | 50 | pub fn workflow( 51 | name: impl Into, 52 | key: impl Into, 53 | handler: impl Into, 54 | ) -> Self { 55 | Self::Workflow { 56 | name: name.into(), 57 | key: key.into(), 58 | handler: handler.into(), 59 | } 60 | } 61 | } 62 | 63 | impl fmt::Display for RequestTarget { 64 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 65 | match self { 66 | RequestTarget::Service { name, handler } => write!(f, "{name}/{handler}"), 67 | RequestTarget::Object { name, key, handler } => write!(f, "{name}/{key}/{handler}"), 68 | RequestTarget::Workflow { name, key, handler } => write!(f, "{name}/{key}/{handler}"), 69 | } 70 | } 71 | } 72 | 73 | /// This struct encapsulates the parameters for a request to a service. 74 | pub struct Request<'a, Req, Res = ()> { 75 | ctx: &'a ContextInternal, 76 | request_target: RequestTarget, 77 | idempotency_key: Option, 78 | headers: Vec<(String, String)>, 79 | req: Req, 80 | res: PhantomData, 81 | } 82 | 83 | impl<'a, Req, Res> Request<'a, Req, Res> { 84 | pub(crate) fn new(ctx: &'a ContextInternal, request_target: RequestTarget, req: Req) -> Self { 85 | Self { 86 | ctx, 87 | request_target, 88 | idempotency_key: None, 89 | headers: vec![], 90 | req, 91 | res: PhantomData, 92 | } 93 | } 94 | 95 | pub fn header(mut self, key: String, value: String) -> Self { 96 | self.headers.push((key, value)); 97 | self 98 | } 99 | 100 | /// Add idempotency key to the request 101 | pub fn idempotency_key(mut self, idempotency_key: impl Into) -> Self { 102 | self.idempotency_key = Some(idempotency_key.into()); 103 | self 104 | } 105 | 106 | /// Call a service. This returns a future encapsulating the response. 107 | pub fn call(self) -> impl CallFuture + Send 108 | where 109 | Req: Serialize + 'static, 110 | Res: Deserialize + 'static, 111 | { 112 | self.ctx.call( 113 | self.request_target, 114 | self.idempotency_key, 115 | self.headers, 116 | self.req, 117 | ) 118 | } 119 | 120 | /// Send the request to the service, without waiting for the response. 121 | pub fn send(self) -> impl InvocationHandle 122 | where 123 | Req: Serialize + 'static, 124 | { 125 | self.ctx.send( 126 | self.request_target, 127 | self.idempotency_key, 128 | self.headers, 129 | self.req, 130 | None, 131 | ) 132 | } 133 | 134 | /// Schedule the request to the service, without waiting for the response. 135 | pub fn send_after(self, delay: Duration) -> impl InvocationHandle 136 | where 137 | Req: Serialize + 'static, 138 | { 139 | self.ctx.send( 140 | self.request_target, 141 | self.idempotency_key, 142 | self.headers, 143 | self.req, 144 | Some(delay), 145 | ) 146 | } 147 | } 148 | 149 | pub trait InvocationHandle { 150 | fn invocation_id(&self) -> impl Future> + Send; 151 | fn cancel(&self) -> impl Future> + Send; 152 | } 153 | 154 | pub trait CallFuture: 155 | DurableFuture> + InvocationHandle 156 | { 157 | type Response; 158 | } 159 | -------------------------------------------------------------------------------- /src/context/run.rs: -------------------------------------------------------------------------------- 1 | use crate::errors::HandlerResult; 2 | use crate::serde::{Deserialize, Serialize}; 3 | use std::future::Future; 4 | use std::time::Duration; 5 | 6 | /// Run closure trait 7 | pub trait RunClosure { 8 | type Output: Deserialize + Serialize + 'static; 9 | type Fut: Future>; 10 | 11 | fn run(self) -> Self::Fut; 12 | } 13 | 14 | impl RunClosure for F 15 | where 16 | F: FnOnce() -> Fut, 17 | Fut: Future>, 18 | O: Deserialize + Serialize + 'static, 19 | { 20 | type Output = O; 21 | type Fut = Fut; 22 | 23 | fn run(self) -> Self::Fut { 24 | self() 25 | } 26 | } 27 | 28 | /// Future created using [`ContextSideEffects::run`](super::ContextSideEffects::run). 29 | pub trait RunFuture: Future { 30 | /// Provide a custom retry policy for this `run` operation. 31 | /// 32 | /// If unspecified, the `run` will be retried using the [Restate invoker retry policy](https://docs.restate.dev/operate/configuration/server), 33 | /// which by default retries indefinitely. 34 | fn retry_policy(self, retry_policy: RunRetryPolicy) -> Self; 35 | 36 | /// Define a name for this `run` operation. 37 | /// 38 | /// This is used mainly for observability. 39 | fn name(self, name: impl Into) -> Self; 40 | } 41 | 42 | /// This struct represents the policy to execute retries for run closures. 43 | #[derive(Debug, Clone)] 44 | pub struct RunRetryPolicy { 45 | pub(crate) initial_delay: Duration, 46 | pub(crate) factor: f32, 47 | pub(crate) max_delay: Option, 48 | pub(crate) max_attempts: Option, 49 | pub(crate) max_duration: Option, 50 | } 51 | 52 | impl Default for RunRetryPolicy { 53 | fn default() -> Self { 54 | Self { 55 | initial_delay: Duration::from_millis(100), 56 | factor: 2.0, 57 | max_delay: Some(Duration::from_secs(2)), 58 | max_attempts: None, 59 | max_duration: Some(Duration::from_secs(50)), 60 | } 61 | } 62 | } 63 | 64 | impl RunRetryPolicy { 65 | /// Create a new retry policy. 66 | pub fn new() -> Self { 67 | Self { 68 | initial_delay: Duration::from_millis(100), 69 | factor: 1.0, 70 | max_delay: None, 71 | max_attempts: None, 72 | max_duration: None, 73 | } 74 | } 75 | 76 | /// Initial retry delay for the first retry attempt. 77 | pub fn initial_delay(mut self, initial_interval: Duration) -> Self { 78 | self.initial_delay = initial_interval; 79 | self 80 | } 81 | 82 | /// Exponentiation factor to use when computing the next retry delay. 83 | pub fn exponentiation_factor(mut self, factor: f32) -> Self { 84 | self.factor = factor; 85 | self 86 | } 87 | 88 | /// Maximum delay between retries. 89 | pub fn max_delay(mut self, max_interval: Duration) -> Self { 90 | self.max_delay = Some(max_interval); 91 | self 92 | } 93 | 94 | /// Gives up retrying when either at least the given number of attempts is reached, 95 | /// or `max_duration` (if set) is reached first. 96 | /// 97 | /// **Note:** The number of actual retries may be higher than the provided value. 98 | /// This is due to the nature of the run operation, which executes the closure on the service and sends the result afterward to Restate. 99 | /// 100 | /// Infinite retries if this field and `max_duration` are unset. 101 | pub fn max_attempts(mut self, max_attempts: u32) -> Self { 102 | self.max_attempts = Some(max_attempts); 103 | self 104 | } 105 | 106 | /// Gives up retrying when either the retry loop lasted at least for this given max duration, 107 | /// or `max_attempts` (if set) is reached first. 108 | /// 109 | /// **Note:** The real retry loop duration may be higher than the given duration. 110 | /// This is due to the nature of the run operation, which executes the closure on the service and sends the result afterward to Restate. 111 | /// 112 | /// Infinite retries if this field and `max_attempts` are unset. 113 | pub fn max_duration(mut self, max_duration: Duration) -> Self { 114 | self.max_duration = Some(max_duration); 115 | self 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/discovery.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the generated data structures from the [service protocol manifest schema](https://github.com/restatedev/service-protocol/blob/main/endpoint_manifest_schema.json). 2 | 3 | mod generated { 4 | #![allow(clippy::clone_on_copy)] 5 | #![allow(clippy::to_string_trait_impl)] 6 | 7 | include!(concat!(env!("OUT_DIR"), "/endpoint_manifest.rs")); 8 | } 9 | 10 | pub use generated::*; 11 | 12 | use crate::serde::PayloadMetadata; 13 | 14 | impl InputPayload { 15 | pub fn empty() -> Self { 16 | Self { 17 | content_type: None, 18 | json_schema: None, 19 | required: None, 20 | } 21 | } 22 | 23 | pub fn from_metadata() -> Self { 24 | let input_metadata = T::input_metadata(); 25 | Self { 26 | content_type: Some(input_metadata.accept_content_type.to_owned()), 27 | json_schema: T::json_schema(), 28 | required: Some(input_metadata.is_required), 29 | } 30 | } 31 | } 32 | 33 | impl OutputPayload { 34 | pub fn empty() -> Self { 35 | Self { 36 | content_type: None, 37 | json_schema: None, 38 | set_content_type_if_empty: Some(false), 39 | } 40 | } 41 | 42 | pub fn from_metadata() -> Self { 43 | let output_metadata = T::output_metadata(); 44 | Self { 45 | content_type: Some(output_metadata.content_type.to_owned()), 46 | json_schema: T::json_schema(), 47 | set_content_type_if_empty: Some(output_metadata.set_content_type_if_empty), 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/endpoint/futures/async_result_poll.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::context::ContextInternalInner; 2 | use crate::endpoint::ErrorInner; 3 | use restate_sdk_shared_core::{ 4 | DoProgressResponse, Error as CoreError, NotificationHandle, TakeOutputResult, TerminalFailure, 5 | Value, VM, 6 | }; 7 | use std::future::Future; 8 | use std::pin::Pin; 9 | use std::sync::{Arc, Mutex}; 10 | use std::task::Poll; 11 | 12 | pub(crate) struct VmAsyncResultPollFuture { 13 | state: Option, 14 | } 15 | 16 | impl VmAsyncResultPollFuture { 17 | pub fn new(ctx: Arc>, handle: NotificationHandle) -> Self { 18 | VmAsyncResultPollFuture { 19 | state: Some(AsyncResultPollState::Init { ctx, handle }), 20 | } 21 | } 22 | } 23 | 24 | enum AsyncResultPollState { 25 | Init { 26 | ctx: Arc>, 27 | handle: NotificationHandle, 28 | }, 29 | PollProgress { 30 | ctx: Arc>, 31 | handle: NotificationHandle, 32 | }, 33 | WaitingInput { 34 | ctx: Arc>, 35 | handle: NotificationHandle, 36 | }, 37 | } 38 | 39 | macro_rules! must_lock { 40 | ($mutex:expr) => { 41 | $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!") 42 | }; 43 | } 44 | 45 | impl Future for VmAsyncResultPollFuture { 46 | type Output = Result; 47 | 48 | fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { 49 | loop { 50 | match self 51 | .state 52 | .take() 53 | .expect("Future should not be polled after Poll::Ready") 54 | { 55 | AsyncResultPollState::Init { ctx, handle } => { 56 | let mut inner_lock = must_lock!(ctx); 57 | 58 | // Let's consume some output to begin with 59 | let out = inner_lock.vm.take_output(); 60 | match out { 61 | TakeOutputResult::Buffer(b) => { 62 | if !inner_lock.write.send(b) { 63 | return Poll::Ready(Err(ErrorInner::Suspended)); 64 | } 65 | } 66 | TakeOutputResult::EOF => { 67 | return Poll::Ready(Err(ErrorInner::UnexpectedOutputClosed)) 68 | } 69 | } 70 | 71 | // We can now start polling 72 | drop(inner_lock); 73 | self.state = Some(AsyncResultPollState::PollProgress { ctx, handle }); 74 | } 75 | AsyncResultPollState::WaitingInput { ctx, handle } => { 76 | let mut inner_lock = must_lock!(ctx); 77 | 78 | let read_result = match inner_lock.read.poll_recv(cx) { 79 | Poll::Ready(t) => t, 80 | Poll::Pending => { 81 | // Still need to wait for input 82 | drop(inner_lock); 83 | self.state = Some(AsyncResultPollState::WaitingInput { ctx, handle }); 84 | return Poll::Pending; 85 | } 86 | }; 87 | 88 | // Pass read result to VM 89 | match read_result { 90 | Some(Ok(b)) => inner_lock.vm.notify_input(b), 91 | Some(Err(e)) => inner_lock.vm.notify_error( 92 | CoreError::new(500u16, format!("Error when reading the body {e:?}",)), 93 | None, 94 | ), 95 | None => inner_lock.vm.notify_input_closed(), 96 | } 97 | 98 | // It's time to poll progress again 99 | drop(inner_lock); 100 | self.state = Some(AsyncResultPollState::PollProgress { ctx, handle }); 101 | } 102 | AsyncResultPollState::PollProgress { ctx, handle } => { 103 | let mut inner_lock = must_lock!(ctx); 104 | 105 | match inner_lock.vm.do_progress(vec![handle]) { 106 | Ok(DoProgressResponse::AnyCompleted) => { 107 | // We're good, we got the response 108 | } 109 | Ok(DoProgressResponse::ReadFromInput) => { 110 | drop(inner_lock); 111 | self.state = Some(AsyncResultPollState::WaitingInput { ctx, handle }); 112 | continue; 113 | } 114 | Ok(DoProgressResponse::ExecuteRun(_)) => { 115 | unimplemented!() 116 | } 117 | Ok(DoProgressResponse::WaitingPendingRun) => { 118 | unimplemented!() 119 | } 120 | Ok(DoProgressResponse::CancelSignalReceived) => { 121 | return Poll::Ready(Ok(Value::Failure(TerminalFailure { 122 | code: 409, 123 | message: "cancelled".to_string(), 124 | }))) 125 | } 126 | Err(e) => { 127 | return Poll::Ready(Err(e.into())); 128 | } 129 | }; 130 | 131 | // DoProgress might cause a flip of the replaying state 132 | inner_lock.maybe_flip_span_replaying_field(); 133 | 134 | // At this point let's try to take the notification 135 | match inner_lock.vm.take_notification(handle) { 136 | Ok(Some(v)) => return Poll::Ready(Ok(v)), 137 | Ok(None) => { 138 | panic!( 139 | "This is not supposed to happen, handle was flagged as completed" 140 | ) 141 | } 142 | Err(e) => return Poll::Ready(Err(e.into())), 143 | } 144 | } 145 | } 146 | } 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/endpoint/futures/durable_future_impl.rs: -------------------------------------------------------------------------------- 1 | use crate::context::DurableFuture; 2 | use crate::endpoint::{ContextInternal, Error}; 3 | use pin_project_lite::pin_project; 4 | use restate_sdk_shared_core::NotificationHandle; 5 | use std::future::Future; 6 | use std::pin::Pin; 7 | use std::task::{ready, Context, Poll}; 8 | 9 | pin_project! { 10 | /// Future that intercepts errors of inner future, and passes them to ContextInternal 11 | pub struct DurableFutureImpl{ 12 | #[pin] 13 | fut: F, 14 | handle: NotificationHandle, 15 | ctx: ContextInternal 16 | } 17 | } 18 | 19 | impl DurableFutureImpl { 20 | pub fn new(ctx: ContextInternal, handle: NotificationHandle, fut: F) -> Self { 21 | Self { fut, handle, ctx } 22 | } 23 | } 24 | 25 | impl Future for DurableFutureImpl 26 | where 27 | F: Future>, 28 | { 29 | type Output = R; 30 | 31 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 32 | let this = self.project(); 33 | let result = ready!(this.fut.poll(cx)); 34 | 35 | match result { 36 | Ok(r) => Poll::Ready(r), 37 | Err(e) => { 38 | this.ctx.fail(e); 39 | 40 | // Here is the secret sauce. This will immediately cause the whole future chain to be polled, 41 | // but the poll here will be intercepted by HandlerStateAwareFuture 42 | cx.waker().wake_by_ref(); 43 | Poll::Pending 44 | } 45 | } 46 | } 47 | } 48 | 49 | impl crate::context::macro_support::SealedDurableFuture for DurableFutureImpl { 50 | fn inner_context(&self) -> ContextInternal { 51 | self.ctx.clone() 52 | } 53 | 54 | fn handle(&self) -> NotificationHandle { 55 | self.handle 56 | } 57 | } 58 | 59 | impl DurableFuture for DurableFutureImpl where F: Future> {} 60 | -------------------------------------------------------------------------------- /src/endpoint/futures/handler_state_aware.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::{ContextInternal, Error}; 2 | use pin_project_lite::pin_project; 3 | use std::future::Future; 4 | use std::pin::Pin; 5 | use std::task::{Context, Poll}; 6 | use tokio::sync::oneshot; 7 | use tracing::warn; 8 | 9 | pin_project! { 10 | /// Future that will stop polling when handler is suspended/failed 11 | pub struct HandlerStateAwareFuture { 12 | #[pin] 13 | fut: F, 14 | handler_state_rx: oneshot::Receiver, 15 | handler_context: ContextInternal, 16 | } 17 | } 18 | 19 | impl HandlerStateAwareFuture { 20 | pub fn new( 21 | handler_context: ContextInternal, 22 | handler_state_rx: oneshot::Receiver, 23 | fut: F, 24 | ) -> HandlerStateAwareFuture { 25 | HandlerStateAwareFuture { 26 | fut, 27 | handler_state_rx, 28 | handler_context, 29 | } 30 | } 31 | } 32 | 33 | impl Future for HandlerStateAwareFuture 34 | where 35 | F: Future, 36 | { 37 | type Output = Result; 38 | 39 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 40 | let this = self.project(); 41 | 42 | match this.handler_state_rx.try_recv() { 43 | Ok(e) => { 44 | warn!( 45 | rpc.system = "restate", 46 | rpc.service = %this.handler_context.service_name(), 47 | rpc.method = %this.handler_context.handler_name(), 48 | "Error while processing handler {e:#}" 49 | ); 50 | this.handler_context.consume_to_end(); 51 | Poll::Ready(Err(e)) 52 | } 53 | Err(oneshot::error::TryRecvError::Empty) => match this.fut.poll(cx) { 54 | Poll::Ready(out) => { 55 | this.handler_context.consume_to_end(); 56 | Poll::Ready(Ok(out)) 57 | } 58 | Poll::Pending => Poll::Pending, 59 | }, 60 | Err(oneshot::error::TryRecvError::Closed) => { 61 | panic!("This is unexpected, this future is still being polled although the sender side was dropped. This should not be possible, because the sender is dropped when this future returns Poll:ready().") 62 | } 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/endpoint/futures/intercept_error.rs: -------------------------------------------------------------------------------- 1 | use crate::context::{InvocationHandle, RunFuture, RunRetryPolicy}; 2 | use crate::endpoint::{ContextInternal, Error}; 3 | use crate::errors::TerminalError; 4 | use pin_project_lite::pin_project; 5 | use std::future::Future; 6 | use std::pin::Pin; 7 | use std::task::{ready, Context, Poll}; 8 | 9 | pin_project! { 10 | /// Future that intercepts errors of inner future, and passes them to ContextInternal 11 | pub struct InterceptErrorFuture{ 12 | #[pin] 13 | fut: F, 14 | ctx: ContextInternal 15 | } 16 | } 17 | 18 | impl InterceptErrorFuture { 19 | pub fn new(ctx: ContextInternal, fut: F) -> Self { 20 | Self { fut, ctx } 21 | } 22 | } 23 | 24 | impl Future for InterceptErrorFuture 25 | where 26 | F: Future>, 27 | { 28 | type Output = R; 29 | 30 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 31 | let this = self.project(); 32 | let result = ready!(this.fut.poll(cx)); 33 | 34 | match result { 35 | Ok(r) => Poll::Ready(r), 36 | Err(e) => { 37 | this.ctx.fail(e); 38 | 39 | // Here is the secret sauce. This will immediately cause the whole future chain to be polled, 40 | // but the poll here will be intercepted by HandlerStateAwareFuture 41 | cx.waker().wake_by_ref(); 42 | Poll::Pending 43 | } 44 | } 45 | } 46 | } 47 | 48 | impl RunFuture for InterceptErrorFuture 49 | where 50 | F: RunFuture>, 51 | { 52 | fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { 53 | self.fut = self.fut.retry_policy(retry_policy); 54 | self 55 | } 56 | 57 | fn name(mut self, name: impl Into) -> Self { 58 | self.fut = self.fut.name(name); 59 | self 60 | } 61 | } 62 | 63 | impl InvocationHandle for InterceptErrorFuture { 64 | fn invocation_id(&self) -> impl Future> + Send { 65 | self.fut.invocation_id() 66 | } 67 | 68 | fn cancel(&self) -> impl Future> + Send { 69 | self.fut.cancel() 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/endpoint/futures/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod async_result_poll; 2 | pub mod durable_future_impl; 3 | pub mod handler_state_aware; 4 | pub mod intercept_error; 5 | pub mod select_poll; 6 | pub mod trap; 7 | -------------------------------------------------------------------------------- /src/endpoint/futures/select_poll.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::context::ContextInternalInner; 2 | use crate::endpoint::ErrorInner; 3 | use crate::errors::TerminalError; 4 | use restate_sdk_shared_core::{ 5 | DoProgressResponse, Error as CoreError, NotificationHandle, TakeOutputResult, TerminalFailure, 6 | VM, 7 | }; 8 | use std::future::Future; 9 | use std::pin::Pin; 10 | use std::sync::{Arc, Mutex}; 11 | use std::task::Poll; 12 | 13 | pub(crate) struct VmSelectAsyncResultPollFuture { 14 | state: Option, 15 | } 16 | 17 | impl VmSelectAsyncResultPollFuture { 18 | pub fn new(ctx: Arc>, handles: Vec) -> Self { 19 | VmSelectAsyncResultPollFuture { 20 | state: Some(VmSelectAsyncResultPollState::Init { ctx, handles }), 21 | } 22 | } 23 | } 24 | 25 | enum VmSelectAsyncResultPollState { 26 | Init { 27 | ctx: Arc>, 28 | handles: Vec, 29 | }, 30 | PollProgress { 31 | ctx: Arc>, 32 | handles: Vec, 33 | }, 34 | WaitingInput { 35 | ctx: Arc>, 36 | handles: Vec, 37 | }, 38 | } 39 | 40 | macro_rules! must_lock { 41 | ($mutex:expr) => { 42 | $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!") 43 | }; 44 | } 45 | 46 | impl Future for VmSelectAsyncResultPollFuture { 47 | type Output = Result, ErrorInner>; 48 | 49 | fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { 50 | loop { 51 | match self 52 | .state 53 | .take() 54 | .expect("Future should not be polled after Poll::Ready") 55 | { 56 | VmSelectAsyncResultPollState::Init { ctx, handles } => { 57 | let mut inner_lock = must_lock!(ctx); 58 | 59 | // Let's consume some output to begin with 60 | let out = inner_lock.vm.take_output(); 61 | match out { 62 | TakeOutputResult::Buffer(b) => { 63 | if !inner_lock.write.send(b) { 64 | return Poll::Ready(Err(ErrorInner::Suspended)); 65 | } 66 | } 67 | TakeOutputResult::EOF => { 68 | return Poll::Ready(Err(ErrorInner::UnexpectedOutputClosed)) 69 | } 70 | } 71 | 72 | // We can now start polling 73 | drop(inner_lock); 74 | self.state = Some(VmSelectAsyncResultPollState::PollProgress { ctx, handles }); 75 | } 76 | VmSelectAsyncResultPollState::WaitingInput { ctx, handles } => { 77 | let mut inner_lock = must_lock!(ctx); 78 | 79 | let read_result = match inner_lock.read.poll_recv(cx) { 80 | Poll::Ready(t) => t, 81 | Poll::Pending => { 82 | // Still need to wait for input 83 | drop(inner_lock); 84 | self.state = 85 | Some(VmSelectAsyncResultPollState::WaitingInput { ctx, handles }); 86 | return Poll::Pending; 87 | } 88 | }; 89 | 90 | // Pass read result to VM 91 | match read_result { 92 | Some(Ok(b)) => inner_lock.vm.notify_input(b), 93 | Some(Err(e)) => inner_lock.vm.notify_error( 94 | CoreError::new(500u16, format!("Error when reading the body {e:?}",)), 95 | None, 96 | ), 97 | None => inner_lock.vm.notify_input_closed(), 98 | } 99 | 100 | // It's time to poll progress again 101 | drop(inner_lock); 102 | self.state = Some(VmSelectAsyncResultPollState::PollProgress { ctx, handles }); 103 | } 104 | VmSelectAsyncResultPollState::PollProgress { ctx, handles } => { 105 | let mut inner_lock = must_lock!(ctx); 106 | 107 | match inner_lock.vm.do_progress(handles.clone()) { 108 | Ok(DoProgressResponse::AnyCompleted) => { 109 | // We're good, we got the response 110 | } 111 | Ok(DoProgressResponse::ReadFromInput) => { 112 | drop(inner_lock); 113 | self.state = 114 | Some(VmSelectAsyncResultPollState::WaitingInput { ctx, handles }); 115 | continue; 116 | } 117 | Ok(DoProgressResponse::ExecuteRun(_)) => { 118 | unimplemented!() 119 | } 120 | Ok(DoProgressResponse::WaitingPendingRun) => { 121 | unimplemented!() 122 | } 123 | Ok(DoProgressResponse::CancelSignalReceived) => { 124 | return Poll::Ready(Ok(Err(TerminalFailure { 125 | code: 409, 126 | message: "cancelled".to_string(), 127 | } 128 | .into()))) 129 | } 130 | Err(e) => { 131 | return Poll::Ready(Err(e.into())); 132 | } 133 | }; 134 | 135 | // DoProgress might cause a flip of the replaying state 136 | inner_lock.maybe_flip_span_replaying_field(); 137 | 138 | // At this point let's try to take the notification 139 | for (idx, handle) in handles.iter().enumerate() { 140 | if inner_lock.vm.is_completed(*handle) { 141 | return Poll::Ready(Ok(Ok(idx))); 142 | } 143 | } 144 | panic!( 145 | "This is not supposed to happen, none of the given handles were completed even though poll progress completed with AnyCompleted" 146 | ) 147 | } 148 | } 149 | } 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /src/endpoint/futures/trap.rs: -------------------------------------------------------------------------------- 1 | use crate::context::InvocationHandle; 2 | use crate::errors::TerminalError; 3 | use std::future::Future; 4 | use std::marker::PhantomData; 5 | use std::pin::Pin; 6 | use std::task::{Context, Poll}; 7 | 8 | /// Future that traps the execution at this point, but keeps waking up the waker 9 | pub struct TrapFuture(PhantomData T>); 10 | 11 | impl Default for TrapFuture { 12 | fn default() -> Self { 13 | Self(PhantomData) 14 | } 15 | } 16 | 17 | impl Future for TrapFuture { 18 | type Output = T; 19 | 20 | fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { 21 | ctx.waker().wake_by_ref(); 22 | Poll::Pending 23 | } 24 | } 25 | 26 | impl InvocationHandle for TrapFuture { 27 | fn invocation_id(&self) -> impl Future> + Send { 28 | TrapFuture::default() 29 | } 30 | 31 | fn cancel(&self) -> impl Future> + Send { 32 | TrapFuture::default() 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/endpoint/handler_state.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::Error; 2 | use tokio::sync::oneshot; 3 | 4 | pub(super) struct HandlerStateNotifier { 5 | tx: Option>, 6 | } 7 | 8 | impl HandlerStateNotifier { 9 | pub(crate) fn new() -> (Self, oneshot::Receiver) { 10 | let (tx, rx) = oneshot::channel(); 11 | (Self { tx: Some(tx) }, rx) 12 | } 13 | 14 | pub(super) fn mark_error(&mut self, err: Error) { 15 | if let Some(tx) = self.tx.take() { 16 | let _ = tx.send(err); 17 | } 18 | // Some other operation already marked this handler as errored. 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/endpoint/mod.rs: -------------------------------------------------------------------------------- 1 | mod context; 2 | mod futures; 3 | mod handler_state; 4 | 5 | use crate::endpoint::futures::handler_state_aware::HandlerStateAwareFuture; 6 | use crate::endpoint::futures::intercept_error::InterceptErrorFuture; 7 | use crate::endpoint::handler_state::HandlerStateNotifier; 8 | use crate::service::{Discoverable, Service}; 9 | use ::futures::future::BoxFuture; 10 | use ::futures::{Stream, StreamExt}; 11 | use bytes::Bytes; 12 | pub use context::{ContextInternal, InputMetadata}; 13 | use restate_sdk_shared_core::{ 14 | CoreVM, Error as CoreError, Header, HeaderMap, IdentityVerifier, KeyError, VerifyError, VM, 15 | }; 16 | use std::collections::HashMap; 17 | use std::future::poll_fn; 18 | use std::pin::Pin; 19 | use std::sync::Arc; 20 | use std::task::{Context, Poll}; 21 | use tracing::{info_span, Instrument}; 22 | 23 | const DISCOVERY_CONTENT_TYPE: &str = "application/vnd.restate.endpointmanifest.v1+json"; 24 | 25 | type BoxError = Box; 26 | 27 | pub struct OutputSender(tokio::sync::mpsc::UnboundedSender); 28 | 29 | impl OutputSender { 30 | pub fn from_channel(tx: tokio::sync::mpsc::UnboundedSender) -> Self { 31 | Self(tx) 32 | } 33 | 34 | fn send(&self, b: Bytes) -> bool { 35 | self.0.send(b).is_ok() 36 | } 37 | } 38 | 39 | pub struct InputReceiver(InputReceiverInner); 40 | 41 | enum InputReceiverInner { 42 | Channel(tokio::sync::mpsc::UnboundedReceiver>), 43 | BoxedStream(Pin> + Send + 'static>>), 44 | } 45 | 46 | impl InputReceiver { 47 | pub fn from_stream> + Send + 'static>(s: S) -> Self { 48 | Self(InputReceiverInner::BoxedStream(Box::pin(s))) 49 | } 50 | 51 | pub fn from_channel(rx: tokio::sync::mpsc::UnboundedReceiver>) -> Self { 52 | Self(InputReceiverInner::Channel(rx)) 53 | } 54 | 55 | async fn recv(&mut self) -> Option> { 56 | poll_fn(|cx| self.poll_recv(cx)).await 57 | } 58 | 59 | fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll>> { 60 | match &mut self.0 { 61 | InputReceiverInner::Channel(ch) => ch.poll_recv(cx), 62 | InputReceiverInner::BoxedStream(s) => s.poll_next_unpin(cx), 63 | } 64 | } 65 | } 66 | 67 | // TODO can we have the backtrace here? 68 | /// Endpoint error. This encapsulates any error that happens within the SDK while processing a request. 69 | #[derive(Debug, thiserror::Error)] 70 | #[error(transparent)] 71 | pub struct Error(#[from] ErrorInner); 72 | 73 | impl Error { 74 | /// New error for unknown handler 75 | pub fn unknown_handler(service_name: &str, handler_name: &str) -> Self { 76 | Self(ErrorInner::UnknownServiceHandler( 77 | service_name.to_owned(), 78 | handler_name.to_owned(), 79 | )) 80 | } 81 | } 82 | 83 | impl Error { 84 | /// Returns the HTTP status code for this error. 85 | pub fn status_code(&self) -> u16 { 86 | match &self.0 { 87 | ErrorInner::VM(e) => e.code(), 88 | ErrorInner::UnknownService(_) | ErrorInner::UnknownServiceHandler(_, _) => 404, 89 | ErrorInner::Suspended 90 | | ErrorInner::UnexpectedOutputClosed 91 | | ErrorInner::UnexpectedValueVariantForSyscall { .. } 92 | | ErrorInner::Deserialization { .. } 93 | | ErrorInner::Serialization { .. } 94 | | ErrorInner::HandlerResult { .. } => 500, 95 | ErrorInner::BadDiscovery(_) => 415, 96 | ErrorInner::Header { .. } | ErrorInner::BadPath { .. } => 400, 97 | ErrorInner::IdentityVerification(_) => 401, 98 | } 99 | } 100 | } 101 | 102 | #[derive(Debug, thiserror::Error)] 103 | pub(crate) enum ErrorInner { 104 | #[error("Received a request for unknown service '{0}'")] 105 | UnknownService(String), 106 | #[error("Received a request for unknown service handler '{0}/{1}'")] 107 | UnknownServiceHandler(String, String), 108 | #[error("Error when processing the request: {0:?}")] 109 | VM(#[from] CoreError), 110 | #[error("Error when verifying identity: {0:?}")] 111 | IdentityVerification(#[from] VerifyError), 112 | #[error("Cannot convert header '{0}', reason: {1}")] 113 | Header(String, #[source] BoxError), 114 | #[error("Cannot reply to discovery, got accept header '{0}' but currently supported discovery is {DISCOVERY_CONTENT_TYPE}")] 115 | BadDiscovery(String), 116 | #[error("Bad path '{0}', expected either '/discover' or '/invoke/service/handler'")] 117 | BadPath(String), 118 | #[error("Suspended")] 119 | Suspended, 120 | #[error("Unexpected output closed")] 121 | UnexpectedOutputClosed, 122 | #[error("Unexpected value variant {variant} for syscall '{syscall}'")] 123 | UnexpectedValueVariantForSyscall { 124 | variant: &'static str, 125 | syscall: &'static str, 126 | }, 127 | #[error("Failed to deserialize with '{syscall}': {err:?}'")] 128 | Deserialization { 129 | syscall: &'static str, 130 | #[source] 131 | err: BoxError, 132 | }, 133 | #[error("Failed to serialize with '{syscall}': {err:?}'")] 134 | Serialization { 135 | syscall: &'static str, 136 | #[source] 137 | err: BoxError, 138 | }, 139 | #[error("Handler failed with retryable error: {err:?}'")] 140 | HandlerResult { 141 | #[source] 142 | err: BoxError, 143 | }, 144 | } 145 | 146 | impl From for ErrorInner { 147 | fn from(_: restate_sdk_shared_core::SuspendedError) -> Self { 148 | Self::Suspended 149 | } 150 | } 151 | 152 | impl From for ErrorInner { 153 | fn from(value: restate_sdk_shared_core::SuspendedOrVMError) -> Self { 154 | match value { 155 | restate_sdk_shared_core::SuspendedOrVMError::Suspended(e) => e.into(), 156 | restate_sdk_shared_core::SuspendedOrVMError::VM(e) => e.into(), 157 | } 158 | } 159 | } 160 | 161 | impl From for Error { 162 | fn from(e: CoreError) -> Self { 163 | ErrorInner::from(e).into() 164 | } 165 | } 166 | 167 | struct BoxedService( 168 | Box>> + Send + Sync + 'static>, 169 | ); 170 | 171 | impl BoxedService { 172 | pub fn new< 173 | S: Service>> + Send + Sync + 'static, 174 | >( 175 | service: S, 176 | ) -> Self { 177 | Self(Box::new(service)) 178 | } 179 | } 180 | 181 | impl Service for BoxedService { 182 | type Future = BoxFuture<'static, Result<(), Error>>; 183 | 184 | fn handle(&self, req: ContextInternal) -> Self::Future { 185 | self.0.handle(req) 186 | } 187 | } 188 | 189 | /// Builder for [`Endpoint`] 190 | pub struct Builder { 191 | svcs: HashMap, 192 | discovery: crate::discovery::Endpoint, 193 | identity_verifier: IdentityVerifier, 194 | } 195 | 196 | impl Default for Builder { 197 | fn default() -> Self { 198 | Self { 199 | svcs: Default::default(), 200 | discovery: crate::discovery::Endpoint { 201 | max_protocol_version: 5, 202 | min_protocol_version: 5, 203 | protocol_mode: Some(crate::discovery::ProtocolMode::BidiStream), 204 | services: vec![], 205 | }, 206 | identity_verifier: Default::default(), 207 | } 208 | } 209 | } 210 | 211 | impl Builder { 212 | /// Create a new builder for [`Endpoint`]. 213 | pub fn new() -> Self { 214 | Self::default() 215 | } 216 | 217 | /// Add a [`Service`] to this endpoint. 218 | /// 219 | /// When using the [`service`](macro@crate::service), [`object`](macro@crate::object) or [`workflow`](macro@crate::workflow) macros, 220 | /// you need to pass the result of the `serve` method. 221 | pub fn bind< 222 | S: Service>> 223 | + Discoverable 224 | + Send 225 | + Sync 226 | + 'static, 227 | >( 228 | mut self, 229 | s: S, 230 | ) -> Self { 231 | let service_metadata = S::discover(); 232 | let boxed_service = BoxedService::new(s); 233 | self.svcs 234 | .insert(service_metadata.name.to_string(), boxed_service); 235 | self.discovery.services.push(service_metadata); 236 | self 237 | } 238 | 239 | /// Add identity key, e.g. `publickeyv1_ChjENKeMvCtRnqG2mrBK1HmPKufgFUc98K8B3ononQvp`. 240 | pub fn identity_key(mut self, key: &str) -> Result { 241 | self.identity_verifier = self.identity_verifier.with_key(key)?; 242 | Ok(self) 243 | } 244 | 245 | /// Build the [`Endpoint`]. 246 | pub fn build(self) -> Endpoint { 247 | Endpoint(Arc::new(EndpointInner { 248 | svcs: self.svcs, 249 | discovery: self.discovery, 250 | identity_verifier: self.identity_verifier, 251 | })) 252 | } 253 | } 254 | 255 | /// This struct encapsulates all the business logic to handle incoming requests to the SDK, 256 | /// including service discovery, invocations and identity verification. 257 | /// 258 | /// It internally wraps the provided services. This structure is cheaply cloneable. 259 | #[derive(Clone)] 260 | pub struct Endpoint(Arc); 261 | 262 | impl Endpoint { 263 | /// Create a new builder for [`Endpoint`]. 264 | pub fn builder() -> Builder { 265 | Builder::new() 266 | } 267 | } 268 | 269 | pub struct EndpointInner { 270 | svcs: HashMap, 271 | discovery: crate::discovery::Endpoint, 272 | identity_verifier: IdentityVerifier, 273 | } 274 | 275 | impl Endpoint { 276 | pub fn resolve(&self, path: &str, headers: H) -> Result 277 | where 278 | H: HeaderMap, 279 | ::Error: std::error::Error + Send + Sync + 'static, 280 | { 281 | if let Err(e) = self.0.identity_verifier.verify_identity(&headers, path) { 282 | return Err(ErrorInner::IdentityVerification(e).into()); 283 | } 284 | 285 | let parts: Vec<&str> = path.split('/').collect(); 286 | 287 | if parts.last() == Some(&"health") { 288 | return Ok(Response::ReplyNow { 289 | status_code: 200, 290 | headers: vec![], 291 | body: Bytes::new(), 292 | }); 293 | } 294 | 295 | if parts.last() == Some(&"discover") { 296 | let accept_header = headers 297 | .extract("accept") 298 | .map_err(|e| ErrorInner::Header("accept".to_owned(), Box::new(e)))?; 299 | if accept_header.is_some() { 300 | let accept = accept_header.unwrap(); 301 | if !accept.contains("application/vnd.restate.endpointmanifest.v1+json") { 302 | return Err(Error(ErrorInner::BadDiscovery(accept.to_owned()))); 303 | } 304 | } 305 | 306 | return Ok(Response::ReplyNow { 307 | status_code: 200, 308 | headers: vec![Header { 309 | key: "content-type".into(), 310 | value: DISCOVERY_CONTENT_TYPE.into(), 311 | }], 312 | body: Bytes::from( 313 | serde_json::to_string(&self.0.discovery) 314 | .expect("Discovery should be serializable"), 315 | ), 316 | }); 317 | } 318 | 319 | let (svc_name, handler_name) = match parts.get(parts.len() - 3..) { 320 | None => return Err(Error(ErrorInner::BadPath(path.to_owned()))), 321 | Some(last_elements) if last_elements[0] != "invoke" => { 322 | return Err(Error(ErrorInner::BadPath(path.to_owned()))) 323 | } 324 | Some(last_elements) => (last_elements[1].to_owned(), last_elements[2].to_owned()), 325 | }; 326 | 327 | let vm = CoreVM::new(headers, Default::default()).map_err(ErrorInner::VM)?; 328 | if !self.0.svcs.contains_key(&svc_name) { 329 | return Err(ErrorInner::UnknownService(svc_name.to_owned()).into()); 330 | } 331 | 332 | let response_head = vm.get_response_head(); 333 | 334 | Ok(Response::BidiStream { 335 | status_code: response_head.status_code, 336 | headers: response_head.headers, 337 | handler: BidiStreamRunner { 338 | svc_name, 339 | handler_name, 340 | vm, 341 | endpoint: Arc::clone(&self.0), 342 | }, 343 | }) 344 | } 345 | } 346 | 347 | pub enum Response { 348 | ReplyNow { 349 | status_code: u16, 350 | headers: Vec
, 351 | body: Bytes, 352 | }, 353 | BidiStream { 354 | status_code: u16, 355 | headers: Vec
, 356 | handler: BidiStreamRunner, 357 | }, 358 | } 359 | 360 | pub struct BidiStreamRunner { 361 | svc_name: String, 362 | handler_name: String, 363 | vm: CoreVM, 364 | endpoint: Arc, 365 | } 366 | 367 | impl BidiStreamRunner { 368 | pub async fn handle( 369 | self, 370 | input_rx: InputReceiver, 371 | output_tx: OutputSender, 372 | ) -> Result<(), Error> { 373 | // Retrieve the service from the Arc 374 | let svc = self 375 | .endpoint 376 | .svcs 377 | .get(&self.svc_name) 378 | .expect("service must exist at this point"); 379 | 380 | let span = info_span!( 381 | "restate_sdk_endpoint_handle", 382 | "rpc.system" = "restate", 383 | "rpc.service" = self.svc_name, 384 | "rpc.method" = self.handler_name, 385 | "restate.sdk.is_replaying" = false 386 | ); 387 | handle( 388 | input_rx, 389 | output_tx, 390 | self.vm, 391 | self.svc_name, 392 | self.handler_name, 393 | svc, 394 | ) 395 | .instrument(span) 396 | .await 397 | } 398 | } 399 | 400 | #[doc(hidden)] 401 | pub async fn handle>> + Send + Sync>( 402 | mut input_rx: InputReceiver, 403 | output_tx: OutputSender, 404 | vm: CoreVM, 405 | svc_name: String, 406 | handler_name: String, 407 | svc: &S, 408 | ) -> Result<(), Error> { 409 | let mut vm = vm; 410 | init_loop_vm(&mut vm, &mut input_rx).await?; 411 | 412 | // Initialize handler context 413 | let (handler_state_tx, handler_state_rx) = HandlerStateNotifier::new(); 414 | let ctx = ContextInternal::new( 415 | vm, 416 | svc_name, 417 | handler_name, 418 | input_rx, 419 | output_tx, 420 | handler_state_tx, 421 | ); 422 | 423 | // Start user code 424 | let user_code_fut = InterceptErrorFuture::new(ctx.clone(), svc.handle(ctx.clone())); 425 | 426 | // Wrap it in handler state aware future 427 | HandlerStateAwareFuture::new(ctx.clone(), handler_state_rx, user_code_fut).await 428 | } 429 | 430 | async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<(), ErrorInner> { 431 | while !vm.is_ready_to_execute().map_err(ErrorInner::VM)? { 432 | match input_rx.recv().await { 433 | Some(Ok(b)) => vm.notify_input(b), 434 | Some(Err(e)) => vm.notify_error( 435 | CoreError::new(500u16, format!("Error when reading the body: {e}")), 436 | None, 437 | ), 438 | None => vm.notify_input_closed(), 439 | } 440 | } 441 | Ok(()) 442 | } 443 | -------------------------------------------------------------------------------- /src/errors.rs: -------------------------------------------------------------------------------- 1 | //! # Error Handling 2 | //! 3 | //! Restate handles retries for failed invocations. 4 | //! By default, Restate does infinite retries with an exponential backoff strategy. 5 | //! 6 | //! For failures for which you do not want retries, but instead want the invocation to end and the error message 7 | //! to be propagated back to the caller, you can return a [`TerminalError`]. 8 | //! 9 | //! You can return a [`TerminalError`] with an optional HTTP status code and a message anywhere in your handler, as follows: 10 | //! 11 | //! ```rust,no_run 12 | //! # use restate_sdk::prelude::*; 13 | //! # async fn handle() -> Result<(), HandlerError> { 14 | //! Err(TerminalError::new("This is a terminal error").into()) 15 | //! # } 16 | //! ``` 17 | //! 18 | //! You can catch terminal exceptions. For example, you can catch the terminal exception that comes out of a [call to another service][crate::context::ContextClient], and build your control flow around it. 19 | use restate_sdk_shared_core::TerminalFailure; 20 | use std::error::Error as StdError; 21 | use std::fmt; 22 | 23 | #[derive(Debug)] 24 | pub(crate) enum HandlerErrorInner { 25 | Retryable(Box), 26 | Terminal(TerminalErrorInner), 27 | } 28 | 29 | impl fmt::Display for HandlerErrorInner { 30 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 31 | match self { 32 | HandlerErrorInner::Retryable(e) => { 33 | write!(f, "Retryable error: {}", e) 34 | } 35 | HandlerErrorInner::Terminal(e) => fmt::Display::fmt(e, f), 36 | } 37 | } 38 | } 39 | 40 | impl StdError for HandlerErrorInner { 41 | fn source(&self) -> Option<&(dyn StdError + 'static)> { 42 | match self { 43 | HandlerErrorInner::Retryable(e) => Some(e.as_ref()), 44 | HandlerErrorInner::Terminal(e) => Some(e), 45 | } 46 | } 47 | } 48 | 49 | /// This error can contain either a [`TerminalError`], or any other Rust's [`StdError`]. 50 | /// For the latter, the error is considered "retryable", and the execution will be retried. 51 | #[derive(Debug)] 52 | pub struct HandlerError(pub(crate) HandlerErrorInner); 53 | 54 | impl>> From for HandlerError { 55 | fn from(value: E) -> Self { 56 | Self(HandlerErrorInner::Retryable(value.into())) 57 | } 58 | } 59 | 60 | impl From for HandlerError { 61 | fn from(value: TerminalError) -> Self { 62 | Self(HandlerErrorInner::Terminal(value.0)) 63 | } 64 | } 65 | 66 | // Took from anyhow 67 | impl AsRef for HandlerError { 68 | fn as_ref(&self) -> &(dyn StdError + Send + Sync + 'static) { 69 | &self.0 70 | } 71 | } 72 | 73 | impl AsRef for HandlerError { 74 | fn as_ref(&self) -> &(dyn StdError + 'static) { 75 | &self.0 76 | } 77 | } 78 | 79 | #[derive(Debug, Clone)] 80 | pub(crate) struct TerminalErrorInner { 81 | code: u16, 82 | message: String, 83 | } 84 | 85 | impl fmt::Display for TerminalErrorInner { 86 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 87 | write!(f, "Terminal error [{}]: {}", self.code, self.message) 88 | } 89 | } 90 | 91 | impl StdError for TerminalErrorInner {} 92 | 93 | /// Error representing the result of an operation recorded in the journal. 94 | /// 95 | /// When returned inside a [`crate::context::ContextSideEffects::run`] closure, or in a handler, it completes the operation with a failure value. 96 | #[derive(Debug, Clone)] 97 | pub struct TerminalError(pub(crate) TerminalErrorInner); 98 | 99 | impl TerminalError { 100 | /// Create a new [`TerminalError`]. 101 | pub fn new(message: impl Into) -> Self { 102 | Self::new_with_code(500, message) 103 | } 104 | 105 | /// Create a new [`TerminalError`] with a status code. 106 | pub fn new_with_code(code: u16, message: impl Into) -> Self { 107 | Self(TerminalErrorInner { 108 | code, 109 | message: message.into(), 110 | }) 111 | } 112 | 113 | pub fn code(&self) -> u16 { 114 | self.0.code 115 | } 116 | 117 | pub fn message(&self) -> &str { 118 | &self.0.message 119 | } 120 | 121 | pub fn from_error(e: E) -> Self { 122 | Self::new(e.to_string()) 123 | } 124 | } 125 | 126 | impl fmt::Display for TerminalError { 127 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 128 | fmt::Display::fmt(&self.0, f) 129 | } 130 | } 131 | 132 | impl AsRef for TerminalError { 133 | fn as_ref(&self) -> &(dyn StdError + Send + Sync + 'static) { 134 | &self.0 135 | } 136 | } 137 | 138 | impl AsRef for TerminalError { 139 | fn as_ref(&self) -> &(dyn StdError + 'static) { 140 | &self.0 141 | } 142 | } 143 | 144 | impl From for TerminalError { 145 | fn from(value: TerminalFailure) -> Self { 146 | Self(TerminalErrorInner { 147 | code: value.code, 148 | message: value.message, 149 | }) 150 | } 151 | } 152 | 153 | impl From for TerminalFailure { 154 | fn from(value: TerminalError) -> Self { 155 | Self { 156 | code: value.0.code, 157 | message: value.0.message, 158 | } 159 | } 160 | } 161 | 162 | /// Result type for a Restate handler. 163 | pub type HandlerResult = Result; 164 | -------------------------------------------------------------------------------- /src/filter.rs: -------------------------------------------------------------------------------- 1 | //! Replay aware tracing filter. 2 | 3 | use std::fmt::Debug; 4 | use tracing::{ 5 | field::{Field, Visit}, 6 | span::{Attributes, Record}, 7 | Event, Id, Metadata, Subscriber, 8 | }; 9 | use tracing_subscriber::{ 10 | layer::{Context, Filter}, 11 | registry::LookupSpan, 12 | Layer, 13 | }; 14 | 15 | #[derive(Debug)] 16 | struct ReplayField(bool); 17 | 18 | struct ReplayFieldVisitor(bool); 19 | 20 | impl Visit for ReplayFieldVisitor { 21 | fn record_bool(&mut self, field: &Field, value: bool) { 22 | if field.name().eq("restate.sdk.is_replaying") { 23 | self.0 = value; 24 | } 25 | } 26 | 27 | fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {} 28 | } 29 | 30 | /// Replay aware tracing filter. 31 | /// 32 | /// Use this filter to skip tracing events in the service while replaying: 33 | /// 34 | /// ```rust,no_run 35 | /// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; 36 | /// tracing_subscriber::registry() 37 | /// .with( 38 | /// tracing_subscriber::fmt::layer() 39 | /// // Default Env filter to read RUST_LOG 40 | /// .with_filter(tracing_subscriber::EnvFilter::from_default_env()) 41 | /// // Replay aware filter 42 | /// .with_filter(restate_sdk::filter::ReplayAwareFilter) 43 | /// ) 44 | /// .init(); 45 | /// ``` 46 | pub struct ReplayAwareFilter; 47 | 48 | impl LookupSpan<'lookup>> Filter for ReplayAwareFilter { 49 | fn enabled(&self, _meta: &Metadata<'_>, _cx: &Context<'_, S>) -> bool { 50 | true 51 | } 52 | 53 | fn event_enabled(&self, event: &Event<'_>, cx: &Context<'_, S>) -> bool { 54 | if let Some(scope) = cx.event_scope(event) { 55 | let iterator = scope.from_root(); 56 | for span in iterator { 57 | if span.name() == "restate_sdk_endpoint_handle" { 58 | if let Some(replay) = span.extensions().get::() { 59 | return !replay.0; 60 | } 61 | } 62 | } 63 | } 64 | true 65 | } 66 | 67 | fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { 68 | if let Some(span) = ctx.span(id) { 69 | if span.name() == "restate_sdk_endpoint_handle" { 70 | let mut visitor = ReplayFieldVisitor(false); 71 | attrs.record(&mut visitor); 72 | let mut extensions = span.extensions_mut(); 73 | extensions.replace::(ReplayField(visitor.0)); 74 | } 75 | } 76 | } 77 | 78 | fn on_record(&self, id: &Id, values: &Record<'_>, ctx: Context<'_, S>) { 79 | if let Some(span) = ctx.span(id) { 80 | if span.name() == "restate_sdk_endpoint_handle" { 81 | let mut visitor = ReplayFieldVisitor(false); 82 | values.record(&mut visitor); 83 | let mut extensions = span.extensions_mut(); 84 | extensions.replace::(ReplayField(visitor.0)); 85 | } 86 | } 87 | } 88 | } 89 | 90 | impl Layer for ReplayAwareFilter {} 91 | -------------------------------------------------------------------------------- /src/http_server.rs: -------------------------------------------------------------------------------- 1 | //! # Serving 2 | //! Restate services run as an HTTP endpoint. 3 | //! 4 | //! ## Creating an HTTP endpoint 5 | //! 1. Create the endpoint 6 | //! 2. Bind one or multiple services to it. 7 | //! 3. Listen on the specified port (default `9080`) for connections and requests. 8 | //! 9 | //! ```rust,no_run 10 | //! # #[path = "../examples/services/mod.rs"] 11 | //! # mod services; 12 | //! # use services::my_service::{MyService, MyServiceImpl}; 13 | //! # use services::my_virtual_object::{MyVirtualObject, MyVirtualObjectImpl}; 14 | //! # use services::my_workflow::{MyWorkflow, MyWorkflowImpl}; 15 | //! use restate_sdk::endpoint::Endpoint; 16 | //! use restate_sdk::http_server::HttpServer; 17 | //! 18 | //! #[tokio::main] 19 | //! async fn main() { 20 | //! HttpServer::new( 21 | //! Endpoint::builder() 22 | //! .bind(MyServiceImpl.serve()) 23 | //! .bind(MyVirtualObjectImpl.serve()) 24 | //! .bind(MyWorkflowImpl.serve()) 25 | //! .build(), 26 | //! ) 27 | //! .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 28 | //! .await; 29 | //! } 30 | //! ``` 31 | //! 32 | //! 33 | //! ## Validating request identity 34 | //! 35 | //! SDKs can validate that incoming requests come from a particular Restate 36 | //! instance. You can find out more about request identity in the [Security docs](https://docs.restate.dev/operate/security#locking-down-service-access). 37 | //! Add the identity key to your endpoint as follows: 38 | //! 39 | //! ```rust,no_run 40 | //! # #[path = "../examples/services/mod.rs"] 41 | //! # mod services; 42 | //! # use services::my_service::{MyService, MyServiceImpl}; 43 | //! # use restate_sdk::endpoint::Endpoint; 44 | //! # use restate_sdk::http_server::HttpServer; 45 | //! # 46 | //! # #[tokio::main] 47 | //! # async fn main() { 48 | //! HttpServer::new( 49 | //! Endpoint::builder() 50 | //! .bind(MyServiceImpl.serve()) 51 | //! .identity_key("publickeyv1_w7YHemBctH5Ck2nQRQ47iBBqhNHy4FV7t2Usbye2A6f") 52 | //! .unwrap() 53 | //! .build(), 54 | //! ) 55 | //! .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 56 | //! .await; 57 | //! # } 58 | //! ``` 59 | 60 | use crate::endpoint::Endpoint; 61 | use crate::hyper::HyperEndpoint; 62 | use futures::FutureExt; 63 | use hyper::server::conn::http2; 64 | use hyper_util::rt::{TokioExecutor, TokioIo}; 65 | use std::future::Future; 66 | use std::net::SocketAddr; 67 | use std::time::Duration; 68 | use tokio::net::TcpListener; 69 | use tracing::{info, warn}; 70 | 71 | /// Http server to expose your Restate services. 72 | pub struct HttpServer { 73 | endpoint: Endpoint, 74 | } 75 | 76 | impl From for HttpServer { 77 | fn from(endpoint: Endpoint) -> Self { 78 | Self { endpoint } 79 | } 80 | } 81 | 82 | impl HttpServer { 83 | /// Create new [`HttpServer`] from an [`Endpoint`]. 84 | pub fn new(endpoint: Endpoint) -> Self { 85 | Self { endpoint } 86 | } 87 | 88 | /// Listen on the given address and serve. 89 | /// 90 | /// The future will be completed once `SIGTERM` is sent to the process. 91 | pub async fn listen_and_serve(self, addr: SocketAddr) { 92 | let listener = TcpListener::bind(addr).await.expect("listener can bind"); 93 | self.serve(listener).await; 94 | } 95 | 96 | /// Serve on the given listener. 97 | /// 98 | /// The future will be completed once `SIGTERM` is sent to the process. 99 | pub async fn serve(self, listener: TcpListener) { 100 | self.serve_with_cancel(listener, tokio::signal::ctrl_c().map(|_| ())) 101 | .await; 102 | } 103 | 104 | /// Serve on the given listener, and cancel the execution with the given future. 105 | pub async fn serve_with_cancel(self, listener: TcpListener, cancel_signal_future: impl Future) { 106 | let endpoint = HyperEndpoint::new(self.endpoint); 107 | let graceful = hyper_util::server::graceful::GracefulShutdown::new(); 108 | 109 | // when this signal completes, start shutdown 110 | let mut signal = std::pin::pin!(cancel_signal_future); 111 | 112 | info!("Starting listening on {}", listener.local_addr().unwrap()); 113 | 114 | // Our server accept loop 115 | loop { 116 | tokio::select! { 117 | Ok((stream, remote)) = listener.accept() => { 118 | let endpoint = endpoint.clone(); 119 | 120 | let conn = http2::Builder::new(TokioExecutor::default()) 121 | .serve_connection(TokioIo::new(stream), endpoint); 122 | 123 | let fut = graceful.watch(conn); 124 | 125 | tokio::spawn(async move { 126 | if let Err(e) = fut.await { 127 | warn!("Error serving connection {remote}: {:?}", e); 128 | } 129 | }); 130 | }, 131 | _ = &mut signal => { 132 | info!("Shutting down"); 133 | // stop the accept loop 134 | break; 135 | } 136 | } 137 | } 138 | 139 | // Wait graceful shutdown 140 | tokio::select! { 141 | _ = graceful.shutdown() => {}, 142 | _ = tokio::time::sleep(Duration::from_secs(10)) => { 143 | warn!("Timed out waiting for all connections to close"); 144 | } 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /src/hyper.rs: -------------------------------------------------------------------------------- 1 | //! Hyper integration. 2 | 3 | use crate::endpoint; 4 | use crate::endpoint::{Endpoint, InputReceiver, OutputSender}; 5 | use bytes::Bytes; 6 | use futures::future::BoxFuture; 7 | use futures::{FutureExt, TryStreamExt}; 8 | use http::header::CONTENT_TYPE; 9 | use http::{response, HeaderName, HeaderValue, Request, Response}; 10 | use http_body_util::{BodyExt, Either, Full}; 11 | use hyper::body::{Body, Frame, Incoming}; 12 | use hyper::service::Service; 13 | use restate_sdk_shared_core::Header; 14 | use std::convert::Infallible; 15 | use std::future::{ready, Ready}; 16 | use std::ops::Deref; 17 | use std::pin::Pin; 18 | use std::task::{ready, Context, Poll}; 19 | use tokio::sync::mpsc; 20 | use tracing::{debug, warn}; 21 | 22 | #[allow(clippy::declare_interior_mutable_const)] 23 | const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server"); 24 | const X_RESTATE_SERVER_VALUE: HeaderValue = 25 | HeaderValue::from_static(concat!("restate-sdk-rust/", env!("CARGO_PKG_VERSION"))); 26 | 27 | /// Wraps [`Endpoint`] to implement hyper [`Service`]. 28 | #[derive(Clone)] 29 | pub struct HyperEndpoint(Endpoint); 30 | 31 | impl HyperEndpoint { 32 | pub fn new(endpoint: Endpoint) -> Self { 33 | Self(endpoint) 34 | } 35 | } 36 | 37 | impl Service> for HyperEndpoint { 38 | type Response = Response, BidiStreamRunner>>; 39 | type Error = endpoint::Error; 40 | type Future = Ready>; 41 | 42 | fn call(&self, req: Request) -> Self::Future { 43 | let (parts, body) = req.into_parts(); 44 | let endpoint_response = match self.0.resolve(parts.uri.path(), parts.headers) { 45 | Ok(res) => res, 46 | Err(err) => { 47 | debug!("Error when trying to handle incoming request: {err}"); 48 | return ready(Ok(Response::builder() 49 | .status(err.status_code()) 50 | .header(CONTENT_TYPE, "text/plain") 51 | .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE) 52 | .body(Either::Left(Full::new(Bytes::from(err.to_string())))) 53 | .expect("Headers should be valid"))); 54 | } 55 | }; 56 | 57 | match endpoint_response { 58 | endpoint::Response::ReplyNow { 59 | status_code, 60 | headers, 61 | body, 62 | } => ready(Ok(response_builder_from_response_head( 63 | status_code, 64 | headers, 65 | ) 66 | .body(Either::Left(Full::new(body))) 67 | .expect("Headers should be valid"))), 68 | endpoint::Response::BidiStream { 69 | status_code, 70 | headers, 71 | handler, 72 | } => { 73 | let input_receiver = 74 | InputReceiver::from_stream(body.into_data_stream().map_err(|e| e.into())); 75 | 76 | let (output_tx, output_rx) = mpsc::unbounded_channel(); 77 | let output_sender = OutputSender::from_channel(output_tx); 78 | 79 | let handler_fut = Box::pin(handler.handle(input_receiver, output_sender)); 80 | 81 | ready(Ok(response_builder_from_response_head( 82 | status_code, 83 | headers, 84 | ) 85 | .body(Either::Right(BidiStreamRunner { 86 | fut: Some(handler_fut), 87 | output_rx, 88 | end_stream: false, 89 | })) 90 | .expect("Headers should be valid"))) 91 | } 92 | } 93 | } 94 | } 95 | 96 | fn response_builder_from_response_head( 97 | status_code: u16, 98 | headers: Vec
, 99 | ) -> response::Builder { 100 | let mut response_builder = Response::builder() 101 | .status(status_code) 102 | .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE); 103 | 104 | for header in headers { 105 | response_builder = response_builder.header(header.key.deref(), header.value.deref()); 106 | } 107 | 108 | response_builder 109 | } 110 | 111 | pub struct BidiStreamRunner { 112 | fut: Option>>, 113 | output_rx: mpsc::UnboundedReceiver, 114 | end_stream: bool, 115 | } 116 | 117 | impl Body for BidiStreamRunner { 118 | type Data = Bytes; 119 | type Error = Infallible; 120 | 121 | fn poll_frame( 122 | mut self: Pin<&mut Self>, 123 | cx: &mut Context<'_>, 124 | ) -> Poll, Self::Error>>> { 125 | // First try to consume the runner future 126 | if let Some(mut fut) = self.fut.take() { 127 | match fut.poll_unpin(cx) { 128 | Poll::Ready(res) => { 129 | if let Err(e) = res { 130 | warn!("Handler failure: {e:?}") 131 | } 132 | self.output_rx.close(); 133 | } 134 | Poll::Pending => { 135 | self.fut = Some(fut); 136 | } 137 | } 138 | } 139 | 140 | if let Some(out) = ready!(self.output_rx.poll_recv(cx)) { 141 | Poll::Ready(Some(Ok(Frame::data(out)))) 142 | } else { 143 | self.end_stream = true; 144 | Poll::Ready(None) 145 | } 146 | } 147 | 148 | fn is_end_stream(&self) -> bool { 149 | self.end_stream 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /src/serde.rs: -------------------------------------------------------------------------------- 1 | //! # Serialization 2 | //! 3 | //! Restate sends data over the network for storing state, journaling actions, awakeables, etc. 4 | //! 5 | //! Therefore, the types of the values that are stored, need to either: 6 | //! - be a primitive type 7 | //! - use a wrapper type [`Json`] for using [`serde-json`](https://serde.rs/). To enable JSON schema generation, you'll need to enable the `schemars` feature. See [PayloadMetadata] for more details. 8 | //! - have the [`Serialize`] and [`Deserialize`] trait implemented. If you need to use a type for the handler input/output, you'll also need to implement [PayloadMetadata] to reply with correct content type and enable **JSON schema generation**. 9 | //! 10 | 11 | use bytes::Bytes; 12 | use std::convert::Infallible; 13 | 14 | const APPLICATION_JSON: &str = "application/json"; 15 | const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; 16 | 17 | /// Serialize trait for Restate services. 18 | /// 19 | /// Default implementations are provided for primitives, and you can use the wrapper type [`Json`] to serialize using [`serde_json`]. 20 | /// 21 | /// This looks similar to [`serde::Serialize`], but allows to plug-in non-serde serialization formats (e.g. like Protobuf using `prost`). 22 | pub trait Serialize { 23 | type Error: std::error::Error + Send + Sync + 'static; 24 | 25 | fn serialize(&self) -> Result; 26 | } 27 | 28 | // TODO perhaps figure out how to add a lifetime here so we can deserialize to borrowed types 29 | /// Deserialize trait for Restate services. 30 | /// 31 | /// Default implementations are provided for primitives, and you can use the wrapper type [`Json`] to serialize using [`serde_json`]. 32 | /// 33 | /// This looks similar to [`serde::Deserialize`], but allows to plug-in non-serde serialization formats (e.g. like Protobuf using `prost`). 34 | pub trait Deserialize 35 | where 36 | Self: Sized, 37 | { 38 | type Error: std::error::Error + Send + Sync + 'static; 39 | 40 | fn deserialize(bytes: &mut Bytes) -> Result; 41 | } 42 | 43 | /// ## Payload metadata and Json Schemas 44 | /// 45 | /// The SDK propagates during discovery some metadata to restate-server service catalog. This includes: 46 | /// 47 | /// * The JSON schema of the payload. See below for more details. 48 | /// * The [InputMetadata] used to instruct restate how to accept requests. 49 | /// * The [OutputMetadata] used to instruct restate how to send responses out. 50 | /// 51 | /// There are three approaches for generating JSON Schemas for handler inputs and outputs: 52 | /// 53 | /// ### 1. Primitive Types 54 | /// 55 | /// Primitive types (like `String`, `u32`, `bool`) have built-in schema implementations 56 | /// that work automatically without additional code: 57 | /// 58 | /// ```rust 59 | /// use restate_sdk::prelude::*; 60 | /// 61 | /// #[restate_sdk::service] 62 | /// trait SimpleService { 63 | /// async fn greet(name: String) -> HandlerResult; 64 | /// } 65 | /// ``` 66 | /// 67 | /// ### 2. Using `Json` with schemars 68 | /// 69 | /// For complex types wrapped in `Json`, you need to add the `schemars` feature and derive `JsonSchema`: 70 | /// 71 | /// ```rust 72 | /// use restate_sdk::prelude::*; 73 | /// 74 | /// #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] 75 | /// struct User { 76 | /// name: String, 77 | /// age: u32, 78 | /// } 79 | /// 80 | /// #[restate_sdk::service] 81 | /// trait UserService { 82 | /// async fn register(user: Json) -> HandlerResult>; 83 | /// } 84 | /// ``` 85 | /// 86 | /// To enable rich schema generation with `Json`, add the `schemars` feature to your dependency: 87 | /// 88 | /// ```toml 89 | /// [dependencies] 90 | /// restate-sdk = { version = "0.3", features = ["schemars"] } 91 | /// schemars = "1.0.0-alpha.17" 92 | /// ``` 93 | /// 94 | /// ### 3. Custom Implementation 95 | /// 96 | /// You can also implement the [PayloadMetadata] trait directly for your types to provide 97 | /// custom schemas without relying on the `schemars` feature: 98 | /// 99 | /// ```rust 100 | /// use restate_sdk::serde::{PayloadMetadata, Serialize, Deserialize}; 101 | /// 102 | /// #[derive(serde::Serialize, serde::Deserialize)] 103 | /// struct User { 104 | /// name: String, 105 | /// age: u32, 106 | /// } 107 | /// 108 | /// // Implement PayloadMetadata directly and override the json_schema implementation 109 | /// impl PayloadMetadata for User { 110 | /// fn json_schema() -> Option { 111 | /// Some(serde_json::json!({ 112 | /// "type": "object", 113 | /// "properties": { 114 | /// "name": {"type": "string"}, 115 | /// "age": {"type": "integer", "minimum": 0} 116 | /// }, 117 | /// "required": ["name", "age"] 118 | /// })) 119 | /// } 120 | /// } 121 | /// ``` 122 | /// 123 | /// Trait encapsulating JSON Schema information for the given serializer/deserializer. 124 | /// 125 | /// This trait allows types to provide JSON Schema information that can be used for 126 | /// documentation, validation, and client generation. 127 | /// 128 | /// ## Behavior with `schemars` Feature Flag 129 | /// 130 | /// When the `schemars` feature is enabled, implementations for complex types use 131 | /// the `schemars` crate to automatically generate rich, JSON Schema 2020-12 conforming schemas. 132 | /// When the feature is disabled, primitive types still provide basic schemas, 133 | /// but complex types return empty schemas, unless manually implemented. 134 | pub trait PayloadMetadata { 135 | /// Generate a JSON Schema for this type. 136 | /// 137 | /// Returns a JSON value representing the schema for this type. When the `schemars` 138 | /// feature is enabled, this returns an auto-generated JSON Schema 2020-12 conforming schema. When the feature is disabled, 139 | /// this returns an empty schema for complex types, but basic schemas for primitives. 140 | /// 141 | /// If returns none, no schema is provided. This should be used when the payload is not expected to be json 142 | fn json_schema() -> Option { 143 | Some(serde_json::Value::Object(serde_json::Map::default())) 144 | } 145 | 146 | /// Returns the [InputMetadata]. The default implementation returns metadata suitable for JSON payloads. 147 | fn input_metadata() -> InputMetadata { 148 | InputMetadata::default() 149 | } 150 | 151 | /// Returns the [OutputMetadata]. The default implementation returns metadata suitable for JSON payloads. 152 | fn output_metadata() -> OutputMetadata { 153 | OutputMetadata::default() 154 | } 155 | } 156 | 157 | /// This struct encapsulates input payload metadata used by discovery. 158 | /// 159 | /// The default implementation works well with Json payloads. 160 | pub struct InputMetadata { 161 | /// Content type of the input. It can accept wildcards, in the same format as the 'Accept' header. 162 | /// 163 | /// By default, is `application/json`. 164 | pub accept_content_type: &'static str, 165 | /// If true, Restate itself will reject requests **without content-types**. 166 | pub is_required: bool, 167 | } 168 | 169 | impl Default for InputMetadata { 170 | fn default() -> Self { 171 | Self { 172 | accept_content_type: APPLICATION_JSON, 173 | is_required: true, 174 | } 175 | } 176 | } 177 | 178 | /// This struct encapsulates output payload metadata used by discovery. 179 | /// 180 | /// The default implementation works for Json payloads. 181 | pub struct OutputMetadata { 182 | /// Content type of the output. 183 | /// 184 | /// By default, is `application/json`. 185 | pub content_type: &'static str, 186 | /// If true, the specified content-type is set even if the output is empty. This should be set to `true` only for encodings that can return a serialized empty byte array (e.g. Protobuf). 187 | pub set_content_type_if_empty: bool, 188 | } 189 | 190 | impl Default for OutputMetadata { 191 | fn default() -> Self { 192 | Self { 193 | content_type: APPLICATION_JSON, 194 | set_content_type_if_empty: false, 195 | } 196 | } 197 | } 198 | 199 | // --- Default implementation for Unit type 200 | 201 | impl Serialize for () { 202 | type Error = Infallible; 203 | 204 | fn serialize(&self) -> Result { 205 | Ok(Bytes::new()) 206 | } 207 | } 208 | 209 | impl Deserialize for () { 210 | type Error = Infallible; 211 | 212 | fn deserialize(_: &mut Bytes) -> Result { 213 | Ok(()) 214 | } 215 | } 216 | 217 | // --- Passthrough implementation 218 | 219 | impl Serialize for Vec { 220 | type Error = Infallible; 221 | 222 | fn serialize(&self) -> Result { 223 | Ok(Bytes::copy_from_slice(self)) 224 | } 225 | } 226 | 227 | impl Deserialize for Vec { 228 | type Error = Infallible; 229 | 230 | fn deserialize(b: &mut Bytes) -> Result { 231 | Ok(b.to_vec()) 232 | } 233 | } 234 | 235 | impl PayloadMetadata for Vec { 236 | fn json_schema() -> Option { 237 | None 238 | } 239 | 240 | fn input_metadata() -> InputMetadata { 241 | InputMetadata { 242 | accept_content_type: "*/*", 243 | is_required: true, 244 | } 245 | } 246 | 247 | fn output_metadata() -> OutputMetadata { 248 | OutputMetadata { 249 | content_type: APPLICATION_OCTET_STREAM, 250 | set_content_type_if_empty: false, 251 | } 252 | } 253 | } 254 | 255 | impl Serialize for Bytes { 256 | type Error = Infallible; 257 | 258 | fn serialize(&self) -> Result { 259 | Ok(self.clone()) 260 | } 261 | } 262 | 263 | impl Deserialize for Bytes { 264 | type Error = Infallible; 265 | 266 | fn deserialize(b: &mut Bytes) -> Result { 267 | Ok(b.clone()) 268 | } 269 | } 270 | 271 | impl PayloadMetadata for Bytes { 272 | fn json_schema() -> Option { 273 | None 274 | } 275 | 276 | fn input_metadata() -> InputMetadata { 277 | InputMetadata { 278 | accept_content_type: "*/*", 279 | is_required: true, 280 | } 281 | } 282 | 283 | fn output_metadata() -> OutputMetadata { 284 | OutputMetadata { 285 | content_type: APPLICATION_OCTET_STREAM, 286 | set_content_type_if_empty: false, 287 | } 288 | } 289 | } 290 | // --- Option implementation 291 | 292 | impl Serialize for Option { 293 | type Error = T::Error; 294 | 295 | fn serialize(&self) -> Result { 296 | if self.is_none() { 297 | return Ok(Bytes::new()); 298 | } 299 | T::serialize(self.as_ref().unwrap()) 300 | } 301 | } 302 | 303 | impl Deserialize for Option { 304 | type Error = T::Error; 305 | 306 | fn deserialize(b: &mut Bytes) -> Result { 307 | if b.is_empty() { 308 | return Ok(None); 309 | } 310 | T::deserialize(b).map(Some) 311 | } 312 | } 313 | 314 | impl PayloadMetadata for Option { 315 | fn input_metadata() -> InputMetadata { 316 | InputMetadata { 317 | accept_content_type: T::input_metadata().accept_content_type, 318 | is_required: false, 319 | } 320 | } 321 | 322 | fn output_metadata() -> OutputMetadata { 323 | OutputMetadata { 324 | content_type: T::output_metadata().content_type, 325 | set_content_type_if_empty: false, 326 | } 327 | } 328 | } 329 | 330 | // --- Primitives 331 | 332 | macro_rules! impl_integer_primitives { 333 | ($ty:ty) => { 334 | impl Serialize for $ty { 335 | type Error = serde_json::Error; 336 | 337 | fn serialize(&self) -> Result { 338 | serde_json::to_vec(&self).map(Bytes::from) 339 | } 340 | } 341 | 342 | impl Deserialize for $ty { 343 | type Error = serde_json::Error; 344 | 345 | fn deserialize(bytes: &mut Bytes) -> Result { 346 | serde_json::from_slice(&bytes) 347 | } 348 | } 349 | 350 | impl PayloadMetadata for $ty { 351 | fn json_schema() -> Option { 352 | let min = <$ty>::MIN; 353 | let max = <$ty>::MAX; 354 | Some(serde_json::json!({ "type": "integer", "minimum": min, "maximum": max })) 355 | } 356 | } 357 | }; 358 | } 359 | 360 | impl_integer_primitives!(u8); 361 | impl_integer_primitives!(u16); 362 | impl_integer_primitives!(u32); 363 | impl_integer_primitives!(u64); 364 | impl_integer_primitives!(u128); 365 | impl_integer_primitives!(i8); 366 | impl_integer_primitives!(i16); 367 | impl_integer_primitives!(i32); 368 | impl_integer_primitives!(i64); 369 | impl_integer_primitives!(i128); 370 | 371 | macro_rules! impl_serde_primitives { 372 | ($ty:ty) => { 373 | impl Serialize for $ty { 374 | type Error = serde_json::Error; 375 | 376 | fn serialize(&self) -> Result { 377 | serde_json::to_vec(&self).map(Bytes::from) 378 | } 379 | } 380 | 381 | impl Deserialize for $ty { 382 | type Error = serde_json::Error; 383 | 384 | fn deserialize(bytes: &mut Bytes) -> Result { 385 | serde_json::from_slice(&bytes) 386 | } 387 | } 388 | }; 389 | } 390 | 391 | impl_serde_primitives!(String); 392 | impl_serde_primitives!(bool); 393 | impl_serde_primitives!(f32); 394 | impl_serde_primitives!(f64); 395 | 396 | impl PayloadMetadata for String { 397 | fn json_schema() -> Option { 398 | Some(serde_json::json!({ "type": "string" })) 399 | } 400 | } 401 | 402 | impl PayloadMetadata for bool { 403 | fn json_schema() -> Option { 404 | Some(serde_json::json!({ "type": "boolean" })) 405 | } 406 | } 407 | 408 | impl PayloadMetadata for f32 { 409 | fn json_schema() -> Option { 410 | Some(serde_json::json!({ "type": "number" })) 411 | } 412 | } 413 | 414 | impl PayloadMetadata for f64 { 415 | fn json_schema() -> Option { 416 | Some(serde_json::json!({ "type": "number" })) 417 | } 418 | } 419 | 420 | // --- Json wrapper 421 | 422 | /// Wrapper type to use [`serde_json`] with Restate's [`Serialize`]/[`Deserialize`] traits. 423 | pub struct Json(pub T); 424 | 425 | impl Json { 426 | pub fn into_inner(self) -> T { 427 | self.0 428 | } 429 | } 430 | 431 | impl From for Json { 432 | fn from(value: T) -> Self { 433 | Self(value) 434 | } 435 | } 436 | 437 | impl Serialize for Json 438 | where 439 | T: serde::Serialize, 440 | { 441 | type Error = serde_json::Error; 442 | 443 | fn serialize(&self) -> Result { 444 | serde_json::to_vec(&self.0).map(Bytes::from) 445 | } 446 | } 447 | 448 | impl Deserialize for Json 449 | where 450 | for<'a> T: serde::Deserialize<'a>, 451 | { 452 | type Error = serde_json::Error; 453 | 454 | fn deserialize(bytes: &mut Bytes) -> Result { 455 | serde_json::from_slice(bytes).map(Json) 456 | } 457 | } 458 | 459 | impl Default for Json { 460 | fn default() -> Self { 461 | Self(T::default()) 462 | } 463 | } 464 | 465 | // When schemars is disabled - works with any T 466 | #[cfg(not(feature = "schemars"))] 467 | impl PayloadMetadata for Json { 468 | fn json_schema() -> Option { 469 | Some(serde_json::json!({})) 470 | } 471 | } 472 | 473 | // When schemars is enabled - requires T: JsonSchema 474 | #[cfg(feature = "schemars")] 475 | impl PayloadMetadata for Json { 476 | fn json_schema() -> Option { 477 | Some(schemars::schema_for!(T).to_value()) 478 | } 479 | } 480 | -------------------------------------------------------------------------------- /src/service.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint; 2 | use futures::future::BoxFuture; 3 | use std::future::Future; 4 | 5 | /// Trait representing a Restate service. 6 | /// 7 | /// This is used by codegen. 8 | pub trait Service { 9 | type Future: Future> + Send + 'static; 10 | 11 | /// Handle an incoming request. 12 | fn handle(&self, req: endpoint::ContextInternal) -> Self::Future; 13 | } 14 | 15 | /// Trait representing a discoverable Restate service. 16 | /// 17 | /// This is used by codegen. 18 | pub trait Discoverable { 19 | fn discover() -> crate::discovery::Service; 20 | } 21 | 22 | /// Used by codegen 23 | #[doc(hidden)] 24 | pub type ServiceBoxFuture = BoxFuture<'static, Result<(), endpoint::Error>>; 25 | -------------------------------------------------------------------------------- /test-services/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "test-services" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | anyhow = "1.0" 9 | bytes = "1.10.1" 10 | tokio = { version = "1", features = ["full"] } 11 | tracing-subscriber = "0.3" 12 | futures = "0.3" 13 | restate-sdk = { path = "..", features = ["schemars"] } 14 | schemars = "1.0.0-alpha.17" 15 | serde = { version = "1", features = ["derive"] } 16 | tracing = "0.1.40" 17 | -------------------------------------------------------------------------------- /test-services/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust:1.81 2 | 3 | WORKDIR /app 4 | 5 | COPY . . 6 | RUN cargo build -p test-services 7 | RUN cp ./target/debug/test-services /bin/server 8 | 9 | ENV RUST_LOG="debug,restate_shared_core=trace" 10 | ENV RUST_BACKTRACE=1 11 | 12 | CMD ["/bin/server"] -------------------------------------------------------------------------------- /test-services/README.md: -------------------------------------------------------------------------------- 1 | # Test services 2 | 3 | To build (from the repo root): 4 | 5 | ```shell 6 | $ podman build -f test-services/Dockerfile -t restatedev/rust-test-services . 7 | ``` 8 | 9 | To run (download the [sdk-test-suite](https://github.com/restatedev/sdk-test-suite) first): 10 | 11 | ```shell 12 | $ java -jar restate-sdk-test-suite.jar run localhost/restatedev/rust-test-services:latest 13 | ``` -------------------------------------------------------------------------------- /test-services/exclusions.yaml: -------------------------------------------------------------------------------- 1 | exclusions: 2 | "alwaysSuspending": 3 | - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" 4 | - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" 5 | "default": 6 | - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" 7 | - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" 8 | "singleThreadSinglePartition": 9 | - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" 10 | - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" 11 | "threeNodes": 12 | - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" 13 | - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" 14 | "threeNodesAlwaysSuspending": 15 | - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" 16 | - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" 17 | -------------------------------------------------------------------------------- /test-services/src/awakeable_holder.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | #[restate_sdk::object] 4 | #[name = "AwakeableHolder"] 5 | pub(crate) trait AwakeableHolder { 6 | #[name = "hold"] 7 | async fn hold(id: String) -> HandlerResult<()>; 8 | #[name = "hasAwakeable"] 9 | #[shared] 10 | async fn has_awakeable() -> HandlerResult; 11 | #[name = "unlock"] 12 | async fn unlock(payload: String) -> HandlerResult<()>; 13 | } 14 | 15 | pub(crate) struct AwakeableHolderImpl; 16 | 17 | const ID: &str = "id"; 18 | 19 | impl AwakeableHolder for AwakeableHolderImpl { 20 | async fn hold(&self, context: ObjectContext<'_>, id: String) -> HandlerResult<()> { 21 | context.set(ID, id); 22 | Ok(()) 23 | } 24 | 25 | async fn has_awakeable(&self, context: SharedObjectContext<'_>) -> HandlerResult { 26 | Ok(context.get::(ID).await?.is_some()) 27 | } 28 | 29 | async fn unlock(&self, context: ObjectContext<'_>, payload: String) -> HandlerResult<()> { 30 | let k: String = context.get(ID).await?.ok_or_else(|| { 31 | TerminalError::new(format!( 32 | "No awakeable stored for awakeable holder {}", 33 | context.key() 34 | )) 35 | })?; 36 | context.resolve_awakeable(&k, payload); 37 | Ok(()) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /test-services/src/block_and_wait_workflow.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | #[restate_sdk::workflow] 4 | #[name = "BlockAndWaitWorkflow"] 5 | pub(crate) trait BlockAndWaitWorkflow { 6 | #[name = "run"] 7 | async fn run(input: String) -> HandlerResult; 8 | #[name = "unblock"] 9 | #[shared] 10 | async fn unblock(output: String) -> HandlerResult<()>; 11 | #[name = "getState"] 12 | #[shared] 13 | async fn get_state() -> HandlerResult>>; 14 | } 15 | 16 | pub(crate) struct BlockAndWaitWorkflowImpl; 17 | 18 | const MY_PROMISE: &str = "my-promise"; 19 | const MY_STATE: &str = "my-state"; 20 | 21 | impl BlockAndWaitWorkflow for BlockAndWaitWorkflowImpl { 22 | async fn run(&self, context: WorkflowContext<'_>, input: String) -> HandlerResult { 23 | context.set(MY_STATE, input); 24 | 25 | let promise: String = context.promise(MY_PROMISE).await?; 26 | 27 | if context.peek_promise::(MY_PROMISE).await?.is_none() { 28 | return Err(TerminalError::new("Durable promise should be completed").into()); 29 | } 30 | 31 | Ok(promise) 32 | } 33 | 34 | async fn unblock( 35 | &self, 36 | context: SharedWorkflowContext<'_>, 37 | output: String, 38 | ) -> HandlerResult<()> { 39 | context.resolve_promise(MY_PROMISE, output); 40 | Ok(()) 41 | } 42 | 43 | async fn get_state( 44 | &self, 45 | context: SharedWorkflowContext<'_>, 46 | ) -> HandlerResult>> { 47 | Ok(Json(context.get::(MY_STATE).await?)) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /test-services/src/cancel_test.rs: -------------------------------------------------------------------------------- 1 | use crate::awakeable_holder; 2 | use anyhow::anyhow; 3 | use restate_sdk::prelude::*; 4 | use schemars::JsonSchema; 5 | use serde::{Deserialize, Serialize}; 6 | use std::time::Duration; 7 | 8 | #[derive(Serialize, Deserialize, JsonSchema)] 9 | #[serde(rename_all = "SCREAMING_SNAKE_CASE")] 10 | pub(crate) enum BlockingOperation { 11 | Call, 12 | Sleep, 13 | Awakeable, 14 | } 15 | 16 | #[restate_sdk::object] 17 | #[name = "CancelTestRunner"] 18 | pub(crate) trait CancelTestRunner { 19 | #[name = "startTest"] 20 | async fn start_test(op: Json) -> HandlerResult<()>; 21 | #[name = "verifyTest"] 22 | async fn verify_test() -> HandlerResult; 23 | } 24 | 25 | pub(crate) struct CancelTestRunnerImpl; 26 | 27 | const CANCELED: &str = "canceled"; 28 | 29 | impl CancelTestRunner for CancelTestRunnerImpl { 30 | async fn start_test( 31 | &self, 32 | context: ObjectContext<'_>, 33 | op: Json, 34 | ) -> HandlerResult<()> { 35 | let this = context.object_client::(context.key()); 36 | 37 | match this.block(op).call().await { 38 | Ok(_) => Err(anyhow!("Block succeeded, this is unexpected").into()), 39 | Err(e) if e.code() == 409 => { 40 | context.set(CANCELED, true); 41 | Ok(()) 42 | } 43 | Err(e) => Err(e.into()), 44 | } 45 | } 46 | 47 | async fn verify_test(&self, context: ObjectContext<'_>) -> HandlerResult { 48 | Ok(context.get::(CANCELED).await?.unwrap_or(false)) 49 | } 50 | } 51 | 52 | #[restate_sdk::object] 53 | #[name = "CancelTestBlockingService"] 54 | pub(crate) trait CancelTestBlockingService { 55 | #[name = "block"] 56 | async fn block(op: Json) -> HandlerResult<()>; 57 | #[name = "isUnlocked"] 58 | async fn is_unlocked() -> HandlerResult<()>; 59 | } 60 | 61 | pub(crate) struct CancelTestBlockingServiceImpl; 62 | 63 | impl CancelTestBlockingService for CancelTestBlockingServiceImpl { 64 | async fn block( 65 | &self, 66 | context: ObjectContext<'_>, 67 | op: Json, 68 | ) -> HandlerResult<()> { 69 | let this = context.object_client::(context.key()); 70 | let awakeable_holder_client = 71 | context.object_client::(context.key()); 72 | 73 | let (awk_id, awakeable) = context.awakeable::(); 74 | awakeable_holder_client.hold(awk_id).call().await?; 75 | awakeable.await?; 76 | 77 | match &op.0 { 78 | BlockingOperation::Call => { 79 | this.block(op).call().await?; 80 | } 81 | BlockingOperation::Sleep => { 82 | context 83 | .sleep(Duration::from_secs(60 * 60 * 24 * 1024)) 84 | .await?; 85 | } 86 | BlockingOperation::Awakeable => { 87 | let (_, uncompletable) = context.awakeable::(); 88 | uncompletable.await?; 89 | } 90 | } 91 | 92 | Ok(()) 93 | } 94 | 95 | async fn is_unlocked(&self, _: ObjectContext<'_>) -> HandlerResult<()> { 96 | // no-op 97 | Ok(()) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /test-services/src/counter.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | use schemars::JsonSchema; 3 | use serde::{Deserialize, Serialize}; 4 | use tracing::info; 5 | 6 | #[derive(Serialize, Deserialize, JsonSchema)] 7 | #[serde(rename_all = "camelCase")] 8 | pub(crate) struct CounterUpdateResponse { 9 | old_value: u64, 10 | new_value: u64, 11 | } 12 | 13 | #[restate_sdk::object] 14 | #[name = "Counter"] 15 | pub(crate) trait Counter { 16 | #[name = "add"] 17 | async fn add(val: u64) -> HandlerResult>; 18 | #[name = "addThenFail"] 19 | async fn add_then_fail(val: u64) -> HandlerResult<()>; 20 | #[shared] 21 | #[name = "get"] 22 | async fn get() -> HandlerResult; 23 | #[name = "reset"] 24 | async fn reset() -> HandlerResult<()>; 25 | } 26 | 27 | pub(crate) struct CounterImpl; 28 | 29 | const COUNT: &str = "counter"; 30 | 31 | impl Counter for CounterImpl { 32 | async fn get(&self, ctx: SharedObjectContext<'_>) -> HandlerResult { 33 | Ok(ctx.get::(COUNT).await?.unwrap_or(0)) 34 | } 35 | 36 | async fn add( 37 | &self, 38 | ctx: ObjectContext<'_>, 39 | val: u64, 40 | ) -> HandlerResult> { 41 | let current = ctx.get::(COUNT).await?.unwrap_or(0); 42 | let new = current + val; 43 | ctx.set(COUNT, new); 44 | 45 | info!("Old count {}, new count {}", current, new); 46 | 47 | Ok(CounterUpdateResponse { 48 | old_value: current, 49 | new_value: new, 50 | } 51 | .into()) 52 | } 53 | 54 | async fn reset(&self, ctx: ObjectContext<'_>) -> HandlerResult<()> { 55 | ctx.clear(COUNT); 56 | Ok(()) 57 | } 58 | 59 | async fn add_then_fail(&self, ctx: ObjectContext<'_>, val: u64) -> HandlerResult<()> { 60 | let current = ctx.get::(COUNT).await?.unwrap_or(0); 61 | let new = current + val; 62 | ctx.set(COUNT, new); 63 | 64 | info!("Old count {}, new count {}", current, new); 65 | 66 | Err(TerminalError::new(ctx.key()).into()) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /test-services/src/failing.rs: -------------------------------------------------------------------------------- 1 | use anyhow::anyhow; 2 | use restate_sdk::prelude::*; 3 | use std::sync::atomic::{AtomicI32, Ordering}; 4 | use std::sync::Arc; 5 | use std::time::Duration; 6 | 7 | #[restate_sdk::object] 8 | #[name = "Failing"] 9 | pub(crate) trait Failing { 10 | #[name = "terminallyFailingCall"] 11 | async fn terminally_failing_call(error_message: String) -> HandlerResult<()>; 12 | #[name = "callTerminallyFailingCall"] 13 | async fn call_terminally_failing_call(error_message: String) -> HandlerResult; 14 | #[name = "failingCallWithEventualSuccess"] 15 | async fn failing_call_with_eventual_success() -> HandlerResult; 16 | #[name = "terminallyFailingSideEffect"] 17 | async fn terminally_failing_side_effect(error_message: String) -> HandlerResult<()>; 18 | #[name = "sideEffectSucceedsAfterGivenAttempts"] 19 | async fn side_effect_succeeds_after_given_attempts(minimum_attempts: i32) 20 | -> HandlerResult; 21 | #[name = "sideEffectFailsAfterGivenAttempts"] 22 | async fn side_effect_fails_after_given_attempts( 23 | retry_policy_max_retry_count: i32, 24 | ) -> HandlerResult; 25 | } 26 | 27 | #[derive(Clone, Default)] 28 | pub(crate) struct FailingImpl { 29 | eventual_success_calls: Arc, 30 | eventual_success_side_effects: Arc, 31 | eventual_failure_side_effects: Arc, 32 | } 33 | 34 | impl Failing for FailingImpl { 35 | async fn terminally_failing_call( 36 | &self, 37 | _: ObjectContext<'_>, 38 | error_message: String, 39 | ) -> HandlerResult<()> { 40 | Err(TerminalError::new(error_message).into()) 41 | } 42 | 43 | async fn call_terminally_failing_call( 44 | &self, 45 | mut context: ObjectContext<'_>, 46 | error_message: String, 47 | ) -> HandlerResult { 48 | let uuid = context.rand_uuid().to_string(); 49 | context 50 | .object_client::(uuid) 51 | .terminally_failing_call(error_message) 52 | .call() 53 | .await?; 54 | 55 | unreachable!("This should be unreachable") 56 | } 57 | 58 | async fn failing_call_with_eventual_success(&self, _: ObjectContext<'_>) -> HandlerResult { 59 | let current_attempt = self.eventual_success_calls.fetch_add(1, Ordering::SeqCst) + 1; 60 | 61 | if current_attempt >= 4 { 62 | self.eventual_success_calls.store(0, Ordering::SeqCst); 63 | Ok(current_attempt) 64 | } else { 65 | Err(anyhow!("Failed at attempt ${current_attempt}").into()) 66 | } 67 | } 68 | 69 | async fn terminally_failing_side_effect( 70 | &self, 71 | context: ObjectContext<'_>, 72 | error_message: String, 73 | ) -> HandlerResult<()> { 74 | context 75 | .run::<_, _, ()>(|| async move { Err(TerminalError::new(error_message))? }) 76 | .await?; 77 | 78 | unreachable!("This should be unreachable") 79 | } 80 | 81 | async fn side_effect_succeeds_after_given_attempts( 82 | &self, 83 | context: ObjectContext<'_>, 84 | minimum_attempts: i32, 85 | ) -> HandlerResult { 86 | let cloned_counter = Arc::clone(&self.eventual_success_side_effects); 87 | let success_attempt = context 88 | .run(|| async move { 89 | let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; 90 | 91 | if current_attempt >= minimum_attempts { 92 | cloned_counter.store(0, Ordering::SeqCst); 93 | Ok(current_attempt) 94 | } else { 95 | Err(anyhow!("Failed at attempt {current_attempt}"))? 96 | } 97 | }) 98 | .retry_policy( 99 | RunRetryPolicy::new() 100 | .initial_delay(Duration::from_millis(10)) 101 | .exponentiation_factor(1.0), 102 | ) 103 | .name("failing_side_effect") 104 | .await?; 105 | 106 | Ok(success_attempt) 107 | } 108 | 109 | async fn side_effect_fails_after_given_attempts( 110 | &self, 111 | context: ObjectContext<'_>, 112 | retry_policy_max_retry_count: i32, 113 | ) -> HandlerResult { 114 | let cloned_counter = Arc::clone(&self.eventual_failure_side_effects); 115 | if context 116 | .run(|| async move { 117 | let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; 118 | Err::<(), _>(anyhow!("Failed at attempt {current_attempt}").into()) 119 | }) 120 | .retry_policy( 121 | RunRetryPolicy::new() 122 | .initial_delay(Duration::from_millis(10)) 123 | .exponentiation_factor(1.0) 124 | .max_attempts(retry_policy_max_retry_count as u32), 125 | ) 126 | .await 127 | .is_err() 128 | { 129 | Ok(self.eventual_failure_side_effects.load(Ordering::SeqCst)) 130 | } else { 131 | Err(TerminalError::new("Expecting the side effect to fail!"))? 132 | } 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /test-services/src/kill_test.rs: -------------------------------------------------------------------------------- 1 | use crate::awakeable_holder; 2 | use restate_sdk::prelude::*; 3 | 4 | #[restate_sdk::object] 5 | #[name = "KillTestRunner"] 6 | pub(crate) trait KillTestRunner { 7 | #[name = "startCallTree"] 8 | async fn start_call_tree() -> HandlerResult<()>; 9 | } 10 | 11 | pub(crate) struct KillTestRunnerImpl; 12 | 13 | impl KillTestRunner for KillTestRunnerImpl { 14 | async fn start_call_tree(&self, context: ObjectContext<'_>) -> HandlerResult<()> { 15 | context 16 | .object_client::(context.key()) 17 | .recursive_call() 18 | .call() 19 | .await?; 20 | Ok(()) 21 | } 22 | } 23 | 24 | #[restate_sdk::object] 25 | #[name = "KillTestSingleton"] 26 | pub(crate) trait KillTestSingleton { 27 | #[name = "recursiveCall"] 28 | async fn recursive_call() -> HandlerResult<()>; 29 | #[name = "isUnlocked"] 30 | async fn is_unlocked() -> HandlerResult<()>; 31 | } 32 | 33 | pub(crate) struct KillTestSingletonImpl; 34 | 35 | impl KillTestSingleton for KillTestSingletonImpl { 36 | async fn recursive_call(&self, context: ObjectContext<'_>) -> HandlerResult<()> { 37 | let awakeable_holder_client = 38 | context.object_client::(context.key()); 39 | 40 | let (awk_id, awakeable) = context.awakeable::<()>(); 41 | awakeable_holder_client.hold(awk_id).send(); 42 | awakeable.await?; 43 | 44 | context 45 | .object_client::(context.key()) 46 | .recursive_call() 47 | .call() 48 | .await?; 49 | 50 | Ok(()) 51 | } 52 | 53 | async fn is_unlocked(&self, _: ObjectContext<'_>) -> HandlerResult<()> { 54 | // no-op 55 | Ok(()) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /test-services/src/list_object.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | #[restate_sdk::object] 4 | #[name = "ListObject"] 5 | pub(crate) trait ListObject { 6 | #[name = "append"] 7 | async fn append(value: String) -> HandlerResult<()>; 8 | #[name = "get"] 9 | async fn get() -> HandlerResult>>; 10 | #[name = "clear"] 11 | async fn clear() -> HandlerResult>>; 12 | } 13 | 14 | pub(crate) struct ListObjectImpl; 15 | 16 | const LIST: &str = "list"; 17 | 18 | impl ListObject for ListObjectImpl { 19 | async fn append(&self, ctx: ObjectContext<'_>, value: String) -> HandlerResult<()> { 20 | let mut list = ctx 21 | .get::>>(LIST) 22 | .await? 23 | .unwrap_or_default() 24 | .into_inner(); 25 | list.push(value); 26 | ctx.set(LIST, Json(list)); 27 | Ok(()) 28 | } 29 | 30 | async fn get(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { 31 | Ok(ctx 32 | .get::>>(LIST) 33 | .await? 34 | .unwrap_or_default()) 35 | } 36 | 37 | async fn clear(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { 38 | let get = ctx 39 | .get::>>(LIST) 40 | .await? 41 | .unwrap_or_default(); 42 | ctx.clear(LIST); 43 | Ok(get) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /test-services/src/main.rs: -------------------------------------------------------------------------------- 1 | mod awakeable_holder; 2 | mod block_and_wait_workflow; 3 | mod cancel_test; 4 | mod counter; 5 | mod failing; 6 | mod kill_test; 7 | mod list_object; 8 | mod map_object; 9 | mod non_deterministic; 10 | mod proxy; 11 | mod test_utils_service; 12 | mod virtual_object_command_interpreter; 13 | 14 | use restate_sdk::prelude::{Endpoint, HttpServer}; 15 | use std::env; 16 | 17 | #[tokio::main] 18 | async fn main() { 19 | tracing_subscriber::fmt::init(); 20 | let port = env::var("PORT").ok().unwrap_or("9080".to_string()); 21 | let services = env::var("SERVICES").ok().unwrap_or("*".to_string()); 22 | 23 | let mut builder = Endpoint::builder(); 24 | 25 | if services == "*" || services.contains("Counter") { 26 | builder = builder.bind(counter::Counter::serve(counter::CounterImpl)) 27 | } 28 | if services == "*" || services.contains("Proxy") { 29 | builder = builder.bind(proxy::Proxy::serve(proxy::ProxyImpl)) 30 | } 31 | if services == "*" || services.contains("MapObject") { 32 | builder = builder.bind(map_object::MapObject::serve(map_object::MapObjectImpl)) 33 | } 34 | if services == "*" || services.contains("ListObject") { 35 | builder = builder.bind(list_object::ListObject::serve(list_object::ListObjectImpl)) 36 | } 37 | if services == "*" || services.contains("AwakeableHolder") { 38 | builder = builder.bind(awakeable_holder::AwakeableHolder::serve( 39 | awakeable_holder::AwakeableHolderImpl, 40 | )) 41 | } 42 | if services == "*" || services.contains("BlockAndWaitWorkflow") { 43 | builder = builder.bind(block_and_wait_workflow::BlockAndWaitWorkflow::serve( 44 | block_and_wait_workflow::BlockAndWaitWorkflowImpl, 45 | )) 46 | } 47 | if services == "*" || services.contains("CancelTestRunner") { 48 | builder = builder.bind(cancel_test::CancelTestRunner::serve( 49 | cancel_test::CancelTestRunnerImpl, 50 | )) 51 | } 52 | if services == "*" || services.contains("CancelTestBlockingService") { 53 | builder = builder.bind(cancel_test::CancelTestBlockingService::serve( 54 | cancel_test::CancelTestBlockingServiceImpl, 55 | )) 56 | } 57 | if services == "*" || services.contains("Failing") { 58 | builder = builder.bind(failing::Failing::serve(failing::FailingImpl::default())) 59 | } 60 | if services == "*" || services.contains("KillTestRunner") { 61 | builder = builder.bind(kill_test::KillTestRunner::serve( 62 | kill_test::KillTestRunnerImpl, 63 | )) 64 | } 65 | if services == "*" || services.contains("KillTestSingleton") { 66 | builder = builder.bind(kill_test::KillTestSingleton::serve( 67 | kill_test::KillTestSingletonImpl, 68 | )) 69 | } 70 | if services == "*" || services.contains("NonDeterministic") { 71 | builder = builder.bind(non_deterministic::NonDeterministic::serve( 72 | non_deterministic::NonDeterministicImpl::default(), 73 | )) 74 | } 75 | if services == "*" || services.contains("TestUtilsService") { 76 | builder = builder.bind(test_utils_service::TestUtilsService::serve( 77 | test_utils_service::TestUtilsServiceImpl, 78 | )) 79 | } 80 | if services == "*" || services.contains("VirtualObjectCommandInterpreter") { 81 | builder = builder.bind( 82 | virtual_object_command_interpreter::VirtualObjectCommandInterpreter::serve( 83 | virtual_object_command_interpreter::VirtualObjectCommandInterpreterImpl, 84 | ), 85 | ) 86 | } 87 | 88 | if let Ok(key) = env::var("E2E_REQUEST_SIGNING_ENV") { 89 | builder = builder.identity_key(&key).unwrap() 90 | } 91 | 92 | HttpServer::new(builder.build()) 93 | .listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap()) 94 | .await; 95 | } 96 | -------------------------------------------------------------------------------- /test-services/src/map_object.rs: -------------------------------------------------------------------------------- 1 | use anyhow::anyhow; 2 | use restate_sdk::prelude::*; 3 | use schemars::JsonSchema; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Serialize, Deserialize, JsonSchema)] 7 | #[serde(rename_all = "camelCase")] 8 | pub(crate) struct Entry { 9 | key: String, 10 | value: String, 11 | } 12 | 13 | #[restate_sdk::object] 14 | #[name = "MapObject"] 15 | pub(crate) trait MapObject { 16 | #[name = "set"] 17 | async fn set(entry: Json) -> HandlerResult<()>; 18 | #[name = "get"] 19 | async fn get(key: String) -> HandlerResult; 20 | #[name = "clearAll"] 21 | async fn clear_all() -> HandlerResult>>; 22 | } 23 | 24 | pub(crate) struct MapObjectImpl; 25 | 26 | impl MapObject for MapObjectImpl { 27 | async fn set( 28 | &self, 29 | ctx: ObjectContext<'_>, 30 | Json(Entry { key, value }): Json, 31 | ) -> HandlerResult<()> { 32 | ctx.set(&key, value); 33 | Ok(()) 34 | } 35 | 36 | async fn get(&self, ctx: ObjectContext<'_>, key: String) -> HandlerResult { 37 | Ok(ctx.get(&key).await?.unwrap_or_default()) 38 | } 39 | 40 | async fn clear_all(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { 41 | let keys = ctx.get_keys().await?; 42 | 43 | let mut entries = vec![]; 44 | for k in keys { 45 | let value = ctx 46 | .get(&k) 47 | .await? 48 | .ok_or_else(|| anyhow!("Missing key {k}"))?; 49 | entries.push(Entry { key: k, value }) 50 | } 51 | 52 | ctx.clear_all(); 53 | 54 | Ok(entries.into()) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /test-services/src/non_deterministic.rs: -------------------------------------------------------------------------------- 1 | use crate::counter::CounterClient; 2 | use restate_sdk::prelude::*; 3 | use std::collections::HashMap; 4 | use std::sync::Arc; 5 | use std::time::Duration; 6 | use tokio::sync::Mutex; 7 | 8 | #[restate_sdk::object] 9 | #[name = "NonDeterministic"] 10 | pub(crate) trait NonDeterministic { 11 | #[name = "eitherSleepOrCall"] 12 | async fn either_sleep_or_call() -> HandlerResult<()>; 13 | #[name = "callDifferentMethod"] 14 | async fn call_different_method() -> HandlerResult<()>; 15 | #[name = "backgroundInvokeWithDifferentTargets"] 16 | async fn background_invoke_with_different_targets() -> HandlerResult<()>; 17 | #[name = "setDifferentKey"] 18 | async fn set_different_key() -> HandlerResult<()>; 19 | } 20 | 21 | #[derive(Clone, Default)] 22 | pub(crate) struct NonDeterministicImpl(Arc>>); 23 | 24 | const STATE_A: &str = "a"; 25 | const STATE_B: &str = "b"; 26 | 27 | impl NonDeterministic for NonDeterministicImpl { 28 | async fn either_sleep_or_call(&self, context: ObjectContext<'_>) -> HandlerResult<()> { 29 | if self.do_left_action(&context).await { 30 | context.sleep(Duration::from_millis(100)).await?; 31 | } else { 32 | context 33 | .object_client::("abc") 34 | .get() 35 | .call() 36 | .await?; 37 | } 38 | Self::sleep_then_increment_counter(&context).await 39 | } 40 | 41 | async fn call_different_method(&self, context: ObjectContext<'_>) -> HandlerResult<()> { 42 | if self.do_left_action(&context).await { 43 | context 44 | .object_client::("abc") 45 | .get() 46 | .call() 47 | .await?; 48 | } else { 49 | context 50 | .object_client::("abc") 51 | .reset() 52 | .call() 53 | .await?; 54 | } 55 | Self::sleep_then_increment_counter(&context).await 56 | } 57 | 58 | async fn background_invoke_with_different_targets( 59 | &self, 60 | context: ObjectContext<'_>, 61 | ) -> HandlerResult<()> { 62 | if self.do_left_action(&context).await { 63 | context.object_client::("abc").get().send(); 64 | } else { 65 | context.object_client::("abc").reset().send(); 66 | } 67 | Self::sleep_then_increment_counter(&context).await 68 | } 69 | 70 | async fn set_different_key(&self, context: ObjectContext<'_>) -> HandlerResult<()> { 71 | if self.do_left_action(&context).await { 72 | context.set(STATE_A, "my-state".to_owned()); 73 | } else { 74 | context.set(STATE_B, "my-state".to_owned()); 75 | } 76 | Self::sleep_then_increment_counter(&context).await 77 | } 78 | } 79 | 80 | impl NonDeterministicImpl { 81 | async fn do_left_action(&self, ctx: &ObjectContext<'_>) -> bool { 82 | let mut counts = self.0.lock().await; 83 | *(counts 84 | .entry(ctx.key().to_owned()) 85 | .and_modify(|i| *i += 1) 86 | .or_default()) 87 | % 2 88 | == 1 89 | } 90 | 91 | async fn sleep_then_increment_counter(ctx: &ObjectContext<'_>) -> HandlerResult<()> { 92 | ctx.sleep(Duration::from_millis(100)).await?; 93 | ctx.object_client::(ctx.key()).add(1).send(); 94 | Ok(()) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /test-services/src/proxy.rs: -------------------------------------------------------------------------------- 1 | use futures::future::BoxFuture; 2 | use futures::FutureExt; 3 | use restate_sdk::context::RequestTarget; 4 | use restate_sdk::prelude::*; 5 | use schemars::JsonSchema; 6 | use serde::{Deserialize, Serialize}; 7 | use std::time::Duration; 8 | 9 | #[derive(Serialize, Deserialize, JsonSchema)] 10 | #[serde(rename_all = "camelCase")] 11 | pub(crate) struct ProxyRequest { 12 | service_name: String, 13 | virtual_object_key: Option, 14 | handler_name: String, 15 | idempotency_key: Option, 16 | message: Vec, 17 | delay_millis: Option, 18 | } 19 | 20 | impl ProxyRequest { 21 | fn to_target(&self) -> RequestTarget { 22 | if let Some(key) = &self.virtual_object_key { 23 | RequestTarget::Object { 24 | name: self.service_name.clone(), 25 | key: key.clone(), 26 | handler: self.handler_name.clone(), 27 | } 28 | } else { 29 | RequestTarget::Service { 30 | name: self.service_name.clone(), 31 | handler: self.handler_name.clone(), 32 | } 33 | } 34 | } 35 | } 36 | 37 | #[derive(Serialize, Deserialize, JsonSchema)] 38 | #[serde(rename_all = "camelCase")] 39 | pub(crate) struct ManyCallRequest { 40 | proxy_request: ProxyRequest, 41 | one_way_call: bool, 42 | await_at_the_end: bool, 43 | } 44 | 45 | #[restate_sdk::service] 46 | #[name = "Proxy"] 47 | pub(crate) trait Proxy { 48 | #[name = "call"] 49 | async fn call(req: Json) -> HandlerResult>>; 50 | #[name = "oneWayCall"] 51 | async fn one_way_call(req: Json) -> HandlerResult; 52 | #[name = "manyCalls"] 53 | async fn many_calls(req: Json>) -> HandlerResult<()>; 54 | } 55 | 56 | pub(crate) struct ProxyImpl; 57 | 58 | impl Proxy for ProxyImpl { 59 | async fn call( 60 | &self, 61 | ctx: Context<'_>, 62 | Json(req): Json, 63 | ) -> HandlerResult>> { 64 | let mut request = ctx.request::, Vec>(req.to_target(), req.message); 65 | if let Some(idempotency_key) = req.idempotency_key { 66 | request = request.idempotency_key(idempotency_key); 67 | } 68 | Ok(request.call().await?.into()) 69 | } 70 | 71 | async fn one_way_call( 72 | &self, 73 | ctx: Context<'_>, 74 | Json(req): Json, 75 | ) -> HandlerResult { 76 | let mut request = ctx.request::<_, ()>(req.to_target(), req.message); 77 | if let Some(idempotency_key) = req.idempotency_key { 78 | request = request.idempotency_key(idempotency_key); 79 | } 80 | 81 | let invocation_id = if let Some(delay_millis) = req.delay_millis { 82 | request 83 | .send_after(Duration::from_millis(delay_millis)) 84 | .invocation_id() 85 | .await? 86 | } else { 87 | request.send().invocation_id().await? 88 | }; 89 | 90 | Ok(invocation_id) 91 | } 92 | 93 | async fn many_calls( 94 | &self, 95 | ctx: Context<'_>, 96 | Json(requests): Json>, 97 | ) -> HandlerResult<()> { 98 | let mut futures: Vec, TerminalError>>> = vec![]; 99 | 100 | for req in requests { 101 | let mut restate_req = 102 | ctx.request::<_, Vec>(req.proxy_request.to_target(), req.proxy_request.message); 103 | if let Some(idempotency_key) = req.proxy_request.idempotency_key { 104 | restate_req = restate_req.idempotency_key(idempotency_key); 105 | } 106 | if req.one_way_call { 107 | if let Some(delay_millis) = req.proxy_request.delay_millis { 108 | restate_req.send_after(Duration::from_millis(delay_millis)); 109 | } else { 110 | restate_req.send(); 111 | } 112 | } else { 113 | let fut = restate_req.call(); 114 | if req.await_at_the_end { 115 | futures.push(fut.boxed()) 116 | } 117 | } 118 | } 119 | 120 | for fut in futures { 121 | fut.await?; 122 | } 123 | 124 | Ok(()) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /test-services/src/test_utils_service.rs: -------------------------------------------------------------------------------- 1 | use futures::future::BoxFuture; 2 | use futures::FutureExt; 3 | use restate_sdk::prelude::*; 4 | use std::collections::HashMap; 5 | use std::convert::Infallible; 6 | use std::sync::atomic::{AtomicU8, Ordering}; 7 | use std::sync::Arc; 8 | use std::time::Duration; 9 | 10 | #[restate_sdk::service] 11 | #[name = "TestUtilsService"] 12 | pub(crate) trait TestUtilsService { 13 | #[name = "echo"] 14 | async fn echo(input: String) -> HandlerResult; 15 | #[name = "uppercaseEcho"] 16 | async fn uppercase_echo(input: String) -> HandlerResult; 17 | #[name = "rawEcho"] 18 | async fn raw_echo(input: bytes::Bytes) -> Result, Infallible>; 19 | #[name = "echoHeaders"] 20 | async fn echo_headers() -> HandlerResult>>; 21 | #[name = "sleepConcurrently"] 22 | async fn sleep_concurrently(millis_durations: Json>) -> HandlerResult<()>; 23 | #[name = "countExecutedSideEffects"] 24 | async fn count_executed_side_effects(increments: u32) -> HandlerResult; 25 | #[name = "cancelInvocation"] 26 | async fn cancel_invocation(invocation_id: String) -> Result<(), TerminalError>; 27 | } 28 | 29 | pub(crate) struct TestUtilsServiceImpl; 30 | 31 | impl TestUtilsService for TestUtilsServiceImpl { 32 | async fn echo(&self, _: Context<'_>, input: String) -> HandlerResult { 33 | Ok(input) 34 | } 35 | 36 | async fn uppercase_echo(&self, _: Context<'_>, input: String) -> HandlerResult { 37 | Ok(input.to_ascii_uppercase()) 38 | } 39 | 40 | async fn raw_echo(&self, _: Context<'_>, input: bytes::Bytes) -> Result, Infallible> { 41 | Ok(input.to_vec()) 42 | } 43 | 44 | async fn echo_headers( 45 | &self, 46 | context: Context<'_>, 47 | ) -> HandlerResult>> { 48 | let mut headers = HashMap::new(); 49 | for k in context.headers().keys() { 50 | headers.insert( 51 | k.as_str().to_owned(), 52 | context.headers().get(k).unwrap().clone(), 53 | ); 54 | } 55 | 56 | Ok(headers.into()) 57 | } 58 | 59 | async fn sleep_concurrently( 60 | &self, 61 | context: Context<'_>, 62 | millis_durations: Json>, 63 | ) -> HandlerResult<()> { 64 | let mut futures: Vec>> = vec![]; 65 | 66 | for duration in millis_durations.into_inner() { 67 | futures.push(context.sleep(Duration::from_millis(duration)).boxed()); 68 | } 69 | 70 | for fut in futures { 71 | fut.await?; 72 | } 73 | 74 | Ok(()) 75 | } 76 | 77 | async fn count_executed_side_effects( 78 | &self, 79 | context: Context<'_>, 80 | increments: u32, 81 | ) -> HandlerResult { 82 | let counter: Arc = Default::default(); 83 | 84 | for _ in 0..increments { 85 | let counter_clone = Arc::clone(&counter); 86 | context 87 | .run(|| async { 88 | counter_clone.fetch_add(1, Ordering::SeqCst); 89 | Ok(()) 90 | }) 91 | .await?; 92 | } 93 | 94 | Ok(counter.load(Ordering::SeqCst) as u32) 95 | } 96 | 97 | async fn cancel_invocation( 98 | &self, 99 | ctx: Context<'_>, 100 | invocation_id: String, 101 | ) -> Result<(), TerminalError> { 102 | ctx.invocation_handle(invocation_id).cancel().await?; 103 | Ok(()) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /test-services/src/virtual_object_command_interpreter.rs: -------------------------------------------------------------------------------- 1 | use anyhow::anyhow; 2 | use futures::TryFutureExt; 3 | use restate_sdk::prelude::*; 4 | use schemars::JsonSchema; 5 | use serde::{Deserialize, Serialize}; 6 | use std::time::Duration; 7 | 8 | #[derive(Serialize, Deserialize, JsonSchema)] 9 | #[serde(rename_all = "camelCase")] 10 | pub(crate) struct InterpretRequest { 11 | commands: Vec, 12 | } 13 | 14 | #[derive(Serialize, Deserialize, JsonSchema)] 15 | #[serde(tag = "type")] 16 | #[serde(rename_all_fields = "camelCase")] 17 | pub(crate) enum Command { 18 | #[serde(rename = "awaitAnySuccessful")] 19 | AwaitAnySuccessful { commands: Vec }, 20 | #[serde(rename = "awaitAny")] 21 | AwaitAny { commands: Vec }, 22 | #[serde(rename = "awaitOne")] 23 | AwaitOne { command: AwaitableCommand }, 24 | #[serde(rename = "awaitAwakeableOrTimeout")] 25 | AwaitAwakeableOrTimeout { 26 | awakeable_key: String, 27 | timeout_millis: u64, 28 | }, 29 | #[serde(rename = "resolveAwakeable")] 30 | ResolveAwakeable { 31 | awakeable_key: String, 32 | value: String, 33 | }, 34 | #[serde(rename = "rejectAwakeable")] 35 | RejectAwakeable { 36 | awakeable_key: String, 37 | reason: String, 38 | }, 39 | #[serde(rename = "getEnvVariable")] 40 | GetEnvVariable { env_name: String }, 41 | } 42 | 43 | #[derive(Serialize, Deserialize, JsonSchema)] 44 | #[serde(tag = "type")] 45 | #[serde(rename_all_fields = "camelCase")] 46 | pub(crate) enum AwaitableCommand { 47 | #[serde(rename = "createAwakeable")] 48 | CreateAwakeable { awakeable_key: String }, 49 | #[serde(rename = "sleep")] 50 | Sleep { timeout_millis: u64 }, 51 | #[serde(rename = "runThrowTerminalException")] 52 | RunThrowTerminalException { reason: String }, 53 | } 54 | 55 | #[derive(Serialize, Deserialize, JsonSchema)] 56 | #[serde(rename_all = "camelCase")] 57 | pub(crate) struct ResolveAwakeable { 58 | awakeable_key: String, 59 | value: String, 60 | } 61 | 62 | #[derive(Serialize, Deserialize, JsonSchema)] 63 | #[serde(rename_all = "camelCase")] 64 | pub(crate) struct RejectAwakeable { 65 | awakeable_key: String, 66 | reason: String, 67 | } 68 | 69 | #[restate_sdk::object] 70 | #[name = "VirtualObjectCommandInterpreter"] 71 | pub(crate) trait VirtualObjectCommandInterpreter { 72 | #[name = "interpretCommands"] 73 | async fn interpret_commands(req: Json) -> HandlerResult; 74 | 75 | #[name = "resolveAwakeable"] 76 | #[shared] 77 | async fn resolve_awakeable(req: Json) -> HandlerResult<()>; 78 | 79 | #[name = "rejectAwakeable"] 80 | #[shared] 81 | async fn reject_awakeable(req: Json) -> HandlerResult<()>; 82 | 83 | #[name = "hasAwakeable"] 84 | #[shared] 85 | async fn has_awakeable(awakeable_key: String) -> HandlerResult; 86 | 87 | #[name = "getResults"] 88 | #[shared] 89 | async fn get_results() -> HandlerResult>>; 90 | } 91 | 92 | pub(crate) struct VirtualObjectCommandInterpreterImpl; 93 | 94 | impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { 95 | async fn interpret_commands( 96 | &self, 97 | context: ObjectContext<'_>, 98 | Json(req): Json, 99 | ) -> HandlerResult { 100 | let mut last_result: String = Default::default(); 101 | 102 | for cmd in req.commands { 103 | match cmd { 104 | Command::AwaitAny { .. } => { 105 | Err(anyhow!("AwaitAny is currently unsupported in the Rust SDK"))? 106 | } 107 | Command::AwaitAnySuccessful { .. } => Err(anyhow!( 108 | "AwaitAnySuccessful is currently unsupported in the Rust SDK" 109 | ))?, 110 | Command::AwaitAwakeableOrTimeout { 111 | awakeable_key, 112 | timeout_millis, 113 | } => { 114 | let (awakeable_id, awk_fut) = context.awakeable::(); 115 | context.set::(&format!("awk-{awakeable_key}"), awakeable_id); 116 | 117 | last_result = restate_sdk::select! { 118 | res = awk_fut => { 119 | res 120 | }, 121 | _ = context.sleep(Duration::from_millis(timeout_millis)) => { 122 | Err(TerminalError::new("await-timeout")) 123 | } 124 | }?; 125 | } 126 | Command::AwaitOne { command } => { 127 | last_result = match command { 128 | AwaitableCommand::CreateAwakeable { awakeable_key } => { 129 | let (awakeable_id, fut) = context.awakeable::(); 130 | context.set::(&format!("awk-{awakeable_key}"), awakeable_id); 131 | fut.await? 132 | } 133 | AwaitableCommand::Sleep { timeout_millis } => { 134 | context 135 | .sleep(Duration::from_millis(timeout_millis)) 136 | .map_ok(|_| "sleep".to_string()) 137 | .await? 138 | } 139 | AwaitableCommand::RunThrowTerminalException { reason } => { 140 | context 141 | .run::<_, _, String>( 142 | || async move { Err(TerminalError::new(reason))? }, 143 | ) 144 | .await? 145 | } 146 | } 147 | } 148 | Command::GetEnvVariable { env_name } => { 149 | last_result = std::env::var(env_name).ok().unwrap_or_default(); 150 | } 151 | Command::ResolveAwakeable { 152 | awakeable_key, 153 | value, 154 | } => { 155 | let Some(awakeable_id) = context 156 | .get::(&format!("awk-{awakeable_key}")) 157 | .await? 158 | else { 159 | Err(TerminalError::new( 160 | "Awakeable is not registered yet".to_string(), 161 | ))? 162 | }; 163 | 164 | context.resolve_awakeable(&awakeable_id, value); 165 | last_result = Default::default(); 166 | } 167 | Command::RejectAwakeable { 168 | awakeable_key, 169 | reason, 170 | } => { 171 | let Some(awakeable_id) = context 172 | .get::(&format!("awk-{awakeable_key}")) 173 | .await? 174 | else { 175 | Err(TerminalError::new( 176 | "Awakeable is not registered yet".to_string(), 177 | ))? 178 | }; 179 | 180 | context.reject_awakeable(&awakeable_id, TerminalError::new(reason)); 181 | last_result = Default::default(); 182 | } 183 | } 184 | 185 | let mut old_results = context 186 | .get::>>("results") 187 | .await? 188 | .unwrap_or_default() 189 | .into_inner(); 190 | old_results.push(last_result.clone()); 191 | context.set("results", Json(old_results)); 192 | } 193 | 194 | Ok(last_result) 195 | } 196 | 197 | async fn resolve_awakeable( 198 | &self, 199 | context: SharedObjectContext<'_>, 200 | req: Json, 201 | ) -> Result<(), HandlerError> { 202 | let ResolveAwakeable { 203 | awakeable_key, 204 | value, 205 | } = req.into_inner(); 206 | let Some(awakeable_id) = context 207 | .get::(&format!("awk-{awakeable_key}")) 208 | .await? 209 | else { 210 | Err(TerminalError::new( 211 | "Awakeable is not registered yet".to_string(), 212 | ))? 213 | }; 214 | 215 | context.resolve_awakeable(&awakeable_id, value); 216 | 217 | Ok(()) 218 | } 219 | 220 | async fn reject_awakeable( 221 | &self, 222 | context: SharedObjectContext<'_>, 223 | req: Json, 224 | ) -> Result<(), HandlerError> { 225 | let RejectAwakeable { 226 | awakeable_key, 227 | reason, 228 | } = req.into_inner(); 229 | let Some(awakeable_id) = context 230 | .get::(&format!("awk-{awakeable_key}")) 231 | .await? 232 | else { 233 | Err(TerminalError::new( 234 | "Awakeable is not registered yet".to_string(), 235 | ))? 236 | }; 237 | 238 | context.reject_awakeable(&awakeable_id, TerminalError::new(reason)); 239 | 240 | Ok(()) 241 | } 242 | 243 | async fn has_awakeable( 244 | &self, 245 | context: SharedObjectContext<'_>, 246 | awakeable_key: String, 247 | ) -> Result { 248 | Ok(context 249 | .get::(&format!("awk-{awakeable_key}")) 250 | .await? 251 | .is_some()) 252 | } 253 | 254 | async fn get_results( 255 | &self, 256 | context: SharedObjectContext<'_>, 257 | ) -> Result>, HandlerError> { 258 | Ok(context 259 | .get::>>("results") 260 | .await? 261 | .unwrap_or_default()) 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /testcontainers/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "restate-sdk-testcontainers" 3 | version = "0.5.0" 4 | edition = "2021" 5 | description = "Restate SDK Testcontainers utilities" 6 | license = "MIT" 7 | repository = "https://github.com/restatedev/sdk-rust" 8 | rust-version = "1.76.0" 9 | 10 | 11 | [dependencies] 12 | anyhow = "1.0.95" 13 | futures = "0.3.31" 14 | reqwest = { version= "0.12.12", features = ["json"] } 15 | restate-sdk = { version = "0.5.0", path = "../" } 16 | serde = "1.0.217" 17 | testcontainers = { version = "0.23.3", features = ["http_wait"] } 18 | tokio = "1.43.0" 19 | tracing = "0.1.41" 20 | tracing-subscriber = "0.3.19" 21 | -------------------------------------------------------------------------------- /testcontainers/src/lib.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use futures::FutureExt; 3 | use restate_sdk::prelude::{Endpoint, HttpServer}; 4 | use serde::{Deserialize, Serialize}; 5 | use testcontainers::core::wait::HttpWaitStrategy; 6 | use testcontainers::{ 7 | core::{IntoContainerPort, WaitFor}, 8 | runners::AsyncRunner, 9 | ContainerAsync, ContainerRequest, GenericImage, ImageExt, 10 | }; 11 | use tokio::{io::AsyncBufReadExt, net::TcpListener, task}; 12 | use tracing::{error, info, warn}; 13 | 14 | // From restate-admin-rest-model 15 | #[derive(Serialize, Deserialize, Debug)] 16 | pub struct RegisterDeploymentRequestHttp { 17 | uri: String, 18 | additional_headers: Option>, 19 | use_http_11: bool, 20 | force: bool, 21 | dry_run: bool, 22 | } 23 | 24 | #[derive(Serialize, Deserialize, Debug)] 25 | pub struct RegisterDeploymentRequestLambda { 26 | arn: String, 27 | assume_role_arn: Option, 28 | force: bool, 29 | dry_run: bool, 30 | } 31 | 32 | #[derive(Serialize, Deserialize, Debug)] 33 | struct VersionResponse { 34 | version: String, 35 | min_admin_api_version: u32, 36 | max_admin_api_version: u32, 37 | } 38 | 39 | pub struct TestEnvironment { 40 | container_name: String, 41 | container_tag: String, 42 | logging: bool, 43 | } 44 | 45 | impl Default for TestEnvironment { 46 | fn default() -> Self { 47 | Self { 48 | container_name: "docker.io/restatedev/restate".to_string(), 49 | container_tag: "latest".to_string(), 50 | logging: false, 51 | } 52 | } 53 | } 54 | 55 | impl TestEnvironment { 56 | // --- Builder methods 57 | 58 | pub fn new() -> Self { 59 | Self::default() 60 | } 61 | 62 | pub fn with_container_logging(mut self) -> Self { 63 | self.logging = true; 64 | self 65 | } 66 | 67 | pub fn with_container(mut self, container_name: String, container_tag: String) -> Self { 68 | self.container_name = container_name; 69 | self.container_tag = container_tag; 70 | 71 | self 72 | } 73 | 74 | // --- Start method 75 | 76 | pub async fn start(self, endpoint: Endpoint) -> Result { 77 | let started_endpoint = StartedEndpoint::serve_endpoint(endpoint).await?; 78 | let started_restate_container = StartedRestateContainer::start_container(&self).await?; 79 | if let Err(e) = started_restate_container 80 | .register_endpoint(&started_endpoint) 81 | .await 82 | { 83 | return Err(anyhow::anyhow!("Failed to register endpoint: {e}")); 84 | } 85 | 86 | Ok(StartedTestEnvironment { 87 | _started_endpoint: started_endpoint, 88 | started_restate_container, 89 | }) 90 | } 91 | } 92 | 93 | struct StartedEndpoint { 94 | port: u16, 95 | _cancel_tx: tokio::sync::oneshot::Sender<()>, 96 | } 97 | 98 | impl StartedEndpoint { 99 | async fn serve_endpoint(endpoint: Endpoint) -> Result { 100 | info!("Starting endpoint server..."); 101 | 102 | // 0.0.0.0:0 will listen on a random port, both IPv4 and IPv6 103 | let host_address = "0.0.0.0:0".to_string(); 104 | let listener = TcpListener::bind(host_address) 105 | .await 106 | .expect("listener can bind"); 107 | let listening_addr = listener.local_addr()?; 108 | let endpoint_server_url = 109 | format!("http://{}:{}", listening_addr.ip(), listening_addr.port()); 110 | 111 | // Start endpoint server 112 | let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel(); 113 | tokio::spawn(async move { 114 | HttpServer::new(endpoint) 115 | .serve_with_cancel(listener, cancel_rx) 116 | .await; 117 | }); 118 | 119 | let client = reqwest::Client::builder().http2_prior_knowledge().build()?; 120 | 121 | // wait for endpoint server to respond 122 | let mut retries = 0; 123 | loop { 124 | match client 125 | .get(format!("{endpoint_server_url}/health",)) 126 | .send() 127 | .await 128 | { 129 | Ok(res) if res.status().is_success() => break, 130 | Ok(res) => { 131 | warn!("Error when waiting for service endpoint server to be healthy, got response {}", res.status()); 132 | retries += 1; 133 | if retries > 10 { 134 | anyhow::bail!("Service endpoint server failed to start") 135 | } 136 | } 137 | Err(err) => { 138 | warn!("Error when waiting for service endpoint server to be healthy, got error {}", err); 139 | retries += 1; 140 | if retries > 10 { 141 | anyhow::bail!("Service endpoint server failed to start") 142 | } 143 | } 144 | } 145 | } 146 | 147 | info!("Service endpoint server listening at: {endpoint_server_url}",); 148 | 149 | Ok(StartedEndpoint { 150 | port: listening_addr.port(), 151 | _cancel_tx: cancel_tx, 152 | }) 153 | } 154 | } 155 | 156 | struct StartedRestateContainer { 157 | _cancel_tx: tokio::sync::oneshot::Sender<()>, 158 | container: ContainerAsync, 159 | ingress_url: String, 160 | } 161 | 162 | impl StartedRestateContainer { 163 | async fn start_container( 164 | test_environment: &TestEnvironment, 165 | ) -> Result { 166 | let image = GenericImage::new( 167 | &test_environment.container_name, 168 | &test_environment.container_tag, 169 | ) 170 | .with_exposed_port(8080.tcp()) 171 | .with_exposed_port(9070.tcp()) 172 | .with_wait_for(WaitFor::Http( 173 | HttpWaitStrategy::new("/restate/health") 174 | .with_port(8080.tcp()) 175 | .with_response_matcher(|res| res.status().is_success()), 176 | )) 177 | .with_wait_for(WaitFor::Http( 178 | HttpWaitStrategy::new("/health") 179 | .with_port(9070.tcp()) 180 | .with_response_matcher(|res| res.status().is_success()), 181 | )); 182 | 183 | // Start container 184 | let container = ContainerRequest::from(image) 185 | // have to expose entire host network because testcontainer-rs doesn't implement selective SSH port forward from host 186 | // see https://github.com/testcontainers/testcontainers-rs/issues/535 187 | .with_host( 188 | "host.docker.internal", 189 | testcontainers::core::Host::HostGateway, 190 | ) 191 | .start() 192 | .await?; 193 | 194 | let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel(); 195 | if test_environment.logging { 196 | let container_stdout = container.stdout(true); 197 | let mut stdout_lines = container_stdout.lines(); 198 | let container_stderr = container.stderr(true); 199 | let mut stderr_lines = container_stderr.lines(); 200 | 201 | // Spawn a task to copy data from the AsyncBufRead to stdout 202 | task::spawn(async move { 203 | tokio::pin!(cancel_rx); 204 | loop { 205 | tokio::select! { 206 | Some(stdout_line) = stdout_lines.next_line().map(|res| res.transpose()) => { 207 | match stdout_line { 208 | Ok(line) => info!("{}", line), 209 | Err(e) => { 210 | error!("Error reading stdout from container stream: {}", e); 211 | break; 212 | } 213 | } 214 | }, 215 | Some(stderr_line) = stderr_lines.next_line().map(|res| res.transpose()) => { 216 | match stderr_line { 217 | Ok(line) => warn!("{}", line), 218 | Err(e) => { 219 | error!("Error reading stderr from container stream: {}", e); 220 | break; 221 | } 222 | } 223 | } 224 | _ = &mut cancel_rx => { 225 | break; 226 | } 227 | } 228 | } 229 | }); 230 | } 231 | 232 | // Resolve ingress url 233 | let host = container.get_host().await?; 234 | let ports = container.ports().await?; 235 | let ingress_port = ports.map_to_host_port_ipv4(8080.tcp()).unwrap(); 236 | let ingress_url = format!("http://{}:{}", host, ingress_port); 237 | 238 | info!("Restate container started, listening on requests at {ingress_url}"); 239 | 240 | Ok(StartedRestateContainer { 241 | _cancel_tx: cancel_tx, 242 | container, 243 | ingress_url, 244 | }) 245 | } 246 | 247 | async fn register_endpoint(&self, endpoint: &StartedEndpoint) -> Result<(), anyhow::Error> { 248 | let host = self.container.get_host().await?; 249 | let ports = self.container.ports().await?; 250 | let admin_port = ports.map_to_host_port_ipv4(9070.tcp()).unwrap(); 251 | 252 | let client = reqwest::Client::builder().http2_prior_knowledge().build()?; 253 | 254 | let deployment_uri: String = format!("http://host.docker.internal:{}/", endpoint.port); 255 | let deployment_payload = RegisterDeploymentRequestHttp { 256 | uri: deployment_uri, 257 | additional_headers: None, 258 | use_http_11: false, 259 | force: false, 260 | dry_run: false, 261 | }; 262 | 263 | let register_admin_url = format!("http://{}:{}/deployments", host, admin_port); 264 | 265 | let response = client 266 | .post(register_admin_url) 267 | .json(&deployment_payload) 268 | .send() 269 | .await 270 | .context("Error when trying to register the service endpoint")?; 271 | 272 | if !response.status().is_success() { 273 | anyhow::bail!( 274 | "Got non success status code when trying to register the service endpoint: {}", 275 | response.status() 276 | ) 277 | } 278 | 279 | Ok(()) 280 | } 281 | } 282 | 283 | pub struct StartedTestEnvironment { 284 | _started_endpoint: StartedEndpoint, 285 | started_restate_container: StartedRestateContainer, 286 | } 287 | 288 | impl StartedTestEnvironment { 289 | pub fn ingress_url(&self) -> String { 290 | self.started_restate_container.ingress_url.clone() 291 | } 292 | } 293 | -------------------------------------------------------------------------------- /testcontainers/tests/test_container.rs: -------------------------------------------------------------------------------- 1 | use reqwest::StatusCode; 2 | use restate_sdk::prelude::*; 3 | use restate_sdk_testcontainers::TestEnvironment; 4 | use tracing::info; 5 | 6 | #[restate_sdk::service] 7 | trait MyService { 8 | async fn my_handler() -> HandlerResult; 9 | } 10 | 11 | #[restate_sdk::object] 12 | trait MyObject { 13 | async fn my_handler(input: String) -> HandlerResult; 14 | #[shared] 15 | async fn my_shared_handler(input: String) -> HandlerResult; 16 | } 17 | 18 | #[restate_sdk::workflow] 19 | trait MyWorkflow { 20 | async fn my_handler(input: String) -> HandlerResult; 21 | #[shared] 22 | async fn my_shared_handler(input: String) -> HandlerResult; 23 | } 24 | 25 | struct MyServiceImpl; 26 | 27 | impl MyService for MyServiceImpl { 28 | async fn my_handler(&self, _: Context<'_>) -> HandlerResult { 29 | let result = "hello!"; 30 | Ok(result.to_string()) 31 | } 32 | } 33 | 34 | #[tokio::test] 35 | async fn test_container() { 36 | tracing_subscriber::fmt::fmt() 37 | .with_max_level(tracing::Level::INFO) // Set the maximum log level 38 | .init(); 39 | 40 | let endpoint = Endpoint::builder().bind(MyServiceImpl.serve()).build(); 41 | 42 | // simple test container initialization with default configuration 43 | //let test_container = TestContainer::default().start(endpoint).await.unwrap(); 44 | 45 | // custom test container initialization with builder 46 | let test_environment = TestEnvironment::new() 47 | // optional passthrough logging from the restate server testcontainers 48 | // prints container logs to tracing::info level 49 | .with_container_logging() 50 | .with_container( 51 | "docker.io/restatedev/restate".to_string(), 52 | "latest".to_string(), 53 | ) 54 | .start(endpoint) 55 | .await 56 | .unwrap(); 57 | 58 | let ingress_url = test_environment.ingress_url(); 59 | 60 | // call container ingress url for /MyService/my_handler 61 | let response = reqwest::Client::new() 62 | .post(format!("{}/MyService/my_handler", ingress_url)) 63 | .header("idempotency-key", "abc") 64 | .send() 65 | .await 66 | .unwrap(); 67 | 68 | assert_eq!(response.status(), StatusCode::OK); 69 | 70 | info!( 71 | "/MyService/my_handler response: {:?}", 72 | response.text().await.unwrap() 73 | ); 74 | } 75 | -------------------------------------------------------------------------------- /tests/compiletest.rs: -------------------------------------------------------------------------------- 1 | #[test] 2 | fn ui() { 3 | let t = trybuild::TestCases::new(); 4 | t.compile_fail("tests/ui/*.rs"); 5 | } 6 | -------------------------------------------------------------------------------- /tests/schema.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | use restate_sdk::serde::{Json, PayloadMetadata}; 3 | use restate_sdk::service::Discoverable; 4 | use serde::{Deserialize, Serialize}; 5 | use std::collections::HashMap; 6 | 7 | #[cfg(feature = "schemars")] 8 | use schemars::JsonSchema; 9 | 10 | #[derive(Serialize, Deserialize)] 11 | #[cfg_attr(feature = "schemars", derive(JsonSchema))] 12 | struct TestUser { 13 | name: String, 14 | age: u32, 15 | } 16 | 17 | #[derive(Serialize, Deserialize)] 18 | #[cfg_attr(feature = "schemars", derive(JsonSchema))] 19 | struct Person { 20 | name: String, 21 | age: u32, 22 | address: Address, 23 | } 24 | 25 | #[derive(Serialize, Deserialize, Default)] 26 | #[cfg_attr(feature = "schemars", derive(JsonSchema))] 27 | struct Address { 28 | street: String, 29 | city: String, 30 | } 31 | 32 | #[restate_sdk::service] 33 | trait SchemaTestService { 34 | async fn string_handler(input: String) -> HandlerResult; 35 | async fn no_input_handler() -> HandlerResult; 36 | async fn json_handler(input: Json) -> HandlerResult>; 37 | async fn complex_handler(input: Json) -> HandlerResult>>; 38 | async fn empty_output_handler(input: String) -> HandlerResult<()>; 39 | } 40 | 41 | struct SchemaTestServiceImpl; 42 | 43 | impl SchemaTestService for SchemaTestServiceImpl { 44 | async fn string_handler(&self, _ctx: Context<'_>, _input: String) -> HandlerResult { 45 | Ok(42) 46 | } 47 | async fn no_input_handler(&self, _ctx: Context<'_>) -> HandlerResult { 48 | Ok("No input".to_string()) 49 | } 50 | async fn json_handler( 51 | &self, 52 | _ctx: Context<'_>, 53 | input: Json, 54 | ) -> HandlerResult> { 55 | Ok(input) 56 | } 57 | async fn complex_handler( 58 | &self, 59 | _ctx: Context<'_>, 60 | input: Json, 61 | ) -> HandlerResult>> { 62 | Ok(Json(HashMap::from([("original".to_string(), input.0)]))) 63 | } 64 | async fn empty_output_handler(&self, _ctx: Context<'_>, _input: String) -> HandlerResult<()> { 65 | Ok(()) 66 | } 67 | } 68 | 69 | #[test] 70 | fn schema_discovery_and_validation() { 71 | let discovery = ServeSchemaTestService::::discover(); 72 | assert_eq!(discovery.name.to_string(), "SchemaTestService"); 73 | assert_eq!(discovery.handlers.len(), 5); 74 | 75 | for handler in &discovery.handlers { 76 | let input = handler 77 | .input 78 | .as_ref() 79 | .expect("Handler should have input schema"); 80 | let output = handler 81 | .output 82 | .as_ref() 83 | .expect("Handler should have output schema"); 84 | 85 | match handler.name.to_string().as_str() { 86 | "string_handler" | "json_handler" | "complex_handler" | "empty_output_handler" => { 87 | let input_schema = input 88 | .json_schema 89 | .as_ref() 90 | .expect("Input schema should exist for handlers with input"); 91 | let output_schema = output.json_schema.as_ref(); 92 | 93 | match handler.name.to_string().as_str() { 94 | "string_handler" => { 95 | assert_eq!( 96 | input_schema.get("type").and_then(|v| v.as_str()), 97 | Some("string") 98 | ); 99 | assert!(output_schema.is_some()); 100 | assert_eq!( 101 | output_schema.unwrap().get("type").and_then(|v| v.as_str()), 102 | Some("integer") 103 | ); 104 | } 105 | "json_handler" => { 106 | #[cfg(feature = "schemars")] 107 | { 108 | let obj = input_schema 109 | .as_object() 110 | .expect("Schema should be an object"); 111 | assert!( 112 | obj.contains_key("properties"), 113 | "Json schema should have properties" 114 | ); 115 | assert!(obj["properties"]["name"]["type"] == "string"); 116 | assert!(obj["properties"]["age"]["type"] == "integer"); 117 | } 118 | #[cfg(not(feature = "schemars"))] 119 | assert_eq!(input_schema, &serde_json::json!({})); 120 | } 121 | "complex_handler" => { 122 | #[cfg(feature = "schemars")] 123 | { 124 | let obj = input_schema 125 | .as_object() 126 | .expect("Schema should be an object"); 127 | assert!(obj.contains_key("properties") || obj.contains_key("$ref")); 128 | let props = obj.get("properties").or_else(|| obj.get("$ref")).unwrap(); 129 | assert!(props.is_object(), "Complex schema should define structure"); 130 | } 131 | #[cfg(not(feature = "schemars"))] 132 | assert_eq!(input_schema, &serde_json::json!({})); 133 | } 134 | "empty_output_handler" => { 135 | assert_eq!( 136 | input_schema.get("type").and_then(|v| v.as_str()), 137 | Some("string") 138 | ); 139 | // For empty output handler, we don't expect json_schema to be set in output 140 | assert!( 141 | output_schema.is_none(), 142 | "Empty output handler should have json_schema set to None" 143 | ); 144 | // Verify that set_content_type_if_empty is set 145 | assert_eq!(output.set_content_type_if_empty, Some(false)); 146 | } 147 | _ => unreachable!("Unexpected handler"), 148 | } 149 | } 150 | "no_input_handler" => { 151 | // For no_input_handler, we don't expect json_schema to be set 152 | assert!( 153 | input.json_schema.is_none(), 154 | "No input handler should have json_schema set to None" 155 | ); 156 | 157 | let output_schema = output 158 | .json_schema 159 | .as_ref() 160 | .expect("Output schema should exist"); 161 | 162 | assert_eq!( 163 | output_schema.get("type").and_then(|v| v.as_str()), 164 | Some("string") 165 | ); 166 | } 167 | _ => unreachable!("Unexpected handler"), 168 | } 169 | } 170 | } 171 | 172 | #[test] 173 | fn schema_generation() { 174 | let string_schema = ::json_schema().unwrap(); 175 | assert_eq!(string_schema["type"], "string"); 176 | 177 | let json_schema = as PayloadMetadata>::json_schema().unwrap(); 178 | #[cfg(feature = "schemars")] 179 | assert!(json_schema["properties"]["name"]["type"] == "string"); 180 | #[cfg(not(feature = "schemars"))] 181 | assert_eq!(json_schema, serde_json::json!({})); 182 | } 183 | -------------------------------------------------------------------------------- /tests/service.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | // Should compile 4 | #[restate_sdk::service] 5 | trait MyService { 6 | async fn my_handler(input: String) -> HandlerResult; 7 | 8 | async fn no_input() -> HandlerResult; 9 | 10 | async fn no_output() -> HandlerResult<()>; 11 | 12 | async fn no_input_no_output() -> HandlerResult<()>; 13 | 14 | async fn std_result() -> Result<(), std::io::Error>; 15 | 16 | async fn std_result_with_terminal_error() -> Result<(), TerminalError>; 17 | 18 | async fn std_result_with_handler_error() -> Result<(), HandlerError>; 19 | } 20 | 21 | #[restate_sdk::object] 22 | trait MyObject { 23 | async fn my_handler(input: String) -> HandlerResult; 24 | #[shared] 25 | async fn my_shared_handler(input: String) -> HandlerResult; 26 | } 27 | 28 | #[restate_sdk::workflow] 29 | trait MyWorkflow { 30 | async fn my_handler(input: String) -> HandlerResult; 31 | #[shared] 32 | async fn my_shared_handler(input: String) -> HandlerResult; 33 | } 34 | 35 | #[restate_sdk::service] 36 | #[name = "myRenamedService"] 37 | trait MyRenamedService { 38 | #[name = "myRenamedHandler"] 39 | async fn my_handler() -> HandlerResult<()>; 40 | } 41 | 42 | struct MyRenamedServiceImpl; 43 | 44 | impl MyRenamedService for MyRenamedServiceImpl { 45 | async fn my_handler(&self, _: Context<'_>) -> HandlerResult<()> { 46 | Ok(()) 47 | } 48 | } 49 | 50 | #[test] 51 | fn renamed_service_handler() { 52 | use restate_sdk::service::Discoverable; 53 | 54 | let discovery = ServeMyRenamedService::::discover(); 55 | assert_eq!(discovery.name.to_string(), "myRenamedService"); 56 | assert_eq!(discovery.handlers[0].name.to_string(), "myRenamedHandler"); 57 | } 58 | -------------------------------------------------------------------------------- /tests/ui/shared_handler_in_service.rs: -------------------------------------------------------------------------------- 1 | use restate_sdk::prelude::*; 2 | 3 | #[restate_sdk::service] 4 | trait SharedHandlerInService { 5 | #[shared] 6 | async fn my_handler() -> HandlerResult<()>; 7 | } 8 | 9 | struct SharedHandlerInServiceImpl; 10 | 11 | impl SharedHandlerInService for SharedHandlerInServiceImpl { 12 | async fn my_handler(&self, _: Context<'_>) -> HandlerResult<()> { 13 | Ok(()) 14 | } 15 | } 16 | 17 | #[tokio::main] 18 | async fn main() { 19 | tracing_subscriber::fmt::init(); 20 | HttpServer::new( 21 | Endpoint::builder() 22 | .with_service(SharedHandlerInServiceImpl.serve()) 23 | .build(), 24 | ) 25 | .listen_and_serve("0.0.0.0:9080".parse().unwrap()) 26 | .await; 27 | } -------------------------------------------------------------------------------- /tests/ui/shared_handler_in_service.stderr: -------------------------------------------------------------------------------- 1 | error: Service handlers cannot be annotated with #[shared] 2 | --> tests/ui/shared_handler_in_service.rs:6:14 3 | | 4 | 6 | async fn my_handler() -> HandlerResult<()>; 5 | | ^^^^^^^^^^ 6 | 7 | error[E0405]: cannot find trait `SharedHandlerInService` in this scope 8 | --> tests/ui/shared_handler_in_service.rs:11:6 9 | | 10 | 11 | impl SharedHandlerInService for SharedHandlerInServiceImpl { 11 | | ^^^^^^^^^^^^^^^^^^^^^^ not found in this scope 12 | --------------------------------------------------------------------------------