├── examples
├── client
│ └── .gitignore
├── helloworld
│ ├── .gitignore
│ └── main.go
├── ticketreservation
│ ├── .gitignore
│ ├── main.go
│ ├── checkout.go
│ ├── ticket_service.go
│ ├── user_session.go
│ └── main_test.go
├── codegen
│ ├── buf.yaml
│ ├── buf.lock
│ ├── buf.gen.yaml
│ ├── proto
│ │ └── helloworld.proto
│ └── main.go
├── otel
│ ├── README.md
│ ├── go.mod
│ └── main.go
└── parallelizework
│ └── main.go
├── shared-core
├── .gitignore
├── build.rs
├── .cargo
│ └── config.toml
└── Cargo.toml
├── .gitignore
├── test-services
├── exclusions.yaml
├── .env
├── README.md
├── registry.go
├── upgradetest.go
├── mapobject.go
├── kill.go
├── awakeableholder.go
├── blockandwaitworkflow.go
├── main.go
├── listobject.go
├── counter.go
├── canceltest.go
├── testutils.go
├── failing.go
├── nondeterministic.go
└── proxy.go
├── generate.go
├── internal
├── ingress
│ ├── invocation.go
│ └── error.go
├── statemachine
│ ├── shared_core_golang_wasm_binding.wasm
│ └── logger.go
├── restatecontext
│ ├── handler.go
│ ├── io_helpers.go
│ ├── sleep.go
│ ├── select.go
│ ├── wait_iterator.go
│ ├── state.go
│ ├── awakeable.go
│ ├── promise.go
│ ├── execute_invocation.go
│ ├── async_results.go
│ └── run.go
├── identity
│ ├── identity.go
│ └── v1.go
├── errors
│ └── error.go
├── converters
│ └── converters.go
├── log
│ └── log.go
└── rand
│ ├── rand_test.go
│ └── rand.go
├── .github
└── workflows
│ ├── cla.yml
│ ├── test.yaml
│ ├── docker.yaml
│ └── integration.yaml
├── buf.yaml
├── buf.lock
├── internal.buf.gen.yaml
├── buf.gen.yaml
├── .mockery.yaml
├── testing
└── test_env_test.go
├── mock.go
├── ingress
├── client.go
├── mock_server_test.go
├── send_requester.go
└── requester.go
├── encoding
├── internal
│ ├── util
│ │ └── util.go
│ └── protojsonschema
│ │ ├── schema.go
│ │ └── wellknown.go
└── encoding_test.go
├── error_test.go
├── rcontext
└── rcontext.go
├── LICENSE
├── proto
├── dev
│ └── restate
│ │ └── sdk
│ │ └── go.proto
└── internal.proto
├── error.go
├── server
├── conn.go
└── lambda.go
├── context.go
├── protoc-gen-go-restate
├── README.md
└── main.go
├── mocks
├── mock_Invocation.go
├── mock_AfterFuture.go
├── mock_Selector.go
├── mock_AttachFuture.go
├── mock_RunAsyncFuture.go
├── mock_WaitIterator.go
├── mock_AwakeableFuture.go
└── mock_ResponseFuture.go
├── go.mod
├── router.go
└── README.md
/examples/client/.gitignore:
--------------------------------------------------------------------------------
1 | test
2 |
--------------------------------------------------------------------------------
/shared-core/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | test_report
2 | .restate
3 |
--------------------------------------------------------------------------------
/examples/helloworld/.gitignore:
--------------------------------------------------------------------------------
1 | test
2 |
--------------------------------------------------------------------------------
/examples/ticketreservation/.gitignore:
--------------------------------------------------------------------------------
1 | test
2 |
--------------------------------------------------------------------------------
/test-services/exclusions.yaml:
--------------------------------------------------------------------------------
1 | exclusions: {}
2 |
--------------------------------------------------------------------------------
/generate.go:
--------------------------------------------------------------------------------
1 | package restate
2 |
3 | //go:generate buf generate
4 |
--------------------------------------------------------------------------------
/test-services/.env:
--------------------------------------------------------------------------------
1 | RESTATE_LOGGING=trace
2 | CORE_TRACE_LOGGING_ENABLED=true
--------------------------------------------------------------------------------
/internal/ingress/invocation.go:
--------------------------------------------------------------------------------
1 | package ingress
2 |
3 | type Invocation struct {
4 | Id string `json:"invocationId"`
5 | Status string `json:"status"`
6 | }
7 |
--------------------------------------------------------------------------------
/internal/statemachine/shared_core_golang_wasm_binding.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/restatedev/sdk-go/HEAD/internal/statemachine/shared_core_golang_wasm_binding.wasm
--------------------------------------------------------------------------------
/examples/codegen/buf.yaml:
--------------------------------------------------------------------------------
1 | version: v2
2 | lint:
3 | use:
4 | - DEFAULT
5 | breaking:
6 | use:
7 | - FILE
8 | deps:
9 | - buf.build/restatedev/sdk-go
10 |
--------------------------------------------------------------------------------
/shared-core/build.rs:
--------------------------------------------------------------------------------
1 | use std::io::Result;
2 |
3 | fn main() -> Result<()> {
4 | prost_build::Config::new()
5 | .bytes(["."])
6 | .compile_protos(&["../proto/internal.proto"], &["../proto"])?;
7 | Ok(())
8 | }
9 |
--------------------------------------------------------------------------------
/shared-core/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | target = "wasm32-unknown-unknown"
3 | rustflags = [
4 | # Make stack size of 1 WASM page,
5 | # to avoid allocating like crazy on module instantiation
6 | "-C", "link-arg=-zstack-size=65536"
7 | ]
--------------------------------------------------------------------------------
/.github/workflows/cla.yml:
--------------------------------------------------------------------------------
1 | name: "CLA Assistant"
2 | on:
3 | issue_comment:
4 | types: [created]
5 | pull_request_target:
6 | types: [opened, closed, synchronize]
7 |
8 | jobs:
9 | CLAAssistant:
10 | uses: restatedev/restate/.github/workflows/cla.yml@main
11 | secrets: inherit
12 |
--------------------------------------------------------------------------------
/buf.yaml:
--------------------------------------------------------------------------------
1 | version: v2
2 | modules:
3 | - path: proto
4 | name: buf.build/restatedev/sdk-go
5 | excludes:
6 | - proto/dev/restate/sdk/go
7 | deps:
8 | - buf.build/protocolbuffers/wellknowntypes:v29.3
9 | breaking:
10 | use:
11 | - FILE
12 | lint:
13 | use:
14 | - DEFAULT
15 |
--------------------------------------------------------------------------------
/buf.lock:
--------------------------------------------------------------------------------
1 | # Generated by buf. DO NOT EDIT.
2 | version: v2
3 | deps:
4 | - name: buf.build/protocolbuffers/wellknowntypes
5 | commit: d4f14e5e0a9c40889c90d373c74e95eb
6 | digest: b5:39b4d0887abcd8ee1594086283f4120f688e1c33ec9ccd554ab0362ad9ad482154d0e07e3787d394bb22970930b452aac1c5c105c05efe129cec299ff5b5e05e
7 |
--------------------------------------------------------------------------------
/internal.buf.gen.yaml:
--------------------------------------------------------------------------------
1 | version: v2
2 | managed:
3 | enabled: true
4 | plugins:
5 | - remote: buf.build/protocolbuffers/go:v1.36.5
6 | out: internal/generated
7 | opt:
8 | - paths=source_relative
9 | - default_api_level=API_OPAQUE
10 | inputs:
11 | - proto_file: proto/internal.proto
12 |
--------------------------------------------------------------------------------
/examples/codegen/buf.lock:
--------------------------------------------------------------------------------
1 | # Generated by buf. DO NOT EDIT.
2 | version: v2
3 | deps:
4 | - name: buf.build/restatedev/sdk-go
5 | commit: 9ea0b54286dd4f35b0cb96ecdf09b402
6 | digest: b5:822b9362e943c827c36e44b0db519542259439382f94817989349d0ee590617ba70e35975840c5d96ceff278254806435e7d570db81548f9703c00b01eec398e
7 |
--------------------------------------------------------------------------------
/buf.gen.yaml:
--------------------------------------------------------------------------------
1 | version: v2
2 | managed:
3 | enabled: true
4 | override:
5 | - file_option: go_package_prefix
6 | value: github.com/restatedev/sdk-go/generated
7 | plugins:
8 | - remote: buf.build/protocolbuffers/go:v1.36.5
9 | out: generated
10 | opt: paths=source_relative
11 | inputs:
12 | - proto_file: proto/dev/restate/sdk/go.proto
13 |
--------------------------------------------------------------------------------
/examples/codegen/buf.gen.yaml:
--------------------------------------------------------------------------------
1 | version: v2
2 | managed:
3 | enabled: true
4 | plugins:
5 | - remote: buf.build/protocolbuffers/go:v1.34.2
6 | out: .
7 | opt: paths=source_relative
8 | - local: protoc-gen-go-restate
9 | out: .
10 | opt:
11 | - paths=source_relative
12 | - use_go_service_names=false
13 | inputs:
14 | - directory: .
15 |
--------------------------------------------------------------------------------
/internal/statemachine/logger.go:
--------------------------------------------------------------------------------
1 | package statemachine
2 |
3 | import (
4 | "context"
5 | "log/slog"
6 | )
7 |
8 | type loggerKey struct{}
9 |
10 | func WithLogger(ctx context.Context, logger *slog.Logger) context.Context {
11 | return context.WithValue(ctx, loggerKey{}, logger)
12 | }
13 |
14 | func getLogger(ctx context.Context) *slog.Logger {
15 | val, _ := ctx.Value(loggerKey{}).(*slog.Logger)
16 | return val
17 | }
18 |
--------------------------------------------------------------------------------
/.mockery.yaml:
--------------------------------------------------------------------------------
1 | with-expecter: true
2 | issue-845-fix: true
3 | resolve-type-alias: false
4 | dir: mocks
5 | outPkg: mocks
6 | packages:
7 | github.com/restatedev/sdk-go/internal/restatecontext:
8 | interfaces:
9 | AfterFuture:
10 | AttachFuture:
11 | AwakeableFuture:
12 | Client:
13 | Context:
14 | DurablePromise:
15 | Invocation:
16 | ResponseFuture:
17 | RunAsyncFuture:
18 | Selector:
19 | WaitIterator:
20 | github.com/restatedev/sdk-go/internal/rand:
21 | interfaces:
22 | Rand:
23 |
--------------------------------------------------------------------------------
/internal/restatecontext/handler.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "github.com/restatedev/sdk-go/encoding"
5 | "github.com/restatedev/sdk-go/internal"
6 | "github.com/restatedev/sdk-go/internal/options"
7 | )
8 |
9 | // Handler is implemented by all Restate handlers
10 | type Handler interface {
11 | GetOptions() *options.HandlerOptions
12 | InputPayload() *encoding.InputPayload
13 | OutputPayload() *encoding.OutputPayload
14 | HandlerType() *internal.ServiceHandlerType
15 | Call(ctx Context, request []byte) (output []byte, err error)
16 | }
17 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Go
2 | on: [push]
3 |
4 | permissions:
5 | checks: write
6 | pull-requests: write
7 |
8 | jobs:
9 | build:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 | - name: Setup Go
15 | uses: actions/setup-go@v4
16 | with:
17 | go-version: "1.21.x"
18 | - name: Install dependencies
19 | run: go get .
20 | - name: Vet
21 | run: go vet -v ./...
22 | - name: Build
23 | run: go build -v ./...
24 | - name: Test with the Go CLI
25 | run: go test -v ./...
26 |
--------------------------------------------------------------------------------
/examples/ticketreservation/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "log/slog"
6 | "os"
7 |
8 | restate "github.com/restatedev/sdk-go"
9 | "github.com/restatedev/sdk-go/server"
10 | )
11 |
12 | func main() {
13 | server := server.NewRestate().
14 | // Handlers can be inferred from object methods
15 | Bind(restate.Reflect(&userSession{})).
16 | Bind(restate.Reflect(&ticketService{})).
17 | Bind(restate.Reflect(&checkout{}))
18 |
19 | if err := server.Start(context.Background(), ":9080"); err != nil {
20 | slog.Error("application exited unexpectedly", "err", err.Error())
21 | os.Exit(1)
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/shared-core/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "shared-core-golang-wasm-binding"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [lib]
7 | crate-type = ["cdylib"]
8 |
9 | [dependencies]
10 | restate-sdk-shared-core = { version = "=0.6.0" }
11 | bytes = "1.10"
12 | tracing = "0.1.40"
13 | tracing-subscriber = { version = "0.3.18", default-features = false, features = ["fmt", "std"] }
14 | prost = "0.13.5"
15 |
16 | [build-dependencies]
17 | prost-build = "0.13.5"
18 |
19 | # Below settings dramatically reduce wasm output size
20 | # See https://rustwasm.github.io/book/reference/code-size.html#optimizing-builds-for-code-sizewasm-opt -Oz -o
21 | # See https://doc.rust-lang.org/cargo/reference/profiles.html#codegen-units
22 | [profile.release]
23 | opt-level = 3
24 | lto = true
25 |
--------------------------------------------------------------------------------
/examples/otel/README.md:
--------------------------------------------------------------------------------
1 | # Distributed tracing example
2 |
3 | To test out distributed tracing, you can run Jaeger locally:
4 | ```shell
5 | docker run -d --name jaeger \
6 | -e COLLECTOR_OTLP_ENABLED=true \
7 | -p 4317:4317 -p 16686:16686 \
8 | jaegertracing/all-in-one:1.46
9 | ```
10 |
11 | And start the Restate server configured to send traces to Jaeger:
12 | ```shell
13 | npx @restatedev/restate-server --tracing-endpoint http://localhost:4317
14 | ```
15 |
16 | Finally start this example service and register it with the Restate server:
17 | ```shell
18 | go run ./examples/otel
19 | restate dep register http://localhost:9080
20 | ```
21 |
22 | And you can now make invocations with `curl localhost:8080/Greeter/Greet --json '"hello"'`,
23 | and they should appear in the [Jaeger UI](http://localhost:16686) with spans from both the
24 | Restate server and the Go service.
25 |
--------------------------------------------------------------------------------
/testing/test_env_test.go:
--------------------------------------------------------------------------------
1 | package testing
2 |
3 | import (
4 | "testing"
5 |
6 | restate "github.com/restatedev/sdk-go"
7 | "github.com/restatedev/sdk-go/ingress"
8 | "github.com/restatedev/sdk-go/server"
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | type Greeter struct{}
13 |
14 | func (Greeter) Greet(ctx restate.Context, name string) (string, error) {
15 | // Respond to caller
16 | return "You said hi to " + name + "!", nil
17 | }
18 |
19 | func TestWithTestcontainers(t *testing.T) {
20 | tEnv := StartWithOptions(t, server.NewRestate().Bind(restate.Reflect(Greeter{})), WithRestateImage("ghcr.io/restatedev/restate:latest"))
21 | client := tEnv.Ingress()
22 |
23 | out, err := ingress.Service[string, string](client, "Greeter", "Greet").Request(t.Context(), "Francesco")
24 | require.NoError(t, err)
25 | require.Equal(t, "You said hi to Francesco!", out)
26 | }
27 |
--------------------------------------------------------------------------------
/mock.go:
--------------------------------------------------------------------------------
1 | package restate
2 |
3 | import (
4 | "github.com/restatedev/sdk-go/internal/restatecontext"
5 | )
6 |
7 | type mockContext struct {
8 | restatecontext.Context
9 | }
10 |
11 | func (m mockContext) inner() restatecontext.Context {
12 | return m.Context
13 | }
14 | func (m mockContext) object() {}
15 | func (m mockContext) exclusiveObject() {}
16 | func (m mockContext) workflow() {}
17 | func (m mockContext) runWorkflow() {}
18 |
19 | var _ RunContext = mockContext{}
20 | var _ Context = mockContext{}
21 | var _ ObjectSharedContext = mockContext{}
22 | var _ ObjectContext = mockContext{}
23 | var _ WorkflowSharedContext = mockContext{}
24 | var _ WorkflowContext = mockContext{}
25 |
26 | // WithMockContext allows providing a mocked state.Context to handlers
27 | func WithMockContext(ctx restatecontext.Context) mockContext {
28 | return mockContext{ctx}
29 | }
30 |
--------------------------------------------------------------------------------
/test-services/README.md:
--------------------------------------------------------------------------------
1 | # Test services to run the sdk-test-suite
2 |
3 | ## To run locally
4 |
5 | * Grab the release of sdk-test-suite: https://github.com/restatedev/sdk-test-suite/releases
6 |
7 | * Prepare the docker image:
8 | ```shell
9 | KO_DOCKER_REPO=restatedev ko build -B -L github.com/restatedev/sdk-go/test-services
10 | ```
11 |
12 | * Run the tests (requires JVM >= 17):
13 | ```shell
14 | java -jar restate-sdk-test-suite.jar run --exclusions-file exclusions.yaml restatedev/test-services
15 | ```
16 |
17 | ## To debug a single test:
18 |
19 | * Run the golang service using your IDE
20 | * Run the test runner in debug mode specifying test suite and test:
21 | ```shell
22 | java -jar restate-sdk-test-suite.jar debug --image-pull-policy=CACHED --test-config=lazyState --test-name=dev.restate.sdktesting.tests.State default-service=9080
23 | ```
24 |
25 | For more info: https://github.com/restatedev/sdk-test-suite
--------------------------------------------------------------------------------
/ingress/client.go:
--------------------------------------------------------------------------------
1 | package ingress
2 |
3 | import (
4 | "github.com/restatedev/sdk-go/internal/ingress"
5 | "github.com/restatedev/sdk-go/internal/options"
6 | )
7 |
8 | type Client = ingress.Client
9 |
10 | // NewClient creates a new ingress client for calling Restate services from outside a Restate context.
11 | // The baseUri should point to your Restate ingress endpoint (e.g., "http://localhost:8080").
12 | //
13 | // Options can be used to configure the client, such as setting a custom HTTP client, authentication key, or codec:
14 | //
15 | // client := ingress.NewClient("http://localhost:8080",
16 | // restate.WithAuthKey("my-auth-key"),
17 | // )
18 | func NewClient(baseUri string, opts ...options.IngressClientOption) *Client {
19 | clientOpts := options.IngressClientOptions{}
20 | for _, opt := range opts {
21 | opt.BeforeIngress(&clientOpts)
22 | }
23 | return ingress.NewClient(baseUri, clientOpts)
24 | }
25 |
--------------------------------------------------------------------------------
/encoding/internal/util/util.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "strings"
5 |
6 | "github.com/invopop/jsonschema"
7 | )
8 |
9 | // Schemas that have a top-level ref can be problematic for some parsers, like the playground in the UI.
10 | // To be more forgiving, we can yank the definition up to the top level
11 | func ExpandSchema(rootSchema *jsonschema.Schema) *jsonschema.Schema {
12 | if !strings.HasPrefix(rootSchema.Ref, `#/$defs/`) {
13 | return rootSchema
14 | }
15 | defName := rootSchema.Ref[len(`#/$defs/`):]
16 | def, ok := rootSchema.Definitions[defName]
17 | if !ok {
18 | return rootSchema
19 | }
20 | // allow references to #/$defs/name to still work by redirecting to the root
21 | rootSchema.Definitions[defName] = &jsonschema.Schema{Ref: "#"}
22 |
23 | expandedSchema := &*def
24 | expandedSchema.ID = rootSchema.ID
25 | expandedSchema.Version = rootSchema.Version
26 | expandedSchema.Definitions = rootSchema.Definitions
27 |
28 | return expandedSchema
29 | }
30 |
--------------------------------------------------------------------------------
/internal/ingress/error.go:
--------------------------------------------------------------------------------
1 | package ingress
2 |
3 | type restateError struct {
4 | Message string `json:"message"`
5 | Code int `json:"code,omitempty"`
6 | Description string `json:"description,omitempty"`
7 | Stacktrace string `json:"stacktrace,omitempty"`
8 | }
9 |
10 | type GenericError struct {
11 | *restateError
12 | }
13 |
14 | type InvocationNotFoundError struct {
15 | *restateError
16 | }
17 |
18 | type InvocationNotReadyError struct {
19 | *restateError
20 | }
21 |
22 | func (e restateError) Error() string {
23 | return e.Message
24 | }
25 |
26 | func newGenericError(err *restateError) *GenericError {
27 | return &GenericError{
28 | restateError: err,
29 | }
30 | }
31 |
32 | func newInvocationNotFoundError(err *restateError) *InvocationNotFoundError {
33 | return &InvocationNotFoundError{
34 | restateError: err,
35 | }
36 | }
37 |
38 | func newInvocationNotReadyError(err *restateError) *InvocationNotReadyError {
39 | return &InvocationNotReadyError{
40 | restateError: err,
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/internal/identity/identity.go:
--------------------------------------------------------------------------------
1 | package identity
2 |
3 | import "fmt"
4 |
5 | const SIGNATURE_SCHEME_HEADER = "X-Restate-Signature-Scheme"
6 |
7 | type SignatureScheme string
8 |
9 | var (
10 | SchemeUnsigned SignatureScheme = "unsigned"
11 | errMissingIdentity = fmt.Errorf("request has no identity")
12 | )
13 |
14 | func ValidateRequestIdentity(keySet KeySetV1, path string, headers map[string][]string) error {
15 | switch len(headers[SIGNATURE_SCHEME_HEADER]) {
16 | case 0:
17 | return errMissingIdentity
18 | case 1:
19 | switch SignatureScheme(headers[SIGNATURE_SCHEME_HEADER][0]) {
20 | case SchemeV1:
21 | return validateV1(keySet, path, headers)
22 | case SchemeUnsigned:
23 | return errMissingIdentity
24 | default:
25 | return fmt.Errorf("unexpected signature scheme %v, allowed values are [%s %s]", headers[SIGNATURE_SCHEME_HEADER][0], SchemeUnsigned, SchemeV1)
26 | }
27 | default:
28 | return fmt.Errorf("unexpected multi-value signature scheme header: %v", headers[SIGNATURE_SCHEME_HEADER])
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/error_test.go:
--------------------------------------------------------------------------------
1 | package restate
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/require"
9 | )
10 |
11 | func TestTerminal(t *testing.T) {
12 | require.False(t, IsTerminalError(fmt.Errorf("not terminal")))
13 |
14 | err := TerminalErrorf("failed terminally")
15 | require.True(t, IsTerminalError(err))
16 |
17 | //terminal with code
18 | err = TerminalError(fmt.Errorf("terminal with code"), 500)
19 |
20 | require.True(t, IsTerminalError(err))
21 | require.EqualValues(t, 500, ErrorCode(err))
22 | }
23 |
24 | func TestCode(t *testing.T) {
25 |
26 | err := WithErrorCode(fmt.Errorf("some error"), 16)
27 |
28 | code := ErrorCode(err)
29 |
30 | require.EqualValues(t, 16, code)
31 |
32 | require.EqualValues(t, http.StatusInternalServerError, ErrorCode(fmt.Errorf("unknown error")))
33 | }
34 |
35 | func TestCombine(t *testing.T) {
36 | err := WithErrorCode(TerminalError(fmt.Errorf("some error")), 100)
37 |
38 | require.True(t, IsTerminalError(err))
39 | require.EqualValues(t, 100, ErrorCode(err))
40 | }
41 |
--------------------------------------------------------------------------------
/test-services/registry.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 |
6 | restate "github.com/restatedev/sdk-go"
7 | "github.com/restatedev/sdk-go/server"
8 | )
9 |
10 | var REGISTRY = Registry{components: map[string]Component{}}
11 |
12 | type Registry struct {
13 | components map[string]Component
14 | }
15 |
16 | type Component struct {
17 | Fqdn string
18 | Binder func(endpoint *server.Restate)
19 | }
20 |
21 | func (r *Registry) Add(c Component) {
22 | r.components[c.Fqdn] = c
23 | }
24 |
25 | func (r *Registry) AddDefinition(definition restate.ServiceDefinition) {
26 | r.Add(Component{
27 | Fqdn: definition.Name(),
28 | Binder: func(e *server.Restate) { e.Bind(definition) },
29 | })
30 | }
31 |
32 | func (r *Registry) RegisterAll(e *server.Restate) {
33 | for _, c := range r.components {
34 | c.Binder(e)
35 | }
36 | }
37 |
38 | func (r *Registry) Register(fqdns map[string]struct{}, e *server.Restate) {
39 | for fqdn := range fqdns {
40 | c, ok := r.components[fqdn]
41 | if !ok {
42 | log.Fatalf("unknown fqdn %s. Did you remember to register it?", fqdn)
43 | }
44 | c.Binder(e)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/examples/helloworld/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log/slog"
7 | "os"
8 |
9 | restate "github.com/restatedev/sdk-go"
10 | "github.com/restatedev/sdk-go/server"
11 | )
12 |
13 | type Greeter struct{}
14 |
15 | func (Greeter) Greet(ctx restate.Context, name string) (string, error) {
16 | return "You said hi to " + name + "!", nil
17 | }
18 |
19 | type GreeterCounter struct{}
20 |
21 | func (GreeterCounter) Greet(ctx restate.ObjectContext, name string) (string, error) {
22 | count, err := restate.Get[uint32](ctx, "count")
23 | if err != nil {
24 | return "", err
25 | }
26 | count++
27 |
28 | restate.Set[uint32](ctx, "count", count)
29 |
30 | return fmt.Sprintf("You said hi to %s for the %d time!", name, count), nil
31 | }
32 |
33 | func main() {
34 | server := server.NewRestate().
35 | // Handlers can be inferred from object methods
36 | Bind(restate.Reflect(Greeter{})).
37 | Bind(restate.Reflect(GreeterCounter{}))
38 |
39 | if err := server.Start(context.Background(), ":9080"); err != nil {
40 | slog.Error("application exited unexpectedly", "err", err.Error())
41 | os.Exit(1)
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/test-services/upgradetest.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "strings"
7 |
8 | restate "github.com/restatedev/sdk-go"
9 | )
10 |
11 | func init() {
12 | version := func() string {
13 | return strings.TrimSpace(os.Getenv("E2E_UPGRADETEST_VERSION"))
14 | }
15 | REGISTRY.AddDefinition(
16 | restate.NewService("UpgradeTest").
17 | Handler("executeSimple", restate.NewServiceHandler(
18 | func(ctx restate.Context, _ restate.Void) (string, error) {
19 | return version(), nil
20 | })).
21 | Handler("executeComplex", restate.NewServiceHandler(
22 | func(ctx restate.Context, _ restate.Void) (string, error) {
23 | if version() != "v1" {
24 | return "", fmt.Errorf("executeComplex should not be invoked with version different from 1!")
25 | }
26 | awakeable := restate.Awakeable[string](ctx)
27 | restate.ObjectSend(ctx, "AwakeableHolder", "upgrade", "hold").Send(awakeable.Id())
28 | if _, err := awakeable.Result(); err != nil {
29 | return "", err
30 | }
31 | restate.ObjectSend(ctx, "ListObject", "upgrade-test", "append").Send(version())
32 | return version(), nil
33 | })))
34 | }
35 |
--------------------------------------------------------------------------------
/rcontext/rcontext.go:
--------------------------------------------------------------------------------
1 | package rcontext
2 |
3 | import "context"
4 |
5 | // LogSource is an enum to describe the source of a logline
6 | type LogSource int
7 |
8 | const (
9 | // LogSourceRestate logs come from the sdk-go library
10 | LogSourceRestate = iota
11 | // LogSourceUser logs come from user handlers that use the Context.Log() logger.
12 | LogSourceUser
13 | )
14 |
15 | // LogContext contains information stored in the context that is passed to loggers
16 | type LogContext struct {
17 | // The source of the logline
18 | Source LogSource
19 | // Whether the user code is currently replaying
20 | IsReplaying bool
21 | }
22 |
23 | type logContextKey struct{}
24 |
25 | // WithLogContext stores a [LogContext] in the provided [context.Context], returning a new context
26 | func WithLogContext(parent context.Context, logContext *LogContext) context.Context {
27 | return context.WithValue(parent, logContextKey{}, logContext)
28 | }
29 |
30 | // LogContextFrom retrieves the [LogContext] stored in this [context.Context], or otherwise returns nil
31 | func LogContextFrom(ctx context.Context) *LogContext {
32 | if val, ok := ctx.Value(logContextKey{}).(*LogContext); ok {
33 | return val
34 | }
35 | return nil
36 | }
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 - Restate Software, Inc., Restate GmbH
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/test-services/mapobject.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | restate "github.com/restatedev/sdk-go"
5 | )
6 |
7 | type Entry struct {
8 | Key string `json:"key"`
9 | Value string `json:"value"`
10 | }
11 |
12 | func init() {
13 | REGISTRY.AddDefinition(
14 | restate.NewObject("MapObject").
15 | Handler("set", restate.NewObjectHandler(
16 | func(ctx restate.ObjectContext, value Entry) (restate.Void, error) {
17 | restate.Set(ctx, value.Key, value.Value)
18 | return restate.Void{}, nil
19 | })).
20 | Handler("get", restate.NewObjectHandler(
21 | func(ctx restate.ObjectContext, key string) (string, error) {
22 | return restate.Get[string](ctx, key)
23 | })).
24 | Handler("clearAll", restate.NewObjectHandler(
25 | func(ctx restate.ObjectContext, _ restate.Void) ([]Entry, error) {
26 | keys, err := restate.Keys(ctx)
27 | if err != nil {
28 | return nil, err
29 | }
30 | out := make([]Entry, 0, len(keys))
31 | for _, k := range keys {
32 | value, err := restate.Get[string](ctx, k)
33 | if err != nil {
34 | return nil, err
35 | }
36 | out = append(out, Entry{Key: k, Value: value})
37 | }
38 | restate.ClearAll(ctx)
39 | return out, nil
40 | })))
41 | }
42 |
--------------------------------------------------------------------------------
/test-services/kill.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | restate "github.com/restatedev/sdk-go"
5 | )
6 |
7 | func init() {
8 | REGISTRY.AddDefinition(restate.NewObject("KillTestRunner").Handler("startCallTree", restate.NewObjectHandler(func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
9 | return restate.Object[restate.Void](ctx, "KillTestSingleton", restate.Key(ctx), "recursiveCall").Request(restate.Void{})
10 | })))
11 |
12 | REGISTRY.AddDefinition(
13 | restate.NewObject("KillTestSingleton").
14 | Handler("recursiveCall", restate.NewObjectHandler(
15 | func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
16 | awakeable := restate.Awakeable[restate.Void](ctx)
17 | restate.ObjectSend(ctx, "AwakeableHolder", restate.Key(ctx), "hold").Send(awakeable.Id())
18 | if _, err := awakeable.Result(); err != nil {
19 | return restate.Void{}, err
20 | }
21 |
22 | return restate.Object[restate.Void](ctx, "KillTestSingleton", restate.Key(ctx), "recursiveCall").Request(restate.Void{})
23 | })).
24 | Handler("isUnlocked", restate.NewObjectHandler(
25 | func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
26 | // no-op
27 | return restate.Void{}, nil
28 | })))
29 | }
30 |
--------------------------------------------------------------------------------
/test-services/awakeableholder.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 |
6 | restate "github.com/restatedev/sdk-go"
7 | )
8 |
9 | const ID_KEY = "id"
10 |
11 | func init() {
12 | REGISTRY.AddDefinition(
13 | restate.NewObject("AwakeableHolder").
14 | Handler("hold", restate.NewObjectHandler(
15 | func(ctx restate.ObjectContext, id string) (restate.Void, error) {
16 | restate.Set(ctx, ID_KEY, id)
17 | return restate.Void{}, nil
18 | })).
19 | Handler("hasAwakeable", restate.NewObjectHandler(
20 | func(ctx restate.ObjectContext, _ restate.Void) (bool, error) {
21 | id, err := restate.Get[string](ctx, ID_KEY)
22 | if err != nil {
23 | return false, err
24 | }
25 | return id != "", nil
26 | })).
27 | Handler("unlock", restate.NewObjectHandler(
28 | func(ctx restate.ObjectContext, payload string) (restate.Void, error) {
29 | id, err := restate.Get[string](ctx, ID_KEY)
30 | if err != nil {
31 | return restate.Void{}, err
32 | }
33 | if id == "" {
34 | return restate.Void{}, restate.TerminalError(fmt.Errorf("No awakeable registered"), 404)
35 | }
36 | restate.ResolveAwakeable(ctx, id, payload)
37 | restate.Clear(ctx, ID_KEY)
38 | return restate.Void{}, nil
39 | })))
40 | }
41 |
--------------------------------------------------------------------------------
/internal/errors/error.go:
--------------------------------------------------------------------------------
1 | package errors
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | )
7 |
8 | type Code uint16
9 |
10 | type CodeError struct {
11 | Code Code
12 | Inner error
13 | }
14 |
15 | func (e *CodeError) Error() string {
16 | return fmt.Sprintf("[%d] %s", e.Code, e.Inner)
17 | }
18 |
19 | func (e *CodeError) Unwrap() error {
20 | return e.Inner
21 | }
22 |
23 | func ErrorCode(err error) Code {
24 | var e *CodeError
25 | if errors.As(err, &e) {
26 | return e.Code
27 | }
28 |
29 | return 500
30 | }
31 |
32 | type TerminalError struct {
33 | Inner error
34 | }
35 |
36 | func (e *TerminalError) Error() string {
37 | return e.Inner.Error()
38 | }
39 |
40 | func (e *TerminalError) Unwrap() error {
41 | return e.Inner
42 | }
43 |
44 | func IsTerminalError(err error) bool {
45 | if err == nil {
46 | return false
47 | }
48 | var t *TerminalError
49 | return errors.As(err, &t)
50 | }
51 |
52 | func NewTerminalError(err error, code ...Code) error {
53 | if err == nil {
54 | return nil
55 | }
56 |
57 | if len(code) > 1 {
58 | panic("only single code is allowed")
59 | }
60 |
61 | err = &TerminalError{
62 | Inner: err,
63 | }
64 |
65 | if len(code) == 1 {
66 | err = &CodeError{
67 | Inner: err,
68 | Code: code[0],
69 | }
70 | }
71 |
72 | return err
73 | }
74 |
--------------------------------------------------------------------------------
/test-services/blockandwaitworkflow.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | restate "github.com/restatedev/sdk-go"
5 | )
6 |
7 | const MY_STATE = "my-state"
8 | const MY_DURABLE_PROMISE = "durable-promise"
9 |
10 | func init() {
11 | REGISTRY.AddDefinition(
12 | restate.NewWorkflow("BlockAndWaitWorkflow").
13 | Handler("run", restate.NewWorkflowHandler(
14 | func(ctx restate.WorkflowContext, input string) (string, error) {
15 | restate.Set(ctx, MY_STATE, input)
16 | output, err := restate.Promise[string](ctx, MY_DURABLE_PROMISE).Result()
17 | if err != nil {
18 | return "", err
19 | }
20 |
21 | peek, err := restate.Promise[*string](ctx, MY_DURABLE_PROMISE).Peek()
22 | if peek == nil {
23 | return "", restate.TerminalErrorf("Durable promise should be completed")
24 | }
25 |
26 | return output, nil
27 | })).
28 | Handler("unblock", restate.NewWorkflowSharedHandler(
29 | func(ctx restate.WorkflowSharedContext, output string) (restate.Void, error) {
30 | return restate.Void{}, restate.Promise[string](ctx, MY_DURABLE_PROMISE).Resolve(output)
31 | })).
32 | Handler("getState", restate.NewWorkflowSharedHandler(
33 | func(ctx restate.WorkflowSharedContext, input restate.Void) (*string, error) {
34 | return restate.Get[*string](ctx, MY_STATE)
35 | })))
36 | }
37 |
--------------------------------------------------------------------------------
/proto/dev/restate/sdk/go.proto:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3 | *
4 | * This file is part of the Restate SDK for Go,
5 | * which is released under the MIT license.
6 | *
7 | * You can find a copy of the license in file LICENSE in the root
8 | * directory of this repository or package, or at
9 | * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10 | */
11 |
12 | syntax = "proto3";
13 |
14 | package dev.restate.sdk.go;
15 |
16 | import "google/protobuf/descriptor.proto";
17 |
18 | option go_package = "github.com/restatedev/sdk-go/generated/dev/restate/sdk";
19 |
20 | enum ServiceType {
21 | // SERVICE is the default and need not be provided
22 | SERVICE = 0;
23 | VIRTUAL_OBJECT = 1;
24 | WORKFLOW = 2;
25 | }
26 |
27 | enum HandlerType {
28 | // Handler type is ignored for service type SERVICE.
29 | // For VIRTUAL_OBJECT, defaults to EXCLUSIVE.
30 | // For WORKFLOW, defaults to SHARED.
31 | UNSET = 0;
32 | EXCLUSIVE = 1;
33 | SHARED = 2;
34 | // Signifies that this is the primary function for the workflow, typically named 'Run'.
35 | WORKFLOW_RUN = 3;
36 | }
37 |
38 | extend google.protobuf.MethodOptions {
39 | HandlerType handler_type = 2051;
40 | }
41 |
42 | extend google.protobuf.ServiceOptions {
43 | ServiceType service_type = 2051;
44 | }
45 |
--------------------------------------------------------------------------------
/examples/ticketreservation/checkout.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 |
7 | restate "github.com/restatedev/sdk-go"
8 | )
9 |
10 | type PaymentRequest struct {
11 | UserID string `json:"userId"`
12 | Tickets []string `json:"tickets"`
13 | }
14 |
15 | type PaymentResponse struct {
16 | ID string `json:"id"`
17 | Price int `json:"price"`
18 | }
19 |
20 | type checkout struct{}
21 |
22 | func (c *checkout) ServiceName() string {
23 | return CheckoutServiceName
24 | }
25 |
26 | const CheckoutServiceName = "Checkout"
27 |
28 | func (c *checkout) Payment(ctx restate.Context, request PaymentRequest) (response PaymentResponse, err error) {
29 | uuid := restate.Rand(ctx).UUID().String()
30 |
31 | response.ID = uuid
32 |
33 | // We are a uniform shop where everything costs 30 USD
34 | // that is cheaper than the official example :P
35 | price := len(request.Tickets) * 30
36 |
37 | response.Price = price
38 | _, err = restate.Run(ctx, func(ctx restate.RunContext) (bool, error) {
39 | log := ctx.Log().With("uuid", uuid, "price", price)
40 | if rand.Float64() < 0.5 {
41 | log.Info("payment succeeded")
42 | return true, nil
43 | } else {
44 | log.Error("payment failed")
45 | return false, fmt.Errorf("failed to pay")
46 | }
47 | })
48 |
49 | if err != nil {
50 | return response, err
51 | }
52 |
53 | // todo: send email
54 |
55 | return response, nil
56 | }
57 |
--------------------------------------------------------------------------------
/error.go:
--------------------------------------------------------------------------------
1 | package restate
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/restatedev/sdk-go/internal/errors"
7 | )
8 |
9 | // Code is a numeric status code for an error, typically a HTTP status code.
10 | type Code = errors.Code
11 |
12 | // WithErrorCode returns an error with specific [Code] attached.
13 | func WithErrorCode(err error, code Code) error {
14 | if err == nil {
15 | return nil
16 | }
17 |
18 | return &errors.CodeError{
19 | Inner: err,
20 | Code: code,
21 | }
22 | }
23 |
24 | // TerminalError returns a terminal error with optional code. Code is optional but only one code is allowed.
25 | // By default, restate will retry the invocation or Run function forever unless a terminal error is returned
26 | func TerminalError(err error, code ...errors.Code) error {
27 | return errors.NewTerminalError(err, code...)
28 | }
29 |
30 | // TerminalErrorf is a shorthand for combining fmt.Errorf with TerminalError
31 | func TerminalErrorf(format string, a ...any) error {
32 | return TerminalError(fmt.Errorf(format, a...))
33 | }
34 |
35 | // IsTerminalError checks if err is terminal - ie, that returning it in a handler or Run function will finish
36 | // the invocation with the error as a result.
37 | func IsTerminalError(err error) bool {
38 | return errors.IsTerminalError(err)
39 | }
40 |
41 | // ErrorCode returns [Code] associated with error, defaulting to 500
42 | func ErrorCode(err error) errors.Code {
43 | return errors.ErrorCode(err)
44 | }
45 |
--------------------------------------------------------------------------------
/test-services/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "github.com/restatedev/sdk-go/internal/log"
6 | "log/slog"
7 | "os"
8 | "strings"
9 |
10 | "github.com/restatedev/sdk-go/server"
11 | )
12 |
13 | func main() {
14 | // Accommodating for verification tests here
15 | logging := strings.ToLower(os.Getenv("RESTATE_LOGGING"))
16 | if logging == "error" {
17 | slog.SetLogLoggerLevel(slog.LevelError)
18 | } else if logging == "warn" {
19 | slog.SetLogLoggerLevel(slog.LevelWarn)
20 | } else if logging == "info" {
21 | slog.SetLogLoggerLevel(slog.LevelInfo)
22 | } else if logging == "debug" {
23 | slog.SetLogLoggerLevel(slog.LevelDebug)
24 | } else if logging == "trace" {
25 | slog.SetLogLoggerLevel(log.LevelTrace)
26 | }
27 |
28 | services := "*"
29 | if os.Getenv("SERVICES") != "" {
30 | services = os.Getenv("SERVICES")
31 | }
32 |
33 | server := server.NewRestate()
34 |
35 | if services == "*" {
36 | REGISTRY.RegisterAll(server)
37 | } else {
38 | fqdns := strings.Split(services, ",")
39 | set := make(map[string]struct{}, len(fqdns))
40 | for _, fqdn := range fqdns {
41 | set[fqdn] = struct{}{}
42 | }
43 | REGISTRY.Register(set, server)
44 | }
45 |
46 | port := os.Getenv("PORT")
47 | if port == "" {
48 | port = "9080"
49 | }
50 |
51 | if err := server.Start(context.Background(), ":"+port); err != nil {
52 | slog.Error("application exited unexpectedly", "err", err)
53 | os.Exit(1)
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/test-services/listobject.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | restate "github.com/restatedev/sdk-go"
5 | )
6 |
7 | const LIST_KEY = "list"
8 |
9 | func init() {
10 | REGISTRY.AddDefinition(
11 | restate.NewObject("ListObject").
12 | Handler("append", restate.NewObjectHandler(
13 | func(ctx restate.ObjectContext, value string) (restate.Void, error) {
14 | list, err := restate.Get[[]string](ctx, LIST_KEY)
15 | if err != nil {
16 | return restate.Void{}, err
17 | }
18 | list = append(list, value)
19 | restate.Set(ctx, LIST_KEY, list)
20 | return restate.Void{}, nil
21 | })).
22 | Handler("get", restate.NewObjectHandler(
23 | func(ctx restate.ObjectContext, _ restate.Void) ([]string, error) {
24 | list, err := restate.Get[[]string](ctx, LIST_KEY)
25 | if err != nil {
26 | return nil, err
27 | }
28 | if list == nil {
29 | // or go would encode this as JSON null
30 | list = []string{}
31 | }
32 |
33 | return list, nil
34 | })).
35 | Handler("clear", restate.NewObjectHandler(
36 | func(ctx restate.ObjectContext, _ restate.Void) ([]string, error) {
37 | list, err := restate.Get[[]string](ctx, LIST_KEY)
38 | if err != nil {
39 | return nil, err
40 | }
41 | if list == nil {
42 | // or go would encode this as JSON null
43 | list = []string{}
44 | }
45 | restate.Clear(ctx, LIST_KEY)
46 | return list, nil
47 | })))
48 | }
49 |
--------------------------------------------------------------------------------
/examples/otel/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/restatedev/sdk-go/examples/otel
2 |
3 | go 1.24.0
4 |
5 | toolchain go1.24.4
6 |
7 | require (
8 | github.com/restatedev/sdk-go v0.9.1
9 | go.opentelemetry.io/otel v1.38.0
10 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
11 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0
12 | go.opentelemetry.io/otel/sdk v1.38.0
13 | )
14 |
15 | require (
16 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect
17 | github.com/go-logr/logr v1.4.3 // indirect
18 | github.com/go-logr/stdr v1.2.2 // indirect
19 | github.com/golang-jwt/jwt/v5 v5.2.3 // indirect
20 | github.com/google/uuid v1.6.0 // indirect
21 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
22 | github.com/mr-tron/base58 v1.2.0 // indirect
23 | go.opentelemetry.io/otel/metric v1.38.0 // indirect
24 | go.opentelemetry.io/otel/trace v1.38.0 // indirect
25 | go.opentelemetry.io/proto/otlp v1.9.0 // indirect
26 | golang.org/x/net v0.43.0 // indirect
27 | golang.org/x/sys v0.37.0 // indirect
28 | golang.org/x/text v0.28.0 // indirect
29 | google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
30 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
31 | google.golang.org/grpc v1.75.1 // indirect
32 | google.golang.org/protobuf v1.36.10 // indirect
33 | )
34 |
35 | replace (
36 | github.com/restatedev/sdk-go => ../../
37 | github.com/restatedev/sdk-go/server => ../../server
38 | )
39 |
--------------------------------------------------------------------------------
/server/conn.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "errors"
5 | "io"
6 | "net/http"
7 | "sync"
8 | )
9 |
10 | type connection struct {
11 | r io.ReadCloser
12 | flusher http.Flusher
13 | w http.ResponseWriter
14 | cancel func()
15 |
16 | wLock sync.Mutex
17 | rLock sync.Mutex
18 | }
19 |
20 | func newConnection(w http.ResponseWriter, r *http.Request, cancel func()) *connection {
21 | flusher, _ := w.(http.Flusher)
22 | c := &connection{r: r.Body, flusher: flusher, w: w, cancel: cancel}
23 | return c
24 | }
25 |
26 | func (c *connection) Write(data []byte) (int, error) {
27 | c.wLock.Lock()
28 | defer c.wLock.Unlock()
29 |
30 | n, err := c.w.Write(data)
31 | if c.flusher != nil {
32 | c.flusher.Flush()
33 | }
34 | return n, err
35 | }
36 |
37 | func (c *connection) Read(data []byte) (int, error) {
38 | c.rLock.Lock()
39 | defer c.rLock.Unlock()
40 |
41 | n, err := c.r.Read(data)
42 | if errors.Is(err, http.ErrBodyReadAfterClose) ||
43 | // This error is returned when Close() comes while a Read is blocked.
44 | // Unfortunately the Golang stdlib won't give us a way to match with this error,
45 | // so we need this string matching
46 | (err != nil && err.Error() == "body closed by handler") {
47 | // make our state machine a bit more generic by avoiding this http error which to us means the same as EOF
48 | return n, io.EOF
49 | }
50 | return n, err
51 | }
52 |
53 | func (c *connection) Close() error {
54 | c.cancel()
55 | // Unblock Read()
56 | c.r.Close()
57 | return nil
58 | }
59 |
--------------------------------------------------------------------------------
/internal/converters/converters.go:
--------------------------------------------------------------------------------
1 | package converters
2 |
3 | import (
4 | "github.com/restatedev/sdk-go/internal/restatecontext"
5 | )
6 |
7 | type ToInnerFuture interface {
8 | InnerFuture() restatecontext.Selectable
9 | }
10 |
11 | type ResponseFuture[O any] struct {
12 | restatecontext.ResponseFuture
13 | }
14 |
15 | func (t ResponseFuture[O]) Response() (output O, err error) {
16 | err = t.ResponseFuture.Response(&output)
17 | return
18 | }
19 |
20 | func (t ResponseFuture[O]) InnerFuture() restatecontext.Selectable {
21 | return t.ResponseFuture
22 | }
23 |
24 | type RunAsyncFuture[O any] struct {
25 | restatecontext.RunAsyncFuture
26 | }
27 |
28 | func (t RunAsyncFuture[O]) Result() (output O, err error) {
29 | err = t.RunAsyncFuture.Result(&output)
30 | return
31 | }
32 |
33 | func (t RunAsyncFuture[O]) InnerFuture() restatecontext.Selectable {
34 | return t.RunAsyncFuture
35 | }
36 |
37 | type AttachFuture[O any] struct {
38 | restatecontext.AttachFuture
39 | }
40 |
41 | func (t AttachFuture[O]) Response() (output O, err error) {
42 | err = t.AttachFuture.Response(&output)
43 | return
44 | }
45 |
46 | func (t AttachFuture[O]) InnerFuture() restatecontext.Selectable {
47 | return t.AttachFuture
48 | }
49 |
50 | type AwakeableFuture[T any] struct {
51 | restatecontext.AwakeableFuture
52 | }
53 |
54 | func (t AwakeableFuture[T]) Result() (output T, err error) {
55 | err = t.AwakeableFuture.Result(&output)
56 | return
57 | }
58 |
59 | func (t AwakeableFuture[T]) InnerFuture() restatecontext.Selectable {
60 | return t.AwakeableFuture
61 | }
62 |
--------------------------------------------------------------------------------
/.github/workflows/docker.yaml:
--------------------------------------------------------------------------------
1 | name: Docker
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | tags:
7 | - v**
8 |
9 | env:
10 | REPOSITORY_OWNER: ${{ github.repository_owner }}
11 | GHCR_REGISTRY: "ghcr.io"
12 | GHCR_REGISTRY_USERNAME: ${{ github.actor }}
13 | GHCR_REGISTRY_TOKEN: ${{ secrets.GITHUB_TOKEN }}
14 |
15 | jobs:
16 | sdk-test-docker:
17 | if: github.repository_owner == 'restatedev'
18 | runs-on: ubuntu-latest
19 | name: "Create test-services Docker Image"
20 |
21 | steps:
22 | - uses: actions/checkout@v4
23 | with:
24 | repository: restatedev/sdk-go
25 |
26 | - name: Setup Go
27 | uses: actions/setup-go@v5
28 | with:
29 | go-version: "1.21.x"
30 |
31 | - name: Setup ko
32 | uses: ko-build/setup-ko@v0.6
33 | with:
34 | version: v0.16.0
35 |
36 | - name: Log into GitHub container registry
37 | uses: docker/login-action@v2
38 | with:
39 | registry: ${{ env.GHCR_REGISTRY }}
40 | username: ${{ env.GHCR_REGISTRY_USERNAME }}
41 | password: ${{ env.GHCR_REGISTRY_TOKEN }}
42 |
43 | - name: Install dependencies
44 | run: go get .
45 |
46 | - name: Build Docker image
47 | run: KO_DOCKER_REPO=restatedev ko build --platform=linux/amd64,linux/arm64 -B -L github.com/restatedev/sdk-go/test-services
48 |
49 | - name: Push restatedev/test-services-java:main image
50 | run: |
51 | docker tag restatedev/test-services ghcr.io/restatedev/test-services-go:main
52 | docker push ghcr.io/restatedev/test-services-go:main
53 |
--------------------------------------------------------------------------------
/examples/codegen/proto/helloworld.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | option go_package = "github.com/restatedev/sdk-go/examples/codegen/proto";
4 |
5 | import "dev/restate/sdk/go.proto";
6 |
7 | package helloworld;
8 |
9 | service Greeter {
10 | rpc SayHello (HelloRequest) returns (HelloResponse) {}
11 | }
12 |
13 | service Counter {
14 | option (dev.restate.sdk.go.service_type) = VIRTUAL_OBJECT;
15 | // Mutate the value
16 | rpc Add (AddRequest) returns (GetResponse) {}
17 | // Get the current value
18 | rpc Get (GetRequest) returns (GetResponse) {
19 | option (dev.restate.sdk.go.handler_type) = SHARED;
20 | }
21 | }
22 |
23 | service Workflow {
24 | option (dev.restate.sdk.go.service_type) = WORKFLOW;
25 | // Execute the workflow
26 | rpc Run (RunRequest) returns (RunResponse) {}
27 | // Unblock the workflow
28 | rpc Finish(FinishRequest) returns (FinishResponse) {}
29 | // Check the current status
30 | rpc Status (StatusRequest) returns (StatusResponse) {}
31 | }
32 |
33 | message HelloRequest {
34 | string name = 1;
35 | }
36 |
37 | message HelloResponse {
38 | string message = 1;
39 | }
40 |
41 | message AddRequest {
42 | int64 delta = 1;
43 | }
44 |
45 | message GetRequest {}
46 |
47 | message GetResponse {
48 | int64 value = 1;
49 | }
50 |
51 | message RunRequest {}
52 |
53 | message RunResponse {
54 | string status = 1;
55 | }
56 |
57 | message StatusRequest {}
58 |
59 | message StatusResponse {
60 | string status = 1;
61 | }
62 |
63 | message FinishRequest {}
64 |
65 | message FinishResponse {}
66 |
67 | message Test {
68 | Test inner = 1;
69 | string primitive = 2;
70 | StatusResponse another_inner = 3;
71 | }
--------------------------------------------------------------------------------
/context.go:
--------------------------------------------------------------------------------
1 | package restate
2 |
3 | import (
4 | "github.com/restatedev/sdk-go/internal/restatecontext"
5 | )
6 |
7 | // RunContext is passed to [Run] closures and provides the limited set of Restate operations that are safe to use there.
8 | type RunContext = restatecontext.RunContext
9 |
10 | // Request contains a set of information about the request that started an invocation
11 | type Request = restatecontext.Request
12 |
13 | // Context is an extension of [RunContext] which is passed to Restate service handlers and enables
14 | // interaction with Restate
15 | type Context interface {
16 | RunContext
17 | inner() restatecontext.Context
18 | }
19 |
20 | // ObjectSharedContext is an extension of [Context] which is passed to shared-mode Virtual Object handlers,
21 | // giving read-only access to a snapshot of state.
22 | type ObjectSharedContext interface {
23 | Context
24 | object()
25 | }
26 |
27 | // ObjectContext is an extension of [ObjectSharedContext] which is passed to exclusive-mode Virtual Object handlers.
28 | // giving mutable access to state.
29 | type ObjectContext interface {
30 | ObjectSharedContext
31 | exclusiveObject()
32 | }
33 |
34 | // WorkflowSharedContext is an extension of [ObjectSharedContext] which is passed to shared-mode Workflow handlers,
35 | // giving read-only access to a snapshot of state.
36 | type WorkflowSharedContext interface {
37 | ObjectSharedContext
38 | workflow()
39 | }
40 |
41 | // WorkflowContext is an extension of [WorkflowSharedContext] and [ObjectContext] which is passed to Workflow 'run' handlers,
42 | // giving mutable access to state.
43 | type WorkflowContext interface {
44 | WorkflowSharedContext
45 | ObjectContext
46 | runWorkflow()
47 | }
48 |
--------------------------------------------------------------------------------
/internal/restatecontext/io_helpers.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "context"
5 | "github.com/restatedev/sdk-go/internal/log"
6 | "github.com/restatedev/sdk-go/internal/statemachine"
7 | "io"
8 | "log/slog"
9 | "sync"
10 | )
11 |
12 | var BufPool sync.Pool
13 |
14 | func init() {
15 | BufPool = sync.Pool{New: func() interface{} {
16 | return make([]byte, 1024)
17 | }}
18 | }
19 |
20 | func takeOutputAndWriteOut(ctx context.Context, machine *statemachine.StateMachine, conn io.WriteCloser) error {
21 | buffer, err := machine.TakeOutput(ctx)
22 | if err == io.EOF {
23 | return conn.Close()
24 | } else if err != nil {
25 | return err
26 | }
27 | _, err = conn.Write(buffer)
28 | return err
29 | }
30 |
31 | func consumeOutput(ctx context.Context, machine *statemachine.StateMachine, conn io.WriteCloser) error {
32 | for {
33 | buffer, err := machine.TakeOutput(ctx)
34 | if err == io.EOF {
35 | return conn.Close()
36 | } else if err != nil {
37 | return err
38 | }
39 |
40 | _, err = conn.Write(buffer)
41 | if err != nil {
42 | return err
43 | }
44 | }
45 | }
46 |
47 | type readResult struct {
48 | nRead int
49 | buf []byte
50 | }
51 |
52 | func (restateCtx *ctx) readInputLoop(logger *slog.Logger) {
53 | for {
54 | // Acquire buf
55 | tempBuf := BufPool.Get().([]byte)
56 | read, err := restateCtx.conn.Read(tempBuf)
57 | if err != nil {
58 | BufPool.Put(tempBuf)
59 | if err != io.EOF {
60 | logger.WarnContext(restateCtx, "Unexpected when reading input", log.Error(err))
61 | }
62 | close(restateCtx.readChan)
63 | return
64 | }
65 | if read != 0 {
66 | restateCtx.readChan <- readResult{
67 | nRead: read,
68 | buf: tempBuf,
69 | }
70 | }
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/test-services/counter.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | restate "github.com/restatedev/sdk-go"
5 | )
6 |
7 | const COUNTER_KEY = "counter"
8 |
9 | type CounterUpdateResponse struct {
10 | OldValue int64 `json:"oldValue"`
11 | NewValue int64 `json:"newValue"`
12 | }
13 |
14 | func init() {
15 | REGISTRY.AddDefinition(
16 | restate.NewObject("Counter").
17 | Handler("reset", restate.NewObjectHandler(
18 | func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
19 | restate.Clear(ctx, COUNTER_KEY)
20 | return restate.Void{}, nil
21 | })).
22 | Handler("get", restate.NewObjectSharedHandler(
23 | func(ctx restate.ObjectSharedContext, _ restate.Void) (int64, error) {
24 | return restate.Get[int64](ctx, COUNTER_KEY)
25 | })).
26 | Handler("add", restate.NewObjectHandler(
27 | func(ctx restate.ObjectContext, addend int64) (CounterUpdateResponse, error) {
28 | oldValue, err := restate.Get[int64](ctx, COUNTER_KEY)
29 | if err != nil {
30 | return CounterUpdateResponse{}, err
31 | }
32 |
33 | newValue := oldValue + addend
34 | restate.Set(ctx, COUNTER_KEY, newValue)
35 |
36 | return CounterUpdateResponse{
37 | OldValue: oldValue,
38 | NewValue: newValue,
39 | }, nil
40 | })).
41 | Handler("addThenFail", restate.NewObjectHandler(
42 | func(ctx restate.ObjectContext, addend int64) (restate.Void, error) {
43 | oldValue, err := restate.Get[int64](ctx, COUNTER_KEY)
44 | if err != nil {
45 | return restate.Void{}, err
46 | }
47 |
48 | newValue := oldValue + addend
49 | restate.Set(ctx, COUNTER_KEY, newValue)
50 |
51 | return restate.Void{}, restate.TerminalErrorf("%s", restate.Key(ctx))
52 | })))
53 | }
54 |
--------------------------------------------------------------------------------
/internal/restatecontext/sleep.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "fmt"
5 | "github.com/restatedev/sdk-go/internal/options"
6 | "github.com/restatedev/sdk-go/internal/statemachine"
7 | "time"
8 | )
9 |
10 | func (restateCtx *ctx) Sleep(d time.Duration, opts ...options.SleepOption) error {
11 | return restateCtx.After(d, opts...).Done()
12 | }
13 |
14 | // After is a coreHandle on a Sleep operation which allows you to do other work concurrently
15 | // with the sleep.
16 | type AfterFuture interface {
17 | // Done blocks waiting on the remaining duration of the sleep.
18 | // It is *not* safe to call this in a goroutine - use Context.Select if you want to wait on multiple
19 | // results at once. Can return a terminal error in the case where the invocation was cancelled mid-sleep,
20 | // hence Done() should always be called, even afterFuture using Context.Select.
21 | Done() error
22 | Selectable
23 | }
24 |
25 | func (restateCtx *ctx) After(d time.Duration, opts ...options.SleepOption) AfterFuture {
26 | o := options.SleepOptions{}
27 | for _, opt := range opts {
28 | opt.BeforeSleep(&o)
29 | }
30 |
31 | handle, err := restateCtx.stateMachine.SysSleep(restateCtx, o.Name, d)
32 | if err != nil {
33 | panic(err)
34 | }
35 | restateCtx.checkStateTransition()
36 |
37 | return &afterFuture{
38 | asyncResult: newAsyncResult(restateCtx, handle),
39 | }
40 | }
41 |
42 | type afterFuture struct {
43 | asyncResult
44 | }
45 |
46 | func (a *afterFuture) Done() error {
47 | switch result := a.pollProgressAndLoadValue().(type) {
48 | case statemachine.ValueVoid:
49 | return nil
50 | case statemachine.ValueFailure:
51 | return errorFromFailure(result)
52 | default:
53 | panic(fmt.Errorf("unexpected value %s", result))
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/examples/parallelizework/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "log/slog"
8 | "math/rand"
9 | "os"
10 | "strings"
11 |
12 | restate "github.com/restatedev/sdk-go"
13 | "github.com/restatedev/sdk-go/server"
14 | )
15 |
16 | type fanOutWorker struct{}
17 |
18 | func (c *fanOutWorker) ServiceName() string {
19 | return FanOutWorkerServiceName
20 | }
21 |
22 | const FanOutWorkerServiceName = "FanOutWorker"
23 |
24 | func (c *fanOutWorker) Run(ctx restate.Context, commaSeparatedTasks string) (aggregatedResults string, err error) {
25 | tasks := strings.Split(commaSeparatedTasks, ",")
26 |
27 | // Run tasks in parallel
28 | var futs []restate.Selectable
29 | for _, task := range tasks {
30 | futs = append(futs, restate.RunAsync[string](ctx, func(ctx restate.RunContext) (string, error) {
31 | log.Printf("Heavy task %s running", task)
32 | if rand.Intn(2) == 1 {
33 | log.Printf("Heavy task %s failed", task)
34 | panic(fmt.Errorf("failed to complete heavy task %s", task))
35 | }
36 | log.Printf("Heavy task %s done", task)
37 | return task, nil
38 | }))
39 | }
40 |
41 | // Aggregate
42 | var results []string
43 | for fu, err := range restate.Wait(ctx, futs...) {
44 | if err != nil {
45 | return "", err
46 | }
47 | result, err := fu.(restate.RunAsyncFuture[string]).Result()
48 | if err != nil {
49 | return "", err
50 | }
51 | results = append(results, result)
52 | }
53 |
54 | return strings.Join(results, "-"), nil
55 | }
56 |
57 | func main() {
58 | slog.SetLogLoggerLevel(slog.LevelDebug)
59 | server := server.NewRestate().Bind(restate.Reflect(&fanOutWorker{}))
60 |
61 | if err := server.Start(context.Background(), ":9080"); err != nil {
62 | slog.Error("application exited unexpectedly", "err", err.Error())
63 | os.Exit(1)
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/examples/otel/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 |
8 | restate "github.com/restatedev/sdk-go"
9 | "github.com/restatedev/sdk-go/server"
10 | "go.opentelemetry.io/otel"
11 | "go.opentelemetry.io/otel/attribute"
12 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace"
13 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
14 | "go.opentelemetry.io/otel/propagation"
15 | "go.opentelemetry.io/otel/sdk/resource"
16 | sdktrace "go.opentelemetry.io/otel/sdk/trace"
17 | )
18 |
19 | type Greeter struct{}
20 |
21 | func (Greeter) Greet(ctx restate.Context, message string) (string, error) {
22 | _, span := otel.Tracer("").Start(ctx, "Greet")
23 | defer span.End()
24 |
25 | return fmt.Sprintf("%s!", message), nil
26 | }
27 |
28 | func main() {
29 | exporter, err := otlptrace.New(
30 | context.Background(),
31 | otlptracegrpc.NewClient(
32 | otlptracegrpc.WithInsecure(),
33 | otlptracegrpc.WithEndpoint("localhost:4317"),
34 | ),
35 | )
36 |
37 | if err != nil {
38 | log.Fatalf("Could not set exporter: %v", err)
39 | }
40 |
41 | resources, err := resource.New(
42 | context.Background(),
43 | resource.WithAttributes(
44 | attribute.String("service.name", "restate-sdk-go-otel-example-greeter"),
45 | ),
46 | )
47 | if err != nil {
48 | log.Fatalf("Could not set resources: %v", err)
49 | }
50 |
51 | otel.SetTracerProvider(
52 | sdktrace.NewTracerProvider(
53 | sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.AlwaysSample())),
54 | sdktrace.WithSpanProcessor(sdktrace.NewBatchSpanProcessor(exporter)),
55 | sdktrace.WithResource(resources),
56 | ),
57 | )
58 |
59 | otel.SetTextMapPropagator(propagation.TraceContext{})
60 |
61 | if err := server.NewRestate().
62 | Bind(restate.Reflect(Greeter{})).
63 | Start(context.Background(), ":9080"); err != nil {
64 | log.Fatal(err)
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/examples/ticketreservation/ticket_service.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | restate "github.com/restatedev/sdk-go"
5 | )
6 |
7 | type TicketStatus int
8 |
9 | const (
10 | TicketAvailable TicketStatus = 0
11 | TicketReserved TicketStatus = 1
12 | TicketSold TicketStatus = 2
13 | )
14 |
15 | const TicketServiceName = "TicketService"
16 |
17 | type ticketService struct{}
18 |
19 | func (t *ticketService) ServiceName() string { return TicketServiceName }
20 |
21 | func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool, error) {
22 | status, err := restate.Get[TicketStatus](ctx, "status")
23 | if err != nil {
24 | return false, err
25 | }
26 |
27 | if status == TicketAvailable {
28 | restate.Set(ctx, "status", TicketReserved)
29 | return true, nil
30 | }
31 |
32 | return false, nil
33 | }
34 |
35 | func (t *ticketService) Unreserve(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, err error) {
36 | ticketId := restate.Key(ctx)
37 | ctx.Log().Info("un-reserving ticket", "ticket", ticketId)
38 | status, err := restate.Get[TicketStatus](ctx, "status")
39 | if err != nil {
40 | return void, err
41 | }
42 |
43 | if status != TicketSold {
44 | restate.Clear(ctx, "status")
45 | return void, nil
46 | }
47 |
48 | return void, nil
49 | }
50 |
51 | func (t *ticketService) MarkAsSold(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, err error) {
52 | ticketId := restate.Key(ctx)
53 | ctx.Log().Info("mark ticket as sold", "ticket", ticketId)
54 |
55 | status, err := restate.Get[TicketStatus](ctx, "status")
56 | if err != nil {
57 | return void, err
58 | }
59 |
60 | if status == TicketReserved {
61 | restate.Set(ctx, "status", TicketSold)
62 | return void, nil
63 | }
64 |
65 | return void, nil
66 | }
67 |
68 | func (t *ticketService) Status(ctx restate.ObjectSharedContext, _ restate.Void) (TicketStatus, error) {
69 | ticketId := restate.Key(ctx)
70 | ctx.Log().Info("mark ticket as sold", "ticket", ticketId)
71 |
72 | return restate.Get[TicketStatus](ctx, "status")
73 | }
74 |
--------------------------------------------------------------------------------
/protoc-gen-go-restate/README.md:
--------------------------------------------------------------------------------
1 | # protoc-gen-go-grpc
2 |
3 | This tool generates Go language bindings of `service`s in protobuf definition
4 | files for Restate.
5 |
6 | An example of their use can be found in [examples/codegen](../examples/codegen)
7 |
8 | ## Usage
9 | Via protoc:
10 | ```shell
11 | go install github.com/restatedev/sdk-go/protoc-gen-go-restate@latest
12 | protoc --go_out=. --go_opt=paths=source_relative \
13 | --go-restate_out=. --go-restate_opt=paths=source_relative service.proto
14 | ```
15 |
16 | Via [buf](https://buf.build/):
17 | ```yaml
18 | # buf.gen.yaml
19 | plugins:
20 | - remote: buf.build/protocolbuffers/go:v1.34.2
21 | out: .
22 | opt: paths=source_relative
23 | - local: protoc-gen-go-restate
24 | out: .
25 | opt: paths=source_relative
26 | ```
27 |
28 | # Providing options
29 | This protoc plugin supports the service and method extensions defined in
30 | [proto/dev/restate/sdk/go.proto](../proto/dev/restate/sdk/go.proto).
31 | You will need to use these extensions to define virtual objects in proto.
32 |
33 | You can import the extensions with the statement `import "dev/restate/sdk/go.proto";`. Protoc will expect an equivalent directory
34 | structure containing the go.proto file either locally, or under any of the
35 | paths provided with `--proto_path`. It may be easier to use
36 | [buf](https://buf.build/docs/bsr/module/dependency-management) to import:
37 | ```yaml
38 | # buf.yaml
39 | version: v2
40 | deps:
41 | - buf.build/restatedev/sdk-go
42 | ```
43 |
44 | # Upgrading from pre-v0.14
45 | This generator used to create Restate services and methods using the Go names (eg `Greeter/SayHello`) instead of the fully qualified protobuf names (eg `helloworld.Greeter/SayHello`).
46 | This was changed to make this package more compatible with gRPC.
47 | To maintain the old behaviour, pass `--go-restate_opt=use_go_service_names=true` to `protoc`. With buf:
48 | ```yaml
49 | ...
50 | - local: protoc-gen-go-restate
51 | out: .
52 | opt:
53 | - paths=source_relative
54 | - use_go_service_names=true
55 | ```
56 |
--------------------------------------------------------------------------------
/protoc-gen-go-restate/main.go:
--------------------------------------------------------------------------------
1 | // protoc-gen-go-restate is a plugin for the Google protocol buffer compiler to
2 | // generate Restate servers and clients. Install it by building this program and
3 | // making it accessible within your PATH with the name:
4 | //
5 | // protoc-gen-go-restate
6 | //
7 | // The 'go-restate' suffix becomes part of the argument for the protocol compiler,
8 | // such that it can be invoked as:
9 | //
10 | // protoc --go-restate_out=. path/to/file.proto
11 | //
12 | // This generates Restate service definitions for the protocol buffer defined by
13 | // file.proto. With that input, the output will be written to:
14 | //
15 | // path/to/file_restate.pb.go
16 | //
17 | // Lots of code copied from protoc-gen-go-grpc:
18 | // https://github.com/grpc/grpc-go/tree/master/cmd/protoc-gen-go-grpc
19 | // ! License Apache-2.0
20 | package main
21 |
22 | import (
23 | "flag"
24 | "fmt"
25 |
26 | "google.golang.org/protobuf/compiler/protogen"
27 | "google.golang.org/protobuf/types/pluginpb"
28 | )
29 |
30 | var version = "0.1"
31 |
32 | var requireUnimplemented *bool
33 | var useGoServiceNames *bool
34 |
35 | func main() {
36 | showVersion := flag.Bool("version", false, "print the version and exit")
37 | flag.Parse()
38 | if *showVersion {
39 | fmt.Printf("protoc-gen-go-grpc %v\n", version)
40 | return
41 | }
42 |
43 | var flags flag.FlagSet
44 | requireUnimplemented = flags.Bool("require_unimplemented_servers", false, "set to true to disallow servers that have unimplemented fields")
45 | useGoServiceNames = flags.Bool("use_go_service_names", false, "set to true to use Go names for service and method names instead of the Protobuf fully qualified names. This used to be the default behaviour")
46 |
47 | protogen.Options{
48 | ParamFunc: flags.Set,
49 | }.Run(func(gen *protogen.Plugin) error {
50 | gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
51 | for _, f := range gen.Files {
52 | if !f.Generate {
53 | continue
54 | }
55 | generateFile(gen, f)
56 | }
57 | return nil
58 | })
59 | }
60 |
--------------------------------------------------------------------------------
/internal/restatecontext/select.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | // Selector is an iterator over a list of blocking Restate operations that are running
4 | // in the background.
5 | type Selector interface {
6 | // Remaining returns whether there are still operations that haven't been returned by Select().
7 | // There will always be exactly the same number of results as there were operations
8 | // given to Context.Select
9 | Remaining() bool
10 | // Select blocks on the next completed operation or returns nil if there are none left
11 | Select() Selectable
12 | }
13 |
14 | type selector struct {
15 | restateCtx *ctx
16 | indexedFuts map[uint32]Selectable
17 | }
18 |
19 | func (restateCtx *ctx) Select(futs ...Selectable) Selector {
20 | indexedFuts := make(map[uint32]Selectable, len(futs))
21 | for i := range futs {
22 | handle := futs[i].handle()
23 | indexedFuts[handle] = futs[i]
24 | }
25 |
26 | return &selector{
27 | restateCtx: restateCtx,
28 | indexedFuts: indexedFuts,
29 | }
30 | }
31 |
32 | func (s *selector) Select() Selectable {
33 | if !s.Remaining() {
34 | return nil
35 | }
36 |
37 | remainingHandles := make([]uint32, len(s.indexedFuts))
38 | for k := range s.indexedFuts {
39 | remainingHandles = append(remainingHandles, k)
40 | }
41 |
42 | // Do progress
43 | cancelled := s.restateCtx.pollProgress(remainingHandles)
44 | if cancelled {
45 | panic("cancellation is not supported by the Selector API, please use Wait/WaitFirst/WaitIter instead")
46 | }
47 |
48 | // If we exit, one of them is completed, gotta figure out which one
49 | for _, handle := range remainingHandles {
50 | completed, err := s.restateCtx.stateMachine.IsCompleted(s.restateCtx, handle)
51 | if err != nil {
52 | panic(err)
53 | }
54 | if completed {
55 | fut := s.indexedFuts[handle]
56 | delete(s.indexedFuts, handle)
57 | return fut
58 | }
59 | }
60 |
61 | panic("Unexpectedly none of the remaining handles completed, this looks like a bug")
62 | }
63 |
64 | func (s *selector) Remaining() bool {
65 | return len(s.indexedFuts) != 0
66 | }
67 |
--------------------------------------------------------------------------------
/mocks/mock_Invocation.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.52.1. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import mock "github.com/stretchr/testify/mock"
6 |
7 | // MockInvocation is an autogenerated mock type for the Invocation type
8 | type MockInvocation struct {
9 | mock.Mock
10 | }
11 |
12 | type MockInvocation_Expecter struct {
13 | mock *mock.Mock
14 | }
15 |
16 | func (_m *MockInvocation) EXPECT() *MockInvocation_Expecter {
17 | return &MockInvocation_Expecter{mock: &_m.Mock}
18 | }
19 |
20 | // GetInvocationId provides a mock function with no fields
21 | func (_m *MockInvocation) GetInvocationId() string {
22 | ret := _m.Called()
23 |
24 | if len(ret) == 0 {
25 | panic("no return value specified for GetInvocationId")
26 | }
27 |
28 | var r0 string
29 | if rf, ok := ret.Get(0).(func() string); ok {
30 | r0 = rf()
31 | } else {
32 | r0 = ret.Get(0).(string)
33 | }
34 |
35 | return r0
36 | }
37 |
38 | // MockInvocation_GetInvocationId_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetInvocationId'
39 | type MockInvocation_GetInvocationId_Call struct {
40 | *mock.Call
41 | }
42 |
43 | // GetInvocationId is a helper method to define mock.On call
44 | func (_e *MockInvocation_Expecter) GetInvocationId() *MockInvocation_GetInvocationId_Call {
45 | return &MockInvocation_GetInvocationId_Call{Call: _e.mock.On("GetInvocationId")}
46 | }
47 |
48 | func (_c *MockInvocation_GetInvocationId_Call) Run(run func()) *MockInvocation_GetInvocationId_Call {
49 | _c.Call.Run(func(args mock.Arguments) {
50 | run()
51 | })
52 | return _c
53 | }
54 |
55 | func (_c *MockInvocation_GetInvocationId_Call) Return(_a0 string) *MockInvocation_GetInvocationId_Call {
56 | _c.Call.Return(_a0)
57 | return _c
58 | }
59 |
60 | func (_c *MockInvocation_GetInvocationId_Call) RunAndReturn(run func() string) *MockInvocation_GetInvocationId_Call {
61 | _c.Call.Return(run)
62 | return _c
63 | }
64 |
65 | // NewMockInvocation creates a new instance of MockInvocation. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
66 | // The first argument is typically a *testing.T value.
67 | func NewMockInvocation(t interface {
68 | mock.TestingT
69 | Cleanup(func())
70 | }) *MockInvocation {
71 | mock := &MockInvocation{}
72 | mock.Mock.Test(t)
73 |
74 | t.Cleanup(func() { mock.AssertExpectations(t) })
75 |
76 | return mock
77 | }
78 |
--------------------------------------------------------------------------------
/internal/log/log.go:
--------------------------------------------------------------------------------
1 | package log
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log/slog"
7 | "reflect"
8 | "sync/atomic"
9 |
10 | "github.com/restatedev/sdk-go/rcontext"
11 | )
12 |
13 | const (
14 | LevelTrace slog.Level = -8
15 | )
16 |
17 | type typeValue struct{ inner any }
18 |
19 | func (t typeValue) LogValue() slog.Value {
20 | return slog.StringValue(reflect.TypeOf(t.inner).String())
21 | }
22 |
23 | func Type(key string, value any) slog.Attr {
24 | return slog.Any(key, typeValue{value})
25 | }
26 |
27 | type stringerValue[T fmt.Stringer] struct{ inner T }
28 |
29 | func (t stringerValue[T]) LogValue() slog.Value {
30 | return slog.StringValue(t.inner.String())
31 | }
32 |
33 | func Stringer[T fmt.Stringer](key string, value T) slog.Attr {
34 | return slog.Any(key, stringerValue[T]{value})
35 | }
36 |
37 | func Error(err error) slog.Attr {
38 | return slog.String("err", err.Error())
39 | }
40 |
41 | type contextInjectingHandler struct {
42 | logContext *atomic.Pointer[rcontext.LogContext]
43 | dropReplay bool
44 | inner slog.Handler
45 | }
46 |
47 | func NewUserContextHandler(logContext *atomic.Pointer[rcontext.LogContext], dropReplay bool, inner slog.Handler) slog.Handler {
48 | return &contextInjectingHandler{logContext, dropReplay, inner}
49 | }
50 |
51 | func NewRestateContextHandler(inner slog.Handler) slog.Handler {
52 | logContext := atomic.Pointer[rcontext.LogContext]{}
53 | logContext.Store(&rcontext.LogContext{Source: rcontext.LogSourceRestate, IsReplaying: false})
54 | return &contextInjectingHandler{&logContext, false, inner}
55 | }
56 |
57 | func (d *contextInjectingHandler) Enabled(ctx context.Context, l slog.Level) bool {
58 | lc := d.logContext.Load()
59 | if d.dropReplay && lc.IsReplaying {
60 | return false
61 | }
62 | return d.inner.Enabled(rcontext.WithLogContext(ctx, lc), l)
63 | }
64 |
65 | func (d *contextInjectingHandler) Handle(ctx context.Context, record slog.Record) error {
66 | return d.inner.Handle(rcontext.WithLogContext(ctx, d.logContext.Load()), record)
67 | }
68 |
69 | func (d *contextInjectingHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
70 | return &contextInjectingHandler{d.logContext, d.dropReplay, d.inner.WithAttrs(attrs)}
71 | }
72 |
73 | func (d *contextInjectingHandler) WithGroup(name string) slog.Handler {
74 | return &contextInjectingHandler{d.logContext, d.dropReplay, d.inner.WithGroup(name)}
75 | }
76 |
77 | var _ slog.Handler = &contextInjectingHandler{}
78 |
--------------------------------------------------------------------------------
/test-services/canceltest.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "time"
6 |
7 | restate "github.com/restatedev/sdk-go"
8 | )
9 |
10 | const CanceledState = "canceled"
11 |
12 | type BlockingOperation string
13 |
14 | const (
15 | CallOp BlockingOperation = "CALL"
16 | SleepOp BlockingOperation = "SLEEP"
17 | AwakeableOp BlockingOperation = "AWAKEABLE"
18 | )
19 |
20 | func init() {
21 | REGISTRY.AddDefinition(
22 | restate.NewObject("CancelTestRunner").
23 | Handler("startTest", restate.NewObjectHandler(
24 | func(ctx restate.ObjectContext, operation BlockingOperation) (restate.Void, error) {
25 | if _, err := restate.Object[restate.Void](ctx, "CancelTestBlockingService", restate.Key(ctx), "block").Request(operation); err != nil {
26 | if restate.ErrorCode(err) == 409 {
27 | restate.Set(ctx, CanceledState, true)
28 | return restate.Void{}, nil
29 | }
30 | return restate.Void{}, err
31 | }
32 | return restate.Void{}, nil
33 | })).
34 | Handler("verifyTest", restate.NewObjectHandler(
35 | func(ctx restate.ObjectContext, _ restate.Void) (bool, error) {
36 | return restate.Get[bool](ctx, CanceledState)
37 | })))
38 | REGISTRY.AddDefinition(
39 | restate.NewObject("CancelTestBlockingService").
40 | Handler("block", restate.NewObjectHandler(
41 | func(ctx restate.ObjectContext, operation BlockingOperation) (restate.Void, error) {
42 | awakeable := restate.Awakeable[restate.Void](ctx)
43 | if _, err := restate.Object[restate.Void](ctx, "AwakeableHolder", restate.Key(ctx), "hold").Request(awakeable.Id()); err != nil {
44 | return restate.Void{}, err
45 | }
46 | if _, err := awakeable.Result(); err != nil {
47 | return restate.Void{}, err
48 | }
49 | switch operation {
50 | case CallOp:
51 | return restate.Object[restate.Void](ctx, "CancelTestBlockingService", restate.Key(ctx), "block").Request(operation)
52 | case SleepOp:
53 | return restate.Void{}, restate.Sleep(ctx, 1024*time.Hour*24)
54 | case AwakeableOp:
55 | return restate.Awakeable[restate.Void](ctx).Result()
56 | default:
57 | return restate.Void{}, restate.TerminalError(fmt.Errorf("unexpected operation %s", operation), 400)
58 | }
59 | })).
60 | Handler("isUnlocked", restate.NewObjectHandler(
61 | func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
62 | // no-op
63 | return restate.Void{}, nil
64 | })))
65 | }
66 |
--------------------------------------------------------------------------------
/internal/identity/v1.go:
--------------------------------------------------------------------------------
1 | package identity
2 |
3 | import (
4 | "crypto/ed25519"
5 | "fmt"
6 | "strings"
7 |
8 | jwt "github.com/golang-jwt/jwt/v5"
9 | "github.com/mr-tron/base58"
10 | )
11 |
12 | const (
13 | JWT_HEADER = "X-Restate-Jwt-V1"
14 | SchemeV1 SignatureScheme = "v1"
15 | )
16 |
17 | type KeySetV1 = map[string]ed25519.PublicKey
18 |
19 | func validateV1(keySet KeySetV1, path string, headers map[string][]string) error {
20 | switch len(headers[JWT_HEADER]) {
21 | case 0:
22 | return fmt.Errorf("v1 signature scheme expects the following headers: [%s]", JWT_HEADER)
23 | case 1:
24 | default:
25 | return fmt.Errorf("unexpected multi-value JWT header: %v", headers[JWT_HEADER])
26 | }
27 |
28 | token, err := jwt.Parse(headers[JWT_HEADER][0], func(token *jwt.Token) (interface{}, error) {
29 | if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
30 | return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
31 | }
32 |
33 | kid, ok := token.Header["kid"]
34 | if !ok {
35 | return nil, fmt.Errorf("Token missing 'kid' header field")
36 | }
37 |
38 | kidS, ok := kid.(string)
39 | if !ok {
40 | return nil, fmt.Errorf("Token 'kid' header field was not a string: %v", kid)
41 | }
42 |
43 | key, ok := keySet[kidS]
44 | if !ok {
45 | return nil, fmt.Errorf("Key ID %s is not present in key set", kid)
46 | }
47 |
48 | return key, nil
49 | }, jwt.WithValidMethods([]string{"EdDSA"}), jwt.WithAudience(path), jwt.WithExpirationRequired())
50 | if err != nil {
51 | return fmt.Errorf("failed to validate v1 request identity jwt: %w", err)
52 | }
53 |
54 | nbf, _ := token.Claims.GetNotBefore()
55 | if nbf == nil {
56 | // jwt library only validates nbf if its present, so we should check it was present
57 | return fmt.Errorf("'nbf' claim is missing in v1 request identity jwt")
58 | }
59 |
60 | return nil
61 | }
62 |
63 | func ParseKeySetV1(keys []string) (KeySetV1, error) {
64 | out := make(KeySetV1, len(keys))
65 | for _, key := range keys {
66 | if !strings.HasPrefix(key, "publickeyv1_") {
67 | return nil, fmt.Errorf("v1 public key must start with 'publickeyv1_'")
68 | }
69 |
70 | pubBytes, err := base58.Decode(key[len("publickeyv1_"):])
71 | if err != nil {
72 | return nil, fmt.Errorf("v1 public key must be valid base58: %w", err)
73 | }
74 |
75 | if len(pubBytes) != ed25519.PublicKeySize {
76 | return nil, fmt.Errorf("v1 public key must have exactly %d bytes, found %d", ed25519.PublicKeySize, len(pubBytes))
77 | }
78 |
79 | out[key] = ed25519.PublicKey(pubBytes)
80 | }
81 |
82 | return out, nil
83 | }
84 |
--------------------------------------------------------------------------------
/test-services/testutils.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "os"
5 | "strings"
6 | "sync/atomic"
7 | "time"
8 |
9 | restate "github.com/restatedev/sdk-go"
10 | )
11 |
12 | func init() {
13 | REGISTRY.AddDefinition(
14 | restate.NewService("TestUtilsService").
15 | Handler("echo", restate.NewServiceHandler(
16 | func(ctx restate.Context, input string) (string, error) {
17 | return input, nil
18 | })).
19 | Handler("uppercaseEcho", restate.NewServiceHandler(
20 | func(ctx restate.Context, input string) (string, error) {
21 | return strings.ToUpper(input), nil
22 | })).
23 | Handler("echoHeaders", restate.NewServiceHandler(
24 | func(ctx restate.Context, _ restate.Void) (map[string]string, error) {
25 | return ctx.Request().Headers, nil
26 | })).
27 | Handler("rawEcho", restate.NewServiceHandler(
28 | func(ctx restate.Context, input []byte) ([]byte, error) {
29 | return input, nil
30 | }, restate.WithBinary)).
31 | Handler("sleepConcurrently", restate.NewServiceHandler(
32 | func(ctx restate.Context, millisDuration []int64) (restate.Void, error) {
33 | timers := make([]restate.Future, 0, len(millisDuration))
34 | for _, d := range millisDuration {
35 | timers = append(timers, restate.After(ctx, time.Duration(d)*time.Millisecond))
36 | }
37 | i := 0
38 | for _, err := range restate.Wait(ctx, timers...) {
39 | if err != nil {
40 | return restate.Void{}, err
41 | }
42 | i++
43 | }
44 | if i != len(timers) {
45 | return restate.Void{}, restate.TerminalErrorf("unexpected number of timers fired: %d", i)
46 | }
47 | return restate.Void{}, nil
48 | })).
49 | Handler("countExecutedSideEffects", restate.NewServiceHandler(
50 | func(ctx restate.Context, increments int32) (int32, error) {
51 | invokedSideEffects := atomic.Int32{}
52 | for i := int32(0); i < increments; i++ {
53 | restate.Run(ctx, func(ctx restate.RunContext) (int32, error) {
54 | return invokedSideEffects.Add(1), nil
55 | })
56 | }
57 | return invokedSideEffects.Load(), nil
58 | })).
59 | Handler("getEnvVariable", restate.NewServiceHandler(getEnvVariable)).
60 | Handler("cancelInvocation", restate.NewServiceHandler(
61 | func(ctx restate.Context, invocationId string) (restate.Void, error) {
62 | restate.CancelInvocation(ctx, invocationId)
63 | return restate.Void{}, nil
64 | })),
65 | )
66 | }
67 |
68 | func getEnvVariable(ctx restate.Context, envName string) (string, error) {
69 | return restate.Run(ctx, func(ctx restate.RunContext) (string, error) {
70 | return os.Getenv(envName), nil
71 | })
72 | }
73 |
--------------------------------------------------------------------------------
/internal/restatecontext/wait_iterator.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/restatedev/sdk-go/internal/errors"
7 | )
8 |
9 | // WaitIterator lets you
10 | type WaitIterator interface {
11 | // Next returns whether there are still operations that haven't been returned by Value().
12 | // If returns false, no more operations will be completed. After returning false, Err() should be checked.
13 | Next() bool
14 |
15 | // Err returns an error if the waiter was canceled using Restate's cancellation feature.
16 | Err() error
17 |
18 | // Value returns the current value of this iterator, or nil if the iterator returned Next previously.
19 | // Panics if called before the first Next
20 | Value() Selectable
21 | }
22 |
23 | func (restateCtx *ctx) WaitIter(futs ...Selectable) WaitIterator {
24 | indexedFuts := make(map[uint32]Selectable, len(futs))
25 | for i := range futs {
26 | handle := futs[i].handle()
27 | indexedFuts[handle] = futs[i]
28 | }
29 |
30 | return &waitIterator{
31 | restateCtx: restateCtx,
32 | indexedFuts: indexedFuts,
33 | }
34 | }
35 |
36 | type waitIterator struct {
37 | restateCtx *ctx
38 | indexedFuts map[uint32]Selectable
39 | lastCompleted Selectable
40 | cancelled bool
41 | }
42 |
43 | func (s *waitIterator) Next() bool {
44 | if s.cancelled || len(s.indexedFuts) == 0 {
45 | s.lastCompleted = nil
46 | return false
47 | }
48 |
49 | remainingHandles := make([]uint32, 0, len(s.indexedFuts))
50 | for k := range s.indexedFuts {
51 | remainingHandles = append(remainingHandles, k)
52 | }
53 |
54 | // Do progress
55 | cancelled := s.restateCtx.pollProgress(remainingHandles)
56 | if cancelled {
57 | s.lastCompleted = nil
58 | s.cancelled = true
59 | return false
60 | }
61 |
62 | // If we exit, one of them is completed, gotta figure out which one
63 | for _, handle := range remainingHandles {
64 | completed, err := s.restateCtx.stateMachine.IsCompleted(s.restateCtx, handle)
65 | if err != nil {
66 | panic(err)
67 | }
68 | if completed {
69 | fut := s.indexedFuts[handle]
70 | delete(s.indexedFuts, handle)
71 | s.lastCompleted = fut
72 | return true
73 | }
74 | }
75 |
76 | panic("Unexpectedly none of the remaining handles completed, this looks like a bug")
77 | }
78 |
79 | func (s *waitIterator) Err() error {
80 | if s.cancelled {
81 | return &errors.CodeError{Inner: &errors.TerminalError{Inner: fmt.Errorf("cancelled")}, Code: errors.Code(409)}
82 | }
83 | return nil
84 | }
85 |
86 | func (s *waitIterator) Value() Selectable {
87 | if !s.cancelled && s.lastCompleted == nil {
88 | panic("Unexpected call to Value() before first call to Next()")
89 | }
90 | if s.cancelled {
91 | return nil
92 | }
93 | return s.lastCompleted
94 | }
95 |
--------------------------------------------------------------------------------
/encoding/internal/protojsonschema/schema.go:
--------------------------------------------------------------------------------
1 | package protojsonschema
2 |
3 | import (
4 | "log/slog"
5 | "reflect"
6 | "runtime/debug"
7 |
8 | "github.com/invopop/jsonschema"
9 | "github.com/restatedev/sdk-go/encoding/internal/util"
10 | "google.golang.org/protobuf/reflect/protoreflect"
11 | )
12 |
13 | var protoMessageType = reflect.TypeOf((*protoreflect.ProtoMessage)(nil)).Elem()
14 | var protoEnumType = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()
15 |
16 | func descriptor(typ reflect.Type) protoreflect.Descriptor {
17 | if typ.Implements(protoEnumType) {
18 | zero := reflect.Zero(typ).Interface().(protoreflect.Enum)
19 | return zero.Descriptor()
20 | }
21 |
22 | pointerTyp := reflect.PointerTo(typ)
23 | if pointerTyp.Implements(protoMessageType) {
24 | zero := reflect.Zero(pointerTyp).Interface().(protoreflect.ProtoMessage)
25 | return zero.ProtoReflect().Descriptor()
26 | }
27 |
28 | return nil
29 | }
30 |
31 | func GenerateSchema(v any) (schema *jsonschema.Schema) {
32 | defer func() {
33 | if err := recover(); err != nil {
34 | slog.Warn("Error when trying to generate schema for object. Using `any` inestead", "object", reflect.TypeOf(v), "cause", err)
35 | debug.PrintStack()
36 |
37 | schema = jsonschema.ReflectFromType(reflect.TypeFor[map[string]string]())
38 | }
39 |
40 | }()
41 |
42 | reflector := jsonschema.Reflector{
43 | // Unfortunately we can't enable this due to a panic bug https://github.com/invopop/jsonschema/issues/163
44 | // So we use ExpandSchema instead, which has the same effect but without the panic
45 | // ExpandedStruct: true,
46 | KeyNamer: func(fieldName string) string {
47 | return jsonCamelCase(fieldName)
48 | },
49 | Mapper: func(typ reflect.Type) *jsonschema.Schema {
50 | desc := descriptor(typ)
51 | if desc == nil {
52 | return nil
53 | }
54 |
55 | schemaFn, ok := wellKnownToSchemaFns[string(desc.FullName())]
56 | if !ok {
57 | return nil
58 | }
59 |
60 | return schemaFn(desc)
61 | },
62 | Namer: func(typ reflect.Type) string {
63 | desc := descriptor(typ)
64 | if desc == nil {
65 | return ""
66 | }
67 |
68 | return string(desc.FullName())
69 | },
70 | }
71 | return util.ExpandSchema(reflector.Reflect(v))
72 | }
73 |
74 | // jsonCamelCase converts a snake_case identifier to a camelCase identifier,
75 | // according to the protobuf JSON specification.
76 | func jsonCamelCase(s string) string {
77 | var b []byte
78 | var wasUnderscore bool
79 | for i := 0; i < len(s); i++ { // proto identifiers are always ASCII
80 | c := s[i]
81 | if c != '_' {
82 | if wasUnderscore && 'a' <= c && c <= 'z' {
83 | c -= 'a' - 'A' // convert to uppercase
84 | }
85 | b = append(b, c)
86 | }
87 | wasUnderscore = c == '_'
88 | }
89 | return string(b)
90 | }
91 |
--------------------------------------------------------------------------------
/internal/rand/rand_test.go:
--------------------------------------------------------------------------------
1 | package rand
2 |
3 | import (
4 | "encoding/hex"
5 | "testing"
6 | )
7 |
8 | func TestUint64(t *testing.T) {
9 | id, err := hex.DecodeString("f311f1fdcb9863f0018bd3400ecd7d69b547204e776218b2")
10 | if err != nil {
11 | t.Fatal(err)
12 | }
13 | rand := NewFromInvocationId(id)
14 |
15 | expected := []uint64{
16 | 6541268553928124324,
17 | 1632128201851599825,
18 | 3999496359968271420,
19 | 9099219592091638755,
20 | 2609122094717920550,
21 | 16569362788292807660,
22 | 14955958648458255954,
23 | 15581072429430901841,
24 | 4951852598761288088,
25 | 2380816196140950843,
26 | }
27 |
28 | for _, e := range expected {
29 | if found := rand.Uint64(); e != found {
30 | t.Fatalf("Unexpected uint64 %d, expected %d", found, e)
31 | }
32 | }
33 | }
34 |
35 | func TestFloat64(t *testing.T) {
36 | source := &source{state: [4]uint64{1, 2, 3, 4}}
37 | rand := &rand{source}
38 |
39 | expected := []float64{
40 | 4.656612984099695e-9, 6.519269457605503e-9, 0.39843750651926946,
41 | 0.3986824029416509, 0.5822761557370711, 0.2997488042907357,
42 | 0.5336032865255543, 0.36335061693258097, 0.5968067925950846,
43 | 0.18570456306457928,
44 | }
45 |
46 | for _, e := range expected {
47 | if found := rand.Float64(); e != found {
48 | t.Fatalf("Unexpected float64 %v, expected %v", found, e)
49 | }
50 | }
51 | }
52 |
53 | func TestUUID(t *testing.T) {
54 | source := &source{state: [4]uint64{1, 2, 3, 4}}
55 | rand := &rand{source}
56 |
57 | expected := []string{
58 | "01008002-0000-4000-a700-800300000000",
59 | "67008003-00c0-4c00-b200-449901c20c00",
60 | "cd33c49a-01a2-4280-ba33-eecd8a97698a",
61 | "bd4a1533-4713-41c2-979e-167991a02bac",
62 | "d83f078f-0a19-43db-a092-22b24af10591",
63 | "677c91f7-146e-4769-a4fd-df3793e717e8",
64 | "f15179b2-f220-4427-8d90-7b5437d9828d",
65 | "9e97720f-42b8-4d09-a449-914cf221df26",
66 | "09d0a109-6f11-4ef9-93fa-f013d0ad3808",
67 | "41eb0e0c-41c9-4828-85d0-59fb901b4df4",
68 | }
69 |
70 | for _, e := range expected {
71 | if found := rand.UUID().String(); e != found {
72 | t.Fatalf("Unexpected uuid %s, expected %s", found, e)
73 | }
74 | }
75 | }
76 |
77 | func TestUUIDFromSeed(t *testing.T) {
78 | rand := NewFromSeed(1)
79 |
80 | expected := []string{
81 | "9bc2036f-7fd0-45cf-8de0-3f96324142bf",
82 | "20f5aa57-577d-4319-9656-cd059f1108bf",
83 | "a46f1886-4b18-472f-8523-20e7ca9f2997",
84 | "0715f408-95c7-43fc-a1f2-6303c9a5fe85",
85 | "d04b330d-b3e5-4a18-96ec-26f0c9136122",
86 | "49e6cfdc-f90e-4eeb-b3ff-6c9fddaeef57",
87 | "d6407669-d5e2-4a12-8950-50ee4e3a0365",
88 | "ba884ad5-3e45-4916-bba8-28f4a85a0628",
89 | "a21d045f-1647-408e-b32e-f7a4d9321079",
90 | "9eed3928-5482-48f5-8a0f-8040b1de6aa4",
91 | }
92 |
93 | for _, e := range expected {
94 | if found := rand.UUID().String(); e != found {
95 | t.Fatalf("Unexpected uuid %s, expected %s", found, e)
96 | }
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/internal/restatecontext/state.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | _ "embed"
5 | "fmt"
6 | "github.com/restatedev/sdk-go/encoding"
7 | "github.com/restatedev/sdk-go/internal/options"
8 | "github.com/restatedev/sdk-go/internal/statemachine"
9 | )
10 |
11 | func (restateCtx *ctx) Set(key string, value any, opts ...options.SetOption) {
12 | o := options.SetOptions{}
13 | for _, opt := range opts {
14 | opt.BeforeSet(&o)
15 | }
16 | if o.Codec == nil {
17 | o.Codec = encoding.JSONCodec
18 | }
19 |
20 | bytes, err := encoding.Marshal(o.Codec, value)
21 | if err != nil {
22 | panic(fmt.Errorf("failed to marshal Set value: %w", err))
23 | }
24 |
25 | err = restateCtx.stateMachine.SysStateSet(restateCtx, key, bytes)
26 | if err != nil {
27 | panic(err)
28 | }
29 | restateCtx.checkStateTransition()
30 | }
31 |
32 | func (restateCtx *ctx) Clear(key string) {
33 | err := restateCtx.stateMachine.SysStateClear(restateCtx, key)
34 | if err != nil {
35 | panic(err)
36 | }
37 | restateCtx.checkStateTransition()
38 | }
39 |
40 | // ClearAll drops all associated keys
41 | func (restateCtx *ctx) ClearAll() {
42 | err := restateCtx.stateMachine.SysStateClearAll(restateCtx)
43 | if err != nil {
44 | panic(err)
45 | }
46 | restateCtx.checkStateTransition()
47 | }
48 |
49 | func (restateCtx *ctx) Get(key string, output any, opts ...options.GetOption) (bool, error) {
50 | o := options.GetOptions{}
51 | for _, opt := range opts {
52 | opt.BeforeGet(&o)
53 | }
54 | if o.Codec == nil {
55 | o.Codec = encoding.JSONCodec
56 | }
57 |
58 | handle, err := restateCtx.stateMachine.SysStateGet(restateCtx, key)
59 | if err != nil {
60 | panic(err)
61 | }
62 | restateCtx.checkStateTransition()
63 |
64 | ar := newAsyncResult(restateCtx, handle)
65 | switch result := ar.pollProgressAndLoadValue().(type) {
66 | case statemachine.ValueVoid:
67 | return false, nil
68 | case statemachine.ValueSuccess:
69 | {
70 | if err := encoding.Unmarshal(o.Codec, result.Success, output); err != nil {
71 | panic(fmt.Errorf("failed to unmarshal Get state into output: %w", err))
72 | }
73 | return true, err
74 | }
75 | case statemachine.ValueFailure:
76 | return true, errorFromFailure(result)
77 | default:
78 | panic(fmt.Errorf("unexpected value %s", result))
79 |
80 | }
81 | }
82 |
83 | func (restateCtx *ctx) Keys() ([]string, error) {
84 | handle, err := restateCtx.stateMachine.SysStateGetKeys(restateCtx)
85 | if err != nil {
86 | panic(err)
87 | }
88 | restateCtx.checkStateTransition()
89 |
90 | ar := newAsyncResult(restateCtx, handle)
91 | switch result := ar.pollProgressAndLoadValue().(type) {
92 | case statemachine.ValueStateKeys:
93 | return result.Keys, nil
94 | case statemachine.ValueFailure:
95 | return nil, errorFromFailure(result)
96 | default:
97 | panic(fmt.Errorf("unexpected value %s", result))
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/internal/restatecontext/awakeable.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "fmt"
5 | "github.com/restatedev/sdk-go/encoding"
6 | "github.com/restatedev/sdk-go/internal/errors"
7 | pbinternal "github.com/restatedev/sdk-go/internal/generated"
8 | "github.com/restatedev/sdk-go/internal/options"
9 | "github.com/restatedev/sdk-go/internal/statemachine"
10 | )
11 |
12 | func (restateCtx *ctx) Awakeable(opts ...options.AwakeableOption) AwakeableFuture {
13 | o := options.AwakeableOptions{}
14 | for _, opt := range opts {
15 | opt.BeforeAwakeable(&o)
16 | }
17 | if o.Codec == nil {
18 | o.Codec = encoding.JSONCodec
19 | }
20 |
21 | id, handle, err := restateCtx.stateMachine.SysAwakeable(restateCtx)
22 | if err != nil {
23 | panic(err)
24 | }
25 | restateCtx.checkStateTransition()
26 |
27 | return &awakeableFuture{
28 | asyncResult: newAsyncResult(restateCtx, handle),
29 | id: id,
30 | codec: o.Codec,
31 | }
32 | }
33 |
34 | type AwakeableFuture interface {
35 | Selectable
36 | Id() string
37 | Result(output any) error
38 | }
39 |
40 | type awakeableFuture struct {
41 | asyncResult
42 | id string
43 | codec encoding.Codec
44 | }
45 |
46 | func (d *awakeableFuture) Id() string { return d.id }
47 |
48 | func (d *awakeableFuture) Result(output any) error {
49 | switch result := d.pollProgressAndLoadValue().(type) {
50 | case statemachine.ValueSuccess:
51 | {
52 | if err := encoding.Unmarshal(d.codec, result.Success, output); err != nil {
53 | panic(fmt.Errorf("failed to unmarshal awakeable result into output: %w", err))
54 | }
55 | return nil
56 | }
57 | case statemachine.ValueFailure:
58 | return errorFromFailure(result)
59 | default:
60 | panic(fmt.Errorf("unexpected value %s", result))
61 |
62 | }
63 | }
64 |
65 | func (restateCtx *ctx) ResolveAwakeable(id string, value any, opts ...options.ResolveAwakeableOption) {
66 | o := options.ResolveAwakeableOptions{}
67 | for _, opt := range opts {
68 | opt.BeforeResolveAwakeable(&o)
69 | }
70 | if o.Codec == nil {
71 | o.Codec = encoding.JSONCodec
72 | }
73 | bytes, err := encoding.Marshal(o.Codec, value)
74 | if err != nil {
75 | panic(fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err))
76 | }
77 |
78 | input := pbinternal.VmSysCompleteAwakeableParameters{}
79 | input.SetId(id)
80 | input.SetSuccess(bytes)
81 | if err := restateCtx.stateMachine.SysCompleteAwakeable(restateCtx, &input); err != nil {
82 | panic(err)
83 | }
84 | restateCtx.checkStateTransition()
85 | }
86 |
87 | func (restateCtx *ctx) RejectAwakeable(id string, reason error) {
88 | failure := pbinternal.Failure{}
89 | failure.SetCode(uint32(errors.ErrorCode(reason)))
90 | failure.SetMessage(reason.Error())
91 |
92 | input := pbinternal.VmSysCompleteAwakeableParameters{}
93 | input.SetId(id)
94 | input.SetFailure(&failure)
95 | if err := restateCtx.stateMachine.SysCompleteAwakeable(restateCtx, &input); err != nil {
96 | panic(err)
97 | }
98 | restateCtx.checkStateTransition()
99 | }
100 |
--------------------------------------------------------------------------------
/examples/ticketreservation/user_session.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "slices"
5 | "time"
6 |
7 | restate "github.com/restatedev/sdk-go"
8 | )
9 |
10 | const UserSessionServiceName = "UserSession"
11 |
12 | type userSession struct{}
13 |
14 | func (u *userSession) ServiceName() string {
15 | return UserSessionServiceName
16 | }
17 |
18 | func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (bool, error) {
19 | userId := restate.Key(ctx)
20 |
21 | success, err := restate.Object[bool](ctx, TicketServiceName, ticketId, "Reserve").Request(userId)
22 | if err != nil {
23 | return false, err
24 | }
25 |
26 | if !success {
27 | return false, nil
28 | }
29 |
30 | // add ticket to list of tickets
31 | tickets, err := restate.Get[[]string](ctx, "tickets")
32 | if err != nil {
33 | return false, err
34 | }
35 |
36 | tickets = append(tickets, ticketId)
37 |
38 | restate.Set(ctx, "tickets", tickets)
39 | restate.ObjectSend(ctx, UserSessionServiceName, userId, "ExpireTicket").Send(ticketId, restate.WithDelay(15*time.Minute))
40 |
41 | return true, nil
42 | }
43 |
44 | func (u *userSession) ExpireTicket(ctx restate.ObjectContext, ticketId string) (void restate.Void, err error) {
45 | tickets, err := restate.Get[[]string](ctx, "tickets")
46 | if err != nil {
47 | return void, err
48 | }
49 |
50 | deleted := false
51 | tickets = slices.DeleteFunc(tickets, func(ticket string) bool {
52 | if ticket == ticketId {
53 | deleted = true
54 | return true
55 | }
56 | return false
57 | })
58 | if !deleted {
59 | return void, nil
60 | }
61 |
62 | restate.Set(ctx, "tickets", tickets)
63 | restate.ObjectSend(ctx, TicketServiceName, ticketId, "Unreserve").Send(restate.Void{})
64 |
65 | return void, nil
66 | }
67 |
68 | func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) {
69 | userId := restate.Key(ctx)
70 | tickets, err := restate.Get[[]string](ctx, "tickets")
71 | if err != nil {
72 | return false, err
73 | }
74 |
75 | ctx.Log().Info("tickets in basket", "tickets", tickets)
76 |
77 | if len(tickets) == 0 {
78 | return false, nil
79 | }
80 |
81 | timeout := restate.After(ctx, time.Minute)
82 |
83 | request := restate.Object[PaymentResponse](ctx, CheckoutServiceName, "", "Payment").
84 | RequestFuture(PaymentRequest{UserID: userId, Tickets: tickets})
85 |
86 | // race between the request and the timeout
87 | resultFut, err := restate.WaitFirst(ctx, timeout, request)
88 | if err != nil {
89 | return false, err
90 | }
91 | switch resultFut {
92 | case request:
93 | // happy path
94 | case timeout:
95 | // we could choose to fail here with terminal error, but we'd also have to refund the payment!
96 | ctx.Log().Warn("slow payment")
97 | }
98 |
99 | // block on the eventual response
100 | response, err := request.Response()
101 | if err != nil {
102 | return false, err
103 | }
104 |
105 | ctx.Log().Info("payment details", "id", response.ID, "price", response.Price)
106 |
107 | for _, ticket := range tickets {
108 | restate.ObjectSend(ctx, TicketServiceName, ticket, "MarkAsSold").Send(restate.Void{})
109 | }
110 |
111 | restate.Clear(ctx, "tickets")
112 | return true, nil
113 | }
114 |
--------------------------------------------------------------------------------
/encoding/encoding_test.go:
--------------------------------------------------------------------------------
1 | package encoding
2 |
3 | import (
4 | "encoding/json"
5 | "reflect"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/require"
9 | )
10 |
11 | func TestVoid(t *testing.T) {
12 | codecs := map[string]Codec{
13 | "json": JSONCodec,
14 | "proto": ProtoCodec,
15 | "protojson": ProtoJSONCodec,
16 | "binary": BinaryCodec,
17 | }
18 | for name, codec := range codecs {
19 | t.Run(name, func(t *testing.T) {
20 | bytes, err := Marshal(codec, Void{})
21 | if err != nil {
22 | t.Fatal(err)
23 | }
24 |
25 | if bytes != nil {
26 | t.Fatalf("expected bytes to be nil, found %v", bytes)
27 | }
28 |
29 | if err := Unmarshal(codec, []byte{1, 2, 3}, &Void{}); err != nil {
30 | t.Fatal(err)
31 | }
32 |
33 | if err := Unmarshal(codec, []byte{1, 2, 3}, Void{}); err != nil {
34 | t.Fatal(err)
35 | }
36 | })
37 | }
38 | }
39 |
40 | var jsonSchemaCases = []struct {
41 | object any
42 | schema string
43 | }{
44 | {
45 | object: "abc",
46 | schema: `{"$schema":"https://json-schema.org/draft/2020-12/schema","type":"string"}`,
47 | },
48 | {
49 | object: 123,
50 | schema: `{"$schema":"https://json-schema.org/draft/2020-12/schema","type":"integer"}`,
51 | },
52 | {
53 | object: 1.1,
54 | schema: `{"$schema":"https://json-schema.org/draft/2020-12/schema","type":"number"}`,
55 | },
56 | {
57 | object: struct {
58 | Foo string `json:"foo"`
59 | }{},
60 | schema: `{"$schema":"https://json-schema.org/draft/2020-12/schema","properties":{"foo":{"type":"string"}},"additionalProperties":false,"type":"object","required":["foo"]}`,
61 | },
62 | {
63 | object: recursive{},
64 | schema: `{"$schema":"https://json-schema.org/draft/2020-12/schema","$id":"https://github.com/restatedev/sdk-go/encoding/recursive","$defs":{"recursive":{"$ref":"#"}},"properties":{"inner":{"$ref":"#/$defs/recursive"}},"additionalProperties":false,"type":"object","required":["inner"]}`,
65 | },
66 | {
67 | object: nestedRecursiveA{},
68 | schema: `{"$schema":"https://json-schema.org/draft/2020-12/schema","$id":"https://github.com/restatedev/sdk-go/encoding/nested-recursive-a","$defs":{"nestedRecursiveA":{"$ref":"#"},"nestedRecursiveB":{"properties":{"inner":{"$ref":"#/$defs/nestedRecursiveC"}},"additionalProperties":false,"type":"object","required":["inner"]},"nestedRecursiveC":{"properties":{"inner":{"$ref":"#/$defs/nestedRecursiveA"}},"additionalProperties":false,"type":"object","required":["inner"]}},"properties":{"inner":{"$ref":"#/$defs/nestedRecursiveB"}},"additionalProperties":false,"type":"object","required":["inner"]}`,
69 | },
70 | }
71 |
72 | type recursive struct {
73 | Inner *recursive `json:"inner"`
74 | }
75 |
76 | type nestedRecursiveA struct {
77 | Inner *nestedRecursiveB `json:"inner"`
78 | }
79 |
80 | type nestedRecursiveB struct {
81 | Inner *nestedRecursiveC `json:"inner"`
82 | }
83 |
84 | type nestedRecursiveC struct {
85 | Inner *nestedRecursiveA `json:"inner"`
86 | }
87 |
88 | func TestGenerateJsonSchema(t *testing.T) {
89 | for _, test := range jsonSchemaCases {
90 | t.Run(reflect.TypeOf(test.object).String(), func(t *testing.T) {
91 | schema := generateJsonSchema(test.object)
92 | data, err := json.Marshal(schema)
93 | require.NoError(t, err)
94 | require.Equal(t, test.schema, string(data))
95 | })
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/test-services/failing.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "sync/atomic"
6 | "time"
7 |
8 | restate "github.com/restatedev/sdk-go"
9 | )
10 |
11 | func init() {
12 | var eventualSuccessCalls atomic.Int32
13 | var eventualSuccessSideEffectCalls atomic.Int32
14 | var eventualFailureSideEffectCalls atomic.Int32
15 |
16 | REGISTRY.AddDefinition(
17 | restate.NewObject("Failing").
18 | Handler("terminallyFailingCall", restate.NewObjectHandler(
19 | func(ctx restate.ObjectContext, errorMessage string) (restate.Void, error) {
20 | return restate.Void{}, restate.TerminalErrorf("%s", errorMessage)
21 | })).
22 | Handler("callTerminallyFailingCall", restate.NewObjectHandler(
23 | func(ctx restate.ObjectContext, errorMessage string) (string, error) {
24 | if _, err := restate.Object[restate.Void](ctx, "Failing", restate.UUID(ctx).String(), "terminallyFailingCall").Request(errorMessage); err != nil {
25 | return "", err
26 | }
27 |
28 | return "", restate.TerminalErrorf("This should be unreachable")
29 | })).
30 | Handler("failingCallWithEventualSuccess", restate.NewObjectHandler(
31 | func(ctx restate.ObjectContext, _ restate.Void) (int32, error) {
32 | currentAttempt := eventualSuccessCalls.Add(1)
33 | if currentAttempt >= 4 {
34 | eventualSuccessCalls.Store(0)
35 | return currentAttempt, nil
36 | } else {
37 | return 0, fmt.Errorf("Failed at attempt: %d", currentAttempt)
38 | }
39 | })).
40 | Handler("terminallyFailingSideEffect", restate.NewObjectHandler(
41 | func(ctx restate.ObjectContext, errorMessage string) (restate.Void, error) {
42 | err := restate.RunVoid(ctx, func(ctx restate.RunContext) error {
43 | return restate.TerminalErrorf("%s", errorMessage)
44 | })
45 | return restate.Void{}, err
46 | })).
47 | Handler("sideEffectSucceedsAfterGivenAttempts", restate.NewObjectHandler(
48 | func(ctx restate.ObjectContext, minimumAttempts int32) (int32, error) {
49 | return restate.Run(ctx, func(ctx restate.RunContext) (int32, error) {
50 | currentAttempt := eventualSuccessSideEffectCalls.Add(1)
51 | if currentAttempt >= minimumAttempts {
52 | eventualSuccessSideEffectCalls.Store(0)
53 | return currentAttempt, nil
54 | } else {
55 | return 0, fmt.Errorf("Failed at attempt: %d", currentAttempt)
56 | }
57 | },
58 | restate.WithName("failing_side_effect"),
59 | restate.WithInitialRetryInterval(time.Millisecond*10),
60 | restate.WithRetryIntervalFactor(1.0))
61 | })).
62 | Handler("sideEffectFailsAfterGivenAttempts", restate.NewObjectHandler(
63 | func(ctx restate.ObjectContext, retryPolicyMaxRetryCount uint) (int32, error) {
64 | _, err := restate.Run(ctx, func(ctx restate.RunContext) (int32, error) {
65 | currentAttempt := eventualFailureSideEffectCalls.Add(1)
66 | return 0, fmt.Errorf("Failed at attempt: %d", currentAttempt)
67 | }, restate.WithName("failing_side_effect"), restate.WithInitialRetryInterval(time.Millisecond*10), restate.WithRetryIntervalFactor(1.0), restate.WithMaxRetryAttempts(retryPolicyMaxRetryCount))
68 | if err != nil {
69 | return eventualFailureSideEffectCalls.Load(), nil
70 | }
71 | return 0, restate.TerminalErrorf("Expecting the side effect to fail!")
72 | })))
73 | }
74 |
--------------------------------------------------------------------------------
/test-services/nondeterministic.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "sync"
5 | "time"
6 |
7 | restate "github.com/restatedev/sdk-go"
8 | )
9 |
10 | const STATE_A = "a"
11 | const STATE_B = "b"
12 |
13 | func init() {
14 | invocationCounts := map[string]int32{}
15 | invocationCountsMtx := sync.RWMutex{}
16 |
17 | doLeftAction := func(ctx restate.ObjectContext) bool {
18 | countKey := restate.Key(ctx)
19 | invocationCountsMtx.Lock()
20 | defer invocationCountsMtx.Unlock()
21 |
22 | invocationCounts[countKey] += 1
23 | return invocationCounts[countKey]%2 == 1
24 | }
25 | incrementCounter := func(ctx restate.ObjectContext) {
26 | restate.ObjectSend(ctx, "Counter", restate.Key(ctx), "add").Send(int64(1))
27 | }
28 |
29 | REGISTRY.AddDefinition(
30 | restate.NewObject("NonDeterministic").
31 | Handler("eitherSleepOrCall", restate.NewObjectHandler(
32 | func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
33 | if doLeftAction(ctx) {
34 | restate.Sleep(ctx, 100*time.Millisecond)
35 | } else {
36 | if _, err := restate.Object[restate.Void](ctx, "Counter", "abc", "get").Request(restate.Void{}); err != nil {
37 | return restate.Void{}, err
38 | }
39 | }
40 |
41 | // This is required to cause a suspension after the non-deterministic operation
42 | restate.Sleep(ctx, 100*time.Millisecond)
43 | incrementCounter(ctx)
44 | return restate.Void{}, nil
45 | })).
46 | Handler("callDifferentMethod", restate.NewObjectHandler(
47 | func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
48 | if doLeftAction(ctx) {
49 | if _, err := restate.Object[restate.Void](ctx, "Counter", "abc", "get").Request(restate.Void{}); err != nil {
50 | return restate.Void{}, err
51 | }
52 | } else {
53 | if _, err := restate.Object[restate.Void](ctx, "Counter", "abc", "reset").Request(restate.Void{}); err != nil {
54 | return restate.Void{}, err
55 | }
56 | }
57 |
58 | // This is required to cause a suspension after the non-deterministic operation
59 | restate.Sleep(ctx, 100*time.Millisecond)
60 | incrementCounter(ctx)
61 | return restate.Void{}, nil
62 | })).
63 | Handler("backgroundInvokeWithDifferentTargets", restate.NewObjectHandler(
64 | func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
65 | if doLeftAction(ctx) {
66 | restate.ObjectSend(ctx, "Counter", "abc", "get").Send(restate.Void{})
67 | } else {
68 | restate.ObjectSend(ctx, "Counter", "abc", "reset").Send(restate.Void{})
69 | }
70 |
71 | // This is required to cause a suspension after the non-deterministic operation
72 | restate.Sleep(ctx, 100*time.Millisecond)
73 | incrementCounter(ctx)
74 | return restate.Void{}, nil
75 | })).
76 | Handler("setDifferentKey", restate.NewObjectHandler(
77 | func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) {
78 | if doLeftAction(ctx) {
79 | restate.Set(ctx, STATE_A, "my-state")
80 | } else {
81 | restate.Set(ctx, STATE_B, "my-state")
82 | }
83 |
84 | // This is required to cause a suspension after the non-deterministic operation
85 | restate.Sleep(ctx, 100*time.Millisecond)
86 | incrementCounter(ctx)
87 | return restate.Void{}, nil
88 | })))
89 | }
90 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/restatedev/sdk-go
2 |
3 | go 1.24.0
4 |
5 | require (
6 | github.com/golang-jwt/jwt/v5 v5.2.3
7 | github.com/google/uuid v1.6.0
8 | github.com/invopop/jsonschema v0.13.0
9 | github.com/mr-tron/base58 v1.2.0
10 | github.com/nsf/jsondiff v0.0.0-20230430225905-43f6cf3098c1
11 | github.com/stretchr/testify v1.11.1
12 | github.com/testcontainers/testcontainers-go v0.40.0
13 | github.com/tetratelabs/wazero v1.9.0
14 | github.com/wk8/go-ordered-map/v2 v2.1.8
15 | go.opentelemetry.io/otel v1.38.0
16 | google.golang.org/protobuf v1.36.10
17 | )
18 |
19 | require (
20 | dario.cat/mergo v1.0.2 // indirect
21 | github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
22 | github.com/Microsoft/go-winio v0.6.2 // indirect
23 | github.com/bahlo/generic-list-go v0.2.0 // indirect
24 | github.com/buger/jsonparser v1.1.1 // indirect
25 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect
26 | github.com/containerd/errdefs v1.0.0 // indirect
27 | github.com/containerd/errdefs/pkg v0.3.0 // indirect
28 | github.com/containerd/log v0.1.0 // indirect
29 | github.com/containerd/platforms v0.2.1 // indirect
30 | github.com/cpuguy83/dockercfg v0.3.2 // indirect
31 | github.com/davecgh/go-spew v1.1.1 // indirect
32 | github.com/distribution/reference v0.6.0 // indirect
33 | github.com/docker/docker v28.5.1+incompatible // indirect
34 | github.com/docker/go-connections v0.6.0 // indirect
35 | github.com/docker/go-units v0.5.0 // indirect
36 | github.com/ebitengine/purego v0.8.4 // indirect
37 | github.com/felixge/httpsnoop v1.0.4 // indirect
38 | github.com/go-logr/logr v1.4.3 // indirect
39 | github.com/go-logr/stdr v1.2.2 // indirect
40 | github.com/go-ole/go-ole v1.2.6 // indirect
41 | github.com/klauspost/compress v1.18.0 // indirect
42 | github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
43 | github.com/magiconair/properties v1.8.10 // indirect
44 | github.com/mailru/easyjson v0.7.7 // indirect
45 | github.com/moby/docker-image-spec v1.3.1 // indirect
46 | github.com/moby/go-archive v0.1.0 // indirect
47 | github.com/moby/patternmatcher v0.6.0 // indirect
48 | github.com/moby/sys/sequential v0.6.0 // indirect
49 | github.com/moby/sys/user v0.4.0 // indirect
50 | github.com/moby/sys/userns v0.1.0 // indirect
51 | github.com/moby/term v0.5.0 // indirect
52 | github.com/morikuni/aec v1.0.0 // indirect
53 | github.com/opencontainers/go-digest v1.0.0 // indirect
54 | github.com/opencontainers/image-spec v1.1.1 // indirect
55 | github.com/pkg/errors v0.9.1 // indirect
56 | github.com/pmezard/go-difflib v1.0.0 // indirect
57 | github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
58 | github.com/shirou/gopsutil/v4 v4.25.6 // indirect
59 | github.com/sirupsen/logrus v1.9.3 // indirect
60 | github.com/stretchr/objx v0.5.2 // indirect
61 | github.com/tklauser/go-sysconf v0.3.12 // indirect
62 | github.com/tklauser/numcpus v0.6.1 // indirect
63 | github.com/yusufpapurcu/wmi v1.2.4 // indirect
64 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect
65 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
66 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect
67 | go.opentelemetry.io/otel/metric v1.38.0 // indirect
68 | go.opentelemetry.io/otel/sdk v1.38.0 // indirect
69 | go.opentelemetry.io/otel/trace v1.38.0 // indirect
70 | go.opentelemetry.io/proto/otlp v1.9.0 // indirect
71 | golang.org/x/crypto v0.43.0 // indirect
72 | golang.org/x/sys v0.37.0 // indirect
73 | gopkg.in/yaml.v3 v3.0.1 // indirect
74 | )
75 |
--------------------------------------------------------------------------------
/encoding/internal/protojsonschema/wellknown.go:
--------------------------------------------------------------------------------
1 | package protojsonschema
2 |
3 | import (
4 | "github.com/invopop/jsonschema"
5 | orderedmap "github.com/wk8/go-ordered-map/v2"
6 | "google.golang.org/protobuf/reflect/protoreflect"
7 | )
8 |
9 | var wellKnownToSchemaFns = map[string]func(protoreflect.Descriptor) *jsonschema.Schema{
10 | "google.protobuf.Duration": func(d protoreflect.Descriptor) *jsonschema.Schema {
11 | return &jsonschema.Schema{
12 | Type: "string",
13 | Format: "regex",
14 | Pattern: `^[-\+]?([0-9]+\.?[0-9]*|\.[0-9]+)s$`,
15 | }
16 | },
17 | "google.protobuf.Timestamp": func(d protoreflect.Descriptor) *jsonschema.Schema {
18 | return &jsonschema.Schema{
19 | Type: "string",
20 | Format: "date-time",
21 | }
22 | },
23 | "google.protobuf.Empty": func(d protoreflect.Descriptor) *jsonschema.Schema {
24 | return &jsonschema.Schema{
25 | Type: "object",
26 | AdditionalProperties: jsonschema.FalseSchema,
27 | }
28 | },
29 | "google.protobuf.Any": func(d protoreflect.Descriptor) *jsonschema.Schema {
30 | return &jsonschema.Schema{
31 | Type: "object",
32 | Properties: orderedmap.New[string, *jsonschema.Schema](orderedmap.WithInitialData[string, *jsonschema.Schema](
33 | orderedmap.Pair[string, *jsonschema.Schema]{
34 | Key: "@type",
35 | Value: &jsonschema.Schema{Type: "string"},
36 | },
37 | orderedmap.Pair[string, *jsonschema.Schema]{
38 | Key: "value",
39 | Value: &jsonschema.Schema{Type: "string", Format: "binary"},
40 | },
41 | )),
42 | AdditionalProperties: jsonschema.TrueSchema,
43 | }
44 | },
45 | "google.protobuf.FieldMask": func(d protoreflect.Descriptor) *jsonschema.Schema {
46 | return &jsonschema.Schema{
47 | Type: "string",
48 | }
49 | },
50 |
51 | "google.protobuf.Struct": func(d protoreflect.Descriptor) *jsonschema.Schema {
52 | return &jsonschema.Schema{
53 | Type: "object",
54 | AdditionalProperties: jsonschema.TrueSchema,
55 | }
56 | },
57 | "google.protobuf.Value": func(d protoreflect.Descriptor) *jsonschema.Schema {
58 | return jsonschema.TrueSchema
59 | },
60 | "google.protobuf.NullValue": func(d protoreflect.Descriptor) *jsonschema.Schema {
61 | return &jsonschema.Schema{
62 | Type: "null",
63 | }
64 | },
65 | "google.protobuf.StringValue": func(d protoreflect.Descriptor) *jsonschema.Schema {
66 | return &jsonschema.Schema{
67 | Type: "string",
68 | }
69 | },
70 | "google.protobuf.BytesValue": func(d protoreflect.Descriptor) *jsonschema.Schema {
71 | return &jsonschema.Schema{
72 | Type: "string",
73 | Format: "binary",
74 | }
75 | },
76 | "google.protobuf.BoolValue": func(d protoreflect.Descriptor) *jsonschema.Schema {
77 | return &jsonschema.Schema{
78 | Type: "boolean",
79 | }
80 | },
81 | "google.protobuf.DoubleValue": google64BitNumberValue,
82 | "google.protobuf.Int64Value": google64BitNumberValue,
83 | "google.protobuf.UInt64Value": google64BitNumberValue,
84 | "google.protobuf.FloatValue": google64BitNumberValue,
85 | "google.protobuf.Int32Value": google32BitNumberValue,
86 | "google.protobuf.UInt32Value": google32BitNumberValue,
87 | }
88 |
89 | var google64BitNumberValue = func(d protoreflect.Descriptor) *jsonschema.Schema {
90 | return &jsonschema.Schema{
91 | OneOf: []*jsonschema.Schema{
92 | &jsonschema.Schema{Type: "number"},
93 | &jsonschema.Schema{Type: "string"},
94 | },
95 | }
96 | }
97 | var google32BitNumberValue = func(d protoreflect.Descriptor) *jsonschema.Schema {
98 | return &jsonschema.Schema{
99 | Type: "number",
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/mocks/mock_AfterFuture.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.52.1. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import mock "github.com/stretchr/testify/mock"
6 |
7 | /* moved to helpers.go
8 | // MockAfterFuture is an autogenerated mock type for the AfterFuture type
9 | type MockAfterFuture struct {
10 | mock.Mock
11 | }
12 | */
13 |
14 | type MockAfterFuture_Expecter struct {
15 | mock *mock.Mock
16 | }
17 |
18 | func (_m *MockAfterFuture) EXPECT() *MockAfterFuture_Expecter {
19 | return &MockAfterFuture_Expecter{mock: &_m.Mock}
20 | }
21 |
22 | // Done provides a mock function with no fields
23 | func (_m *MockAfterFuture) Done() error {
24 | ret := _m.Called()
25 |
26 | if len(ret) == 0 {
27 | panic("no return value specified for Done")
28 | }
29 |
30 | var r0 error
31 | if rf, ok := ret.Get(0).(func() error); ok {
32 | r0 = rf()
33 | } else {
34 | r0 = ret.Error(0)
35 | }
36 |
37 | return r0
38 | }
39 |
40 | // MockAfterFuture_Done_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Done'
41 | type MockAfterFuture_Done_Call struct {
42 | *mock.Call
43 | }
44 |
45 | // Done is a helper method to define mock.On call
46 | func (_e *MockAfterFuture_Expecter) Done() *MockAfterFuture_Done_Call {
47 | return &MockAfterFuture_Done_Call{Call: _e.mock.On("Done")}
48 | }
49 |
50 | func (_c *MockAfterFuture_Done_Call) Run(run func()) *MockAfterFuture_Done_Call {
51 | _c.Call.Run(func(args mock.Arguments) {
52 | run()
53 | })
54 | return _c
55 | }
56 |
57 | func (_c *MockAfterFuture_Done_Call) Return(_a0 error) *MockAfterFuture_Done_Call {
58 | _c.Call.Return(_a0)
59 | return _c
60 | }
61 |
62 | func (_c *MockAfterFuture_Done_Call) RunAndReturn(run func() error) *MockAfterFuture_Done_Call {
63 | _c.Call.Return(run)
64 | return _c
65 | }
66 |
67 | // handle provides a mock function with no fields
68 | func (_m *MockAfterFuture) handle() uint32 {
69 | ret := _m.Called()
70 |
71 | if len(ret) == 0 {
72 | panic("no return value specified for handle")
73 | }
74 |
75 | var r0 uint32
76 | if rf, ok := ret.Get(0).(func() uint32); ok {
77 | r0 = rf()
78 | } else {
79 | r0 = ret.Get(0).(uint32)
80 | }
81 |
82 | return r0
83 | }
84 |
85 | // MockAfterFuture_handle_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'handle'
86 | type MockAfterFuture_handle_Call struct {
87 | *mock.Call
88 | }
89 |
90 | // handle is a helper method to define mock.On call
91 | func (_e *MockAfterFuture_Expecter) handle() *MockAfterFuture_handle_Call {
92 | return &MockAfterFuture_handle_Call{Call: _e.mock.On("handle")}
93 | }
94 |
95 | func (_c *MockAfterFuture_handle_Call) Run(run func()) *MockAfterFuture_handle_Call {
96 | _c.Call.Run(func(args mock.Arguments) {
97 | run()
98 | })
99 | return _c
100 | }
101 |
102 | func (_c *MockAfterFuture_handle_Call) Return(_a0 uint32) *MockAfterFuture_handle_Call {
103 | _c.Call.Return(_a0)
104 | return _c
105 | }
106 |
107 | func (_c *MockAfterFuture_handle_Call) RunAndReturn(run func() uint32) *MockAfterFuture_handle_Call {
108 | _c.Call.Return(run)
109 | return _c
110 | }
111 |
112 | // NewMockAfterFuture creates a new instance of MockAfterFuture. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
113 | // The first argument is typically a *testing.T value.
114 | func NewMockAfterFuture(t interface {
115 | mock.TestingT
116 | Cleanup(func())
117 | }) *MockAfterFuture {
118 | mock := &MockAfterFuture{}
119 | mock.Mock.Test(t)
120 |
121 | t.Cleanup(func() { mock.AssertExpectations(t) })
122 |
123 | return mock
124 | }
125 |
--------------------------------------------------------------------------------
/internal/rand/rand.go:
--------------------------------------------------------------------------------
1 | package rand
2 |
3 | import (
4 | "crypto/sha256"
5 | "encoding/binary"
6 |
7 | "github.com/google/uuid"
8 | )
9 |
10 | type Rand interface {
11 | // Deprecated: Use restate.UUID directly, instead of restate.Rand().UUID()
12 | UUID() uuid.UUID
13 | Float64() float64
14 | Uint64() uint64
15 | // Source returns a deterministic random source that can be provided to math/rand.New()
16 | // and math/rand/v2.New(). The v2 version of rand is strongly recommended where Go 1.22
17 | // is used, and once this library begins to depend on 1.22, it will be embedded in Rand.
18 | //
19 | // Deprecated: Use restate.RandSource directly, instead of restate.Rand().Source()
20 | Source() Source
21 | }
22 |
23 | type rand struct {
24 | source *source
25 | }
26 |
27 | func NewFromInvocationId(invocationID []byte) *rand {
28 | return &rand{newSourceFromInvocationId(invocationID)}
29 | }
30 |
31 | func NewFromSeed(seed uint64) *rand {
32 | return &rand{newSource(seed)}
33 | }
34 |
35 | func (r *rand) UUID() uuid.UUID {
36 | var uuid [16]byte
37 | binary.LittleEndian.PutUint64(uuid[:8], r.Uint64())
38 | binary.LittleEndian.PutUint64(uuid[8:], r.Uint64())
39 | uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
40 | uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
41 | return uuid
42 | }
43 |
44 | func (r *rand) Float64() float64 {
45 | // use the math/rand/v2 implementation of Float64() which is more correct
46 | // and also matches our TS implementation
47 | return float64(r.Uint64()<<11>>11) / (1 << 53)
48 | }
49 |
50 | func (r *rand) Uint64() uint64 {
51 | return r.source.Uint64()
52 | }
53 |
54 | func (r *rand) Source() Source {
55 | return r.source
56 | }
57 |
58 | type Source interface {
59 | Int63() int64
60 | Uint64() uint64
61 |
62 | // only the v1 rand package requires this method; we do *not* support it and will panic if its called.
63 | Seed(int64)
64 | }
65 |
66 | type source struct {
67 | state [4]uint64
68 | }
69 |
70 | // From https://xoroshiro.di.unimi.it/splitmix64.c
71 | func splitMix64(x *uint64) uint64 {
72 | *x += 0x9e3779b97f4a7c15
73 | z := *x
74 | z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9
75 | z = (z ^ (z >> 27)) * 0x94d049bb133111eb
76 | return z ^ (z >> 31)
77 | }
78 |
79 | func newSource(seed uint64) *source {
80 | return &source{state: [4]uint64{
81 | splitMix64(&seed),
82 | splitMix64(&seed),
83 | splitMix64(&seed),
84 | splitMix64(&seed),
85 | }}
86 | }
87 |
88 | func newSourceFromInvocationId(invocationId []byte) *source {
89 | hash := sha256.New()
90 | hash.Write(invocationId)
91 | var sum [32]byte
92 | hash.Sum(sum[:0])
93 |
94 | return &source{state: [4]uint64{
95 | binary.LittleEndian.Uint64(sum[:8]),
96 | binary.LittleEndian.Uint64(sum[8:16]),
97 | binary.LittleEndian.Uint64(sum[16:24]),
98 | binary.LittleEndian.Uint64(sum[24:32]),
99 | }}
100 | }
101 |
102 | func (s *source) Int63() int64 {
103 | return int64(s.Uint64() & ((1 << 63) - 1))
104 | }
105 |
106 | // only the v1 rand package has this method
107 | func (s *source) Seed(int64) {
108 | panic("The Restate random source is already deterministic based on invocation ID and must not be seeded")
109 | }
110 |
111 | func (s *source) Uint64() uint64 {
112 | result := rotl((s.state[0]+s.state[3]), 23) + s.state[0]
113 |
114 | t := (s.state[1] << 17)
115 |
116 | s.state[2] ^= s.state[0]
117 | s.state[3] ^= s.state[1]
118 | s.state[1] ^= s.state[2]
119 | s.state[0] ^= s.state[3]
120 |
121 | s.state[2] ^= t
122 |
123 | s.state[3] = rotl(s.state[3], 45)
124 |
125 | return result
126 | }
127 |
128 | func rotl(x uint64, k uint64) uint64 {
129 | return (x << k) | (x >> (64 - k))
130 | }
131 |
--------------------------------------------------------------------------------
/mocks/mock_Selector.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.52.1. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import (
6 | restatecontext "github.com/restatedev/sdk-go/internal/restatecontext"
7 | mock "github.com/stretchr/testify/mock"
8 | )
9 |
10 | // MockSelector is an autogenerated mock type for the Selector type
11 | type MockSelector struct {
12 | mock.Mock
13 | }
14 |
15 | type MockSelector_Expecter struct {
16 | mock *mock.Mock
17 | }
18 |
19 | func (_m *MockSelector) EXPECT() *MockSelector_Expecter {
20 | return &MockSelector_Expecter{mock: &_m.Mock}
21 | }
22 |
23 | // Remaining provides a mock function with no fields
24 | func (_m *MockSelector) Remaining() bool {
25 | ret := _m.Called()
26 |
27 | if len(ret) == 0 {
28 | panic("no return value specified for Remaining")
29 | }
30 |
31 | var r0 bool
32 | if rf, ok := ret.Get(0).(func() bool); ok {
33 | r0 = rf()
34 | } else {
35 | r0 = ret.Get(0).(bool)
36 | }
37 |
38 | return r0
39 | }
40 |
41 | // MockSelector_Remaining_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remaining'
42 | type MockSelector_Remaining_Call struct {
43 | *mock.Call
44 | }
45 |
46 | // Remaining is a helper method to define mock.On call
47 | func (_e *MockSelector_Expecter) Remaining() *MockSelector_Remaining_Call {
48 | return &MockSelector_Remaining_Call{Call: _e.mock.On("Remaining")}
49 | }
50 |
51 | func (_c *MockSelector_Remaining_Call) Run(run func()) *MockSelector_Remaining_Call {
52 | _c.Call.Run(func(args mock.Arguments) {
53 | run()
54 | })
55 | return _c
56 | }
57 |
58 | func (_c *MockSelector_Remaining_Call) Return(_a0 bool) *MockSelector_Remaining_Call {
59 | _c.Call.Return(_a0)
60 | return _c
61 | }
62 |
63 | func (_c *MockSelector_Remaining_Call) RunAndReturn(run func() bool) *MockSelector_Remaining_Call {
64 | _c.Call.Return(run)
65 | return _c
66 | }
67 |
68 | // Select provides a mock function with no fields
69 | func (_m *MockSelector) Select() restatecontext.Selectable {
70 | ret := _m.Called()
71 |
72 | if len(ret) == 0 {
73 | panic("no return value specified for Select")
74 | }
75 |
76 | var r0 restatecontext.Selectable
77 | if rf, ok := ret.Get(0).(func() restatecontext.Selectable); ok {
78 | r0 = rf()
79 | } else {
80 | if ret.Get(0) != nil {
81 | r0 = ret.Get(0).(restatecontext.Selectable)
82 | }
83 | }
84 |
85 | return r0
86 | }
87 |
88 | // MockSelector_Select_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Select'
89 | type MockSelector_Select_Call struct {
90 | *mock.Call
91 | }
92 |
93 | // Select is a helper method to define mock.On call
94 | func (_e *MockSelector_Expecter) Select() *MockSelector_Select_Call {
95 | return &MockSelector_Select_Call{Call: _e.mock.On("Select")}
96 | }
97 |
98 | func (_c *MockSelector_Select_Call) Run(run func()) *MockSelector_Select_Call {
99 | _c.Call.Run(func(args mock.Arguments) {
100 | run()
101 | })
102 | return _c
103 | }
104 |
105 | func (_c *MockSelector_Select_Call) Return(_a0 restatecontext.Selectable) *MockSelector_Select_Call {
106 | _c.Call.Return(_a0)
107 | return _c
108 | }
109 |
110 | func (_c *MockSelector_Select_Call) RunAndReturn(run func() restatecontext.Selectable) *MockSelector_Select_Call {
111 | _c.Call.Return(run)
112 | return _c
113 | }
114 |
115 | // NewMockSelector creates a new instance of MockSelector. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
116 | // The first argument is typically a *testing.T value.
117 | func NewMockSelector(t interface {
118 | mock.TestingT
119 | Cleanup(func())
120 | }) *MockSelector {
121 | mock := &MockSelector{}
122 | mock.Mock.Test(t)
123 |
124 | t.Cleanup(func() { mock.AssertExpectations(t) })
125 |
126 | return mock
127 | }
128 |
--------------------------------------------------------------------------------
/mocks/mock_AttachFuture.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.52.1. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import mock "github.com/stretchr/testify/mock"
6 |
7 | /* moved to helpers.go
8 | // MockAttachFuture is an autogenerated mock type for the AttachFuture type
9 | type MockAttachFuture struct {
10 | mock.Mock
11 | }
12 | */
13 |
14 | type MockAttachFuture_Expecter struct {
15 | mock *mock.Mock
16 | }
17 |
18 | func (_m *MockAttachFuture) EXPECT() *MockAttachFuture_Expecter {
19 | return &MockAttachFuture_Expecter{mock: &_m.Mock}
20 | }
21 |
22 | // Response provides a mock function with given fields: output
23 | func (_m *MockAttachFuture) Response(output any) error {
24 | ret := _m.Called(output)
25 |
26 | if len(ret) == 0 {
27 | panic("no return value specified for Response")
28 | }
29 |
30 | var r0 error
31 | if rf, ok := ret.Get(0).(func(any) error); ok {
32 | r0 = rf(output)
33 | } else {
34 | r0 = ret.Error(0)
35 | }
36 |
37 | return r0
38 | }
39 |
40 | // MockAttachFuture_Response_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Response'
41 | type MockAttachFuture_Response_Call struct {
42 | *mock.Call
43 | }
44 |
45 | // Response is a helper method to define mock.On call
46 | // - output any
47 | func (_e *MockAttachFuture_Expecter) Response(output interface{}) *MockAttachFuture_Response_Call {
48 | return &MockAttachFuture_Response_Call{Call: _e.mock.On("Response", output)}
49 | }
50 |
51 | func (_c *MockAttachFuture_Response_Call) Run(run func(output any)) *MockAttachFuture_Response_Call {
52 | _c.Call.Run(func(args mock.Arguments) {
53 | run(args[0].(any))
54 | })
55 | return _c
56 | }
57 |
58 | func (_c *MockAttachFuture_Response_Call) Return(_a0 error) *MockAttachFuture_Response_Call {
59 | _c.Call.Return(_a0)
60 | return _c
61 | }
62 |
63 | func (_c *MockAttachFuture_Response_Call) RunAndReturn(run func(any) error) *MockAttachFuture_Response_Call {
64 | _c.Call.Return(run)
65 | return _c
66 | }
67 |
68 | // handle provides a mock function with no fields
69 | func (_m *MockAttachFuture) handle() uint32 {
70 | ret := _m.Called()
71 |
72 | if len(ret) == 0 {
73 | panic("no return value specified for handle")
74 | }
75 |
76 | var r0 uint32
77 | if rf, ok := ret.Get(0).(func() uint32); ok {
78 | r0 = rf()
79 | } else {
80 | r0 = ret.Get(0).(uint32)
81 | }
82 |
83 | return r0
84 | }
85 |
86 | // MockAttachFuture_handle_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'handle'
87 | type MockAttachFuture_handle_Call struct {
88 | *mock.Call
89 | }
90 |
91 | // handle is a helper method to define mock.On call
92 | func (_e *MockAttachFuture_Expecter) handle() *MockAttachFuture_handle_Call {
93 | return &MockAttachFuture_handle_Call{Call: _e.mock.On("handle")}
94 | }
95 |
96 | func (_c *MockAttachFuture_handle_Call) Run(run func()) *MockAttachFuture_handle_Call {
97 | _c.Call.Run(func(args mock.Arguments) {
98 | run()
99 | })
100 | return _c
101 | }
102 |
103 | func (_c *MockAttachFuture_handle_Call) Return(_a0 uint32) *MockAttachFuture_handle_Call {
104 | _c.Call.Return(_a0)
105 | return _c
106 | }
107 |
108 | func (_c *MockAttachFuture_handle_Call) RunAndReturn(run func() uint32) *MockAttachFuture_handle_Call {
109 | _c.Call.Return(run)
110 | return _c
111 | }
112 |
113 | // NewMockAttachFuture creates a new instance of MockAttachFuture. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
114 | // The first argument is typically a *testing.T value.
115 | func NewMockAttachFuture(t interface {
116 | mock.TestingT
117 | Cleanup(func())
118 | }) *MockAttachFuture {
119 | mock := &MockAttachFuture{}
120 | mock.Mock.Test(t)
121 |
122 | t.Cleanup(func() { mock.AssertExpectations(t) })
123 |
124 | return mock
125 | }
126 |
--------------------------------------------------------------------------------
/mocks/mock_RunAsyncFuture.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.52.2. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import mock "github.com/stretchr/testify/mock"
6 |
7 | /* moved to helpers.go
8 | // MockRunAsyncFuture is an autogenerated mock type for the RunAsyncFuture type
9 | type MockRunAsyncFuture struct {
10 | mock.Mock
11 | }
12 | */
13 |
14 | type MockRunAsyncFuture_Expecter struct {
15 | mock *mock.Mock
16 | }
17 |
18 | func (_m *MockRunAsyncFuture) EXPECT() *MockRunAsyncFuture_Expecter {
19 | return &MockRunAsyncFuture_Expecter{mock: &_m.Mock}
20 | }
21 |
22 | // Result provides a mock function with given fields: output
23 | func (_m *MockRunAsyncFuture) Result(output any) error {
24 | ret := _m.Called(output)
25 |
26 | if len(ret) == 0 {
27 | panic("no return value specified for Result")
28 | }
29 |
30 | var r0 error
31 | if rf, ok := ret.Get(0).(func(any) error); ok {
32 | r0 = rf(output)
33 | } else {
34 | r0 = ret.Error(0)
35 | }
36 |
37 | return r0
38 | }
39 |
40 | // MockRunAsyncFuture_Result_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Result'
41 | type MockRunAsyncFuture_Result_Call struct {
42 | *mock.Call
43 | }
44 |
45 | // Result is a helper method to define mock.On call
46 | // - output any
47 | func (_e *MockRunAsyncFuture_Expecter) Result(output interface{}) *MockRunAsyncFuture_Result_Call {
48 | return &MockRunAsyncFuture_Result_Call{Call: _e.mock.On("Result", output)}
49 | }
50 |
51 | func (_c *MockRunAsyncFuture_Result_Call) Run(run func(output any)) *MockRunAsyncFuture_Result_Call {
52 | _c.Call.Run(func(args mock.Arguments) {
53 | run(args[0].(any))
54 | })
55 | return _c
56 | }
57 |
58 | func (_c *MockRunAsyncFuture_Result_Call) Return(_a0 error) *MockRunAsyncFuture_Result_Call {
59 | _c.Call.Return(_a0)
60 | return _c
61 | }
62 |
63 | func (_c *MockRunAsyncFuture_Result_Call) RunAndReturn(run func(any) error) *MockRunAsyncFuture_Result_Call {
64 | _c.Call.Return(run)
65 | return _c
66 | }
67 |
68 | // handle provides a mock function with no fields
69 | func (_m *MockRunAsyncFuture) handle() uint32 {
70 | ret := _m.Called()
71 |
72 | if len(ret) == 0 {
73 | panic("no return value specified for handle")
74 | }
75 |
76 | var r0 uint32
77 | if rf, ok := ret.Get(0).(func() uint32); ok {
78 | r0 = rf()
79 | } else {
80 | r0 = ret.Get(0).(uint32)
81 | }
82 |
83 | return r0
84 | }
85 |
86 | // MockRunAsyncFuture_handle_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'handle'
87 | type MockRunAsyncFuture_handle_Call struct {
88 | *mock.Call
89 | }
90 |
91 | // handle is a helper method to define mock.On call
92 | func (_e *MockRunAsyncFuture_Expecter) handle() *MockRunAsyncFuture_handle_Call {
93 | return &MockRunAsyncFuture_handle_Call{Call: _e.mock.On("handle")}
94 | }
95 |
96 | func (_c *MockRunAsyncFuture_handle_Call) Run(run func()) *MockRunAsyncFuture_handle_Call {
97 | _c.Call.Run(func(args mock.Arguments) {
98 | run()
99 | })
100 | return _c
101 | }
102 |
103 | func (_c *MockRunAsyncFuture_handle_Call) Return(_a0 uint32) *MockRunAsyncFuture_handle_Call {
104 | _c.Call.Return(_a0)
105 | return _c
106 | }
107 |
108 | func (_c *MockRunAsyncFuture_handle_Call) RunAndReturn(run func() uint32) *MockRunAsyncFuture_handle_Call {
109 | _c.Call.Return(run)
110 | return _c
111 | }
112 |
113 | // NewMockRunAsyncFuture creates a new instance of MockRunAsyncFuture. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
114 | // The first argument is typically a *testing.T value.
115 | func NewMockRunAsyncFuture(t interface {
116 | mock.TestingT
117 | Cleanup(func())
118 | }) *MockRunAsyncFuture {
119 | mock := &MockRunAsyncFuture{}
120 | mock.Mock.Test(t)
121 |
122 | t.Cleanup(func() { mock.AssertExpectations(t) })
123 |
124 | return mock
125 | }
126 |
--------------------------------------------------------------------------------
/server/lambda.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/base64"
7 | "fmt"
8 | "io"
9 | "net/http"
10 | "strings"
11 |
12 | "github.com/restatedev/sdk-go/internal/identity"
13 | )
14 |
15 | type LambdaRequest struct {
16 | Path string `json:"path"`
17 | RawPath string `json:"rawPath"`
18 | Body string `json:"body"`
19 | IsBase64Encoded bool `json:"isBase64Encoded"`
20 | Headers map[string]string `json:"headers"`
21 | }
22 |
23 | type LambdaResponse struct {
24 | StatusCode int `json:"statusCode"`
25 | Headers map[string]string `json:"headers"`
26 | Body string `json:"body"`
27 | IsBase64Encoded bool `json:"isBase64Encoded"`
28 | }
29 |
30 | type LambdaHandlerFunc func(ctx context.Context, event LambdaRequest) (LambdaResponse, error)
31 |
32 | type lambdaResponseWriter struct {
33 | headers http.Header
34 | body bytes.Buffer
35 | status int
36 | }
37 |
38 | func (r *lambdaResponseWriter) Header() http.Header {
39 | return r.headers
40 | }
41 |
42 | func (r *lambdaResponseWriter) Write(body []byte) (int, error) {
43 | if r.status == -1 {
44 | r.status = http.StatusOK
45 | }
46 |
47 | // if the content type header is not set when we write the body we try to
48 | // detect one and set it by default. If the content type cannot be detected
49 | // it is automatically set to "application/octet-stream" by the
50 | // DetectContentType method
51 | if r.Header().Get("Content-Type") == "" {
52 | r.Header().Add("Content-Type", http.DetectContentType(body))
53 | }
54 |
55 | return (&r.body).Write(body)
56 | }
57 |
58 | func (r *lambdaResponseWriter) WriteHeader(statusCode int) {
59 | r.status = statusCode
60 | }
61 |
62 | func (r *lambdaResponseWriter) Flush() {}
63 |
64 | func (r *lambdaResponseWriter) LambdaResponse() LambdaResponse {
65 | headers := make(map[string]string, len(r.headers))
66 | for k, v := range r.headers {
67 | if len(v) == 0 {
68 | continue
69 | }
70 | headers[k] = v[0]
71 | }
72 |
73 | return LambdaResponse{
74 | Headers: headers,
75 | StatusCode: r.status,
76 | IsBase64Encoded: true,
77 | Body: base64.StdEncoding.EncodeToString(r.body.Bytes()),
78 | }
79 | }
80 |
81 | // LambdaHandler obtains a Lambda handler function representing the bound services
82 | // .Bidirectional(false) will be set on your behalf as Lambda only supports request-response communication
83 | func (r *Restate) LambdaHandler() (LambdaHandlerFunc, error) {
84 | r.Bidirectional(false)
85 |
86 | if r.keyIDs == nil {
87 | r.systemLog.Warn("Accepting requests without validating request signatures; Invoke must be restricted")
88 | } else {
89 | ks, err := identity.ParseKeySetV1(r.keyIDs)
90 | if err != nil {
91 | return nil, fmt.Errorf("invalid request identity keys: %w", err)
92 | }
93 | r.keySet = ks
94 | r.systemLog.Info("Validating requests using signing keys", "keys", r.keyIDs)
95 | }
96 |
97 | return LambdaHandlerFunc(r.lambdaHandler), nil
98 | }
99 |
100 | func (r *Restate) lambdaHandler(ctx context.Context, event LambdaRequest) (LambdaResponse, error) {
101 | var path string
102 | if event.Path != "" {
103 | path = event.Path
104 | } else if event.RawPath != "" {
105 | path = event.RawPath
106 | }
107 |
108 | var body io.Reader
109 | if event.Body != "" {
110 | if event.IsBase64Encoded {
111 | body = base64.NewDecoder(base64.StdEncoding, strings.NewReader(event.Body))
112 | } else {
113 | body = strings.NewReader(event.Body)
114 | }
115 | }
116 |
117 | // method is not read so just set POST as a default
118 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, path, body)
119 | if err != nil {
120 | return LambdaResponse{StatusCode: http.StatusBadGateway}, err
121 | }
122 | req.RequestURI = path
123 | for k, v := range event.Headers {
124 | req.Header.Add(k, v)
125 | }
126 |
127 | rw := lambdaResponseWriter{headers: make(http.Header, 2), status: -1}
128 |
129 | r.handler(&rw, req)
130 |
131 | return rw.LambdaResponse(), nil
132 | }
133 |
--------------------------------------------------------------------------------
/internal/restatecontext/promise.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "fmt"
5 | "github.com/restatedev/sdk-go/encoding"
6 | "github.com/restatedev/sdk-go/internal/errors"
7 | pbinternal "github.com/restatedev/sdk-go/internal/generated"
8 | "github.com/restatedev/sdk-go/internal/options"
9 | "github.com/restatedev/sdk-go/internal/statemachine"
10 | )
11 |
12 | func (restateCtx *ctx) Promise(key string, opts ...options.PromiseOption) DurablePromise {
13 | o := options.PromiseOptions{}
14 | for _, opt := range opts {
15 | opt.BeforePromise(&o)
16 | }
17 | if o.Codec == nil {
18 | o.Codec = encoding.JSONCodec
19 | }
20 |
21 | handle, err := restateCtx.stateMachine.SysPromiseGet(restateCtx, key)
22 | if err != nil {
23 | panic(err)
24 | }
25 | restateCtx.checkStateTransition()
26 |
27 | return &durablePromise{
28 | asyncResult: newAsyncResult(restateCtx, handle),
29 | key: key,
30 | codec: o.Codec,
31 | }
32 | }
33 |
34 | type DurablePromise interface {
35 | Selectable
36 | Result(output any) (err error)
37 | Peek(output any) (ok bool, err error)
38 | Resolve(value any) error
39 | Reject(reason error) error
40 | }
41 |
42 | type durablePromise struct {
43 | asyncResult
44 | key string
45 | codec encoding.Codec
46 | }
47 |
48 | func (d *durablePromise) Result(output any) (err error) {
49 | switch result := d.pollProgressAndLoadValue().(type) {
50 | case statemachine.ValueSuccess:
51 | {
52 | if err := encoding.Unmarshal(d.codec, result.Success, output); err != nil {
53 | panic(fmt.Errorf("failed to unmarshal promise result into output: %w", err))
54 | }
55 | return nil
56 | }
57 | case statemachine.ValueFailure:
58 | return errorFromFailure(result)
59 | default:
60 | panic(fmt.Errorf("unexpected value %s", result))
61 |
62 | }
63 | }
64 |
65 | func (d *durablePromise) Peek(output any) (ok bool, err error) {
66 | handle, err := d.ctx.stateMachine.SysPromisePeek(d.ctx, d.key)
67 | if err != nil {
68 | panic(err)
69 | }
70 | d.ctx.checkStateTransition()
71 |
72 | ar := newAsyncResult(d.ctx, handle)
73 | switch result := ar.pollProgressAndLoadValue().(type) {
74 | case statemachine.ValueVoid:
75 | return false, nil
76 | case statemachine.ValueSuccess:
77 | {
78 | if err := encoding.Unmarshal(d.codec, result.Success, output); err != nil {
79 | panic(fmt.Errorf("failed to unmarshal promise result into output: %w", err))
80 | }
81 | return true, nil
82 | }
83 | case statemachine.ValueFailure:
84 | return false, errorFromFailure(result)
85 | default:
86 | panic(fmt.Errorf("unexpected value %s", result))
87 | }
88 | }
89 |
90 | func (d *durablePromise) Resolve(value any) error {
91 | bytes, err := encoding.Marshal(d.codec, value)
92 | if err != nil {
93 | panic(fmt.Errorf("failed to marshal Promise Resolve value: %w", err))
94 | }
95 |
96 | input := pbinternal.VmSysPromiseCompleteParameters{}
97 | input.SetId(d.key)
98 | input.SetSuccess(bytes)
99 | handle, err := d.ctx.stateMachine.SysPromiseComplete(d.ctx, &input)
100 | if err != nil {
101 | panic(err)
102 | }
103 | d.ctx.checkStateTransition()
104 |
105 | ar := newAsyncResult(d.ctx, handle)
106 | switch result := ar.pollProgressAndLoadValue().(type) {
107 | case statemachine.ValueVoid:
108 | return nil
109 | case statemachine.ValueFailure:
110 | return errorFromFailure(result)
111 | default:
112 | panic(fmt.Errorf("unexpected value %s", result))
113 | }
114 | }
115 |
116 | func (d *durablePromise) Reject(reason error) error {
117 | failure := pbinternal.Failure{}
118 | failure.SetCode(uint32(errors.ErrorCode(reason)))
119 | failure.SetMessage(reason.Error())
120 |
121 | input := pbinternal.VmSysPromiseCompleteParameters{}
122 | input.SetId(d.key)
123 | input.SetFailure(&failure)
124 | handle, err := d.ctx.stateMachine.SysPromiseComplete(d.ctx, &input)
125 | if err != nil {
126 | panic(err)
127 | }
128 | d.ctx.checkStateTransition()
129 |
130 | ar := newAsyncResult(d.ctx, handle)
131 | switch result := ar.pollProgressAndLoadValue().(type) {
132 | case statemachine.ValueVoid:
133 | return nil
134 | case statemachine.ValueFailure:
135 | return errorFromFailure(result)
136 | default:
137 | panic(fmt.Errorf("unexpected value %s", result))
138 | }
139 | }
140 |
--------------------------------------------------------------------------------
/ingress/mock_server_test.go:
--------------------------------------------------------------------------------
1 | package ingress_test
2 |
3 | import (
4 | "encoding/json"
5 | "io"
6 | "net/http"
7 | "net/http/httptest"
8 | "strings"
9 | "testing"
10 |
11 | "github.com/nsf/jsondiff"
12 | "github.com/stretchr/testify/require"
13 |
14 | "github.com/restatedev/sdk-go/internal/ingress"
15 | )
16 |
17 | type mockIngressServer struct {
18 | URL string
19 | s *httptest.Server
20 | method string
21 | path string
22 | headers map[string]string
23 | body []byte
24 | query map[string]string
25 | }
26 |
27 | func newMockIngressServer() *mockIngressServer {
28 | m := &mockIngressServer{}
29 | m.s = httptest.NewServer(m)
30 | m.URL = m.s.URL
31 | return m
32 | }
33 |
34 | func (m *mockIngressServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
35 | m.method = r.Method
36 | m.path = r.URL.Path
37 | m.body, _ = io.ReadAll(r.Body)
38 |
39 | m.headers = make(map[string]string)
40 | for k, v := range r.Header {
41 | m.headers[k] = v[0]
42 | }
43 |
44 | m.query = make(map[string]string)
45 | for k, v := range r.URL.Query() {
46 | m.query[k] = v[0]
47 | }
48 |
49 | if strings.HasSuffix(m.path, "/send") {
50 | inv := ingress.Invocation{
51 | Id: "inv_1",
52 | Status: "Accepted",
53 | }
54 | json.NewEncoder(w).Encode(&inv)
55 | } else {
56 | w.Write([]byte("\"OK\""))
57 | }
58 | }
59 |
60 | func (m *mockIngressServer) AssertPath(t *testing.T, expectedPath string) {
61 | require.Equalf(t, expectedPath, m.path, "expected path %s, got %s", expectedPath, m.path)
62 | }
63 |
64 | func (m *mockIngressServer) AssertMethod(t *testing.T, expectedMethod string) {
65 | require.Equalf(t, expectedMethod, m.method, "expected method %s, got %s", expectedMethod, m.method)
66 | }
67 |
68 | func (m *mockIngressServer) AssertContentType(t *testing.T, contentType string) {
69 | require.NotNil(t, m.headers)
70 | require.Equal(t, contentType, m.headers["Content-Type"])
71 | }
72 |
73 | func (m *mockIngressServer) AssertNoContentType(t *testing.T) {
74 | require.NotNil(t, m.headers)
75 | require.Empty(t, m.headers["Content-Type"])
76 | }
77 |
78 | func (m *mockIngressServer) AssertNoBody(t *testing.T) {
79 | require.NotNil(t, m.headers)
80 | require.Empty(t, m.body)
81 | }
82 |
83 | func (m *mockIngressServer) AssertHeaders(t *testing.T, expectedHeaders map[string]string) {
84 | if expectedHeaders == nil && m.headers == nil {
85 | return
86 | }
87 | if expectedHeaders != nil && m.headers == nil {
88 | require.Fail(t, "expected headers but got none")
89 | }
90 | if expectedHeaders == nil && m.headers != nil {
91 | require.Fail(t, "expected no headers but got some")
92 | }
93 | reqHeaders := make(map[string]string)
94 | for k, v := range m.headers {
95 | reqHeaders[strings.ToLower(k)] = v
96 | }
97 | for k, v := range expectedHeaders {
98 | h, ok := reqHeaders[strings.ToLower(k)]
99 | require.Truef(t, ok, "header %s not found in request", k)
100 | require.Equalf(t, v, h, "header %s not equal to expected value", k)
101 | }
102 | }
103 |
104 | func (m *mockIngressServer) AssertBody(t *testing.T, expectedBody []byte) {
105 | if len(expectedBody) == 0 && len(m.body) == 0 {
106 | return
107 | }
108 | require.Equalf(t, len(expectedBody), len(m.body), "expected body length %d, got %d", len(expectedBody), len(m.body))
109 |
110 | diff, _ := jsondiff.Compare(expectedBody, m.body, &jsondiff.Options{})
111 | require.Equalf(t, diff, jsondiff.FullMatch, "expected body %s, got %s; diff: %s", string(m.body), string(expectedBody), diff.String())
112 | }
113 |
114 | func (m *mockIngressServer) AssertQuery(t *testing.T, expectedQuery map[string]string) {
115 | if expectedQuery == nil && m.query == nil {
116 | return
117 | }
118 | if expectedQuery != nil && len(expectedQuery) > 0 && m.query == nil {
119 | require.Fail(t, "expected query but got none")
120 | }
121 | if expectedQuery == nil && m.query != nil && len(m.query) > 0 {
122 | require.Fail(t, "expected no query but got some")
123 | }
124 | for k, v := range expectedQuery {
125 | h, ok := m.query[k]
126 | require.Truef(t, ok, "query %s not found in request", k)
127 | require.Equalf(t, v, h, "query %s not equal to expected value", k)
128 | }
129 | }
130 |
131 | func (m *mockIngressServer) Close() {
132 | m.s.Close()
133 | }
134 |
--------------------------------------------------------------------------------
/.github/workflows/integration.yaml:
--------------------------------------------------------------------------------
1 | name: Integration
2 |
3 | # Controls when the workflow will run
4 | on:
5 | pull_request:
6 | push:
7 | branches:
8 | - main
9 | schedule:
10 | - cron: "0 */6 * * *" # Every 6 hours
11 | workflow_dispatch:
12 | inputs:
13 | restateCommit:
14 | description: "restate commit"
15 | required: false
16 | default: ""
17 | type: string
18 | restateImage:
19 | description: "restate image, superseded by restate commit"
20 | required: false
21 | default: "ghcr.io/restatedev/restate:main"
22 | type: string
23 | workflow_call:
24 | inputs:
25 | restateCommit:
26 | description: "restate commit"
27 | required: false
28 | default: ""
29 | type: string
30 | restateImage:
31 | description: "restate image, superseded by restate commit"
32 | required: false
33 | default: "ghcr.io/restatedev/restate:main"
34 | type: string
35 |
36 | jobs:
37 | sdk-test-suite:
38 | if: github.repository_owner == 'restatedev'
39 | runs-on: ubuntu-latest
40 | name: "Features integration test"
41 | permissions:
42 | contents: read
43 | issues: read
44 | checks: write
45 | pull-requests: write
46 | actions: read
47 |
48 | steps:
49 | - uses: actions/checkout@v4
50 | with:
51 | repository: restatedev/sdk-go
52 |
53 | # support importing oci-format restate.tar
54 | - name: Set up Docker containerd snapshotter
55 | uses: docker/setup-docker-action@v4
56 | with:
57 | version: "v28.5.2"
58 | set-host: true
59 | daemon-config: |
60 | {
61 | "features": {
62 | "containerd-snapshotter": true
63 | }
64 | }
65 |
66 | ### Download the Restate container image, if needed
67 | # Setup restate snapshot if necessary
68 | # Due to https://github.com/actions/upload-artifact/issues/53
69 | # We must use download-artifact to get artifacts created during *this* workflow run, ie by workflow call
70 | - name: Download restate snapshot from in-progress workflow
71 | if: ${{ inputs.restateCommit != '' && github.event_name != 'workflow_dispatch' }}
72 | uses: actions/download-artifact@v4
73 | with:
74 | name: restate.tar
75 |
76 | # In the workflow dispatch case where the artifact was created in a previous run, we can download as normal
77 | - name: Download restate snapshot from completed workflow
78 | if: ${{ inputs.restateCommit != '' && github.event_name == 'workflow_dispatch' }}
79 | uses: dawidd6/action-download-artifact@v3
80 | with:
81 | repo: restatedev/restate
82 | workflow: ci.yml
83 | commit: ${{ inputs.restateCommit }}
84 | name: restate.tar
85 |
86 | - name: Install restate snapshot
87 | if: ${{ inputs.restateCommit != '' }}
88 | run: |
89 | output=$(docker load --input restate.tar | head -n 1)
90 | docker tag "${output#*: }" "localhost/restatedev/restate-commit-download:latest"
91 | docker image ls -a
92 |
93 | - name: Setup Go
94 | uses: actions/setup-go@v5
95 | with:
96 | go-version: "1.21.x"
97 |
98 | - name: Setup ko
99 | uses: ko-build/setup-ko@v0.6
100 | with:
101 | version: v0.16.0
102 |
103 | - name: Install dependencies
104 | run: go get .
105 |
106 | - name: Build Docker image
107 | run: KO_DOCKER_REPO=restatedev ko build -B -L github.com/restatedev/sdk-go/test-services
108 |
109 | - name: Run test tool
110 | uses: restatedev/sdk-test-suite@v3.4
111 | with:
112 | restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }}
113 | serviceContainerImage: "restatedev/test-services"
114 | exclusionsFile: "test-services/exclusions.yaml"
115 | testArtifactOutput: "sdk-go-integration-test-report"
116 | serviceContainerEnvFile: "test-services/.env"
117 |
--------------------------------------------------------------------------------
/internal/restatecontext/execute_invocation.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/restatedev/sdk-go/internal/errors"
7 | pbinternal "github.com/restatedev/sdk-go/internal/generated"
8 | "github.com/restatedev/sdk-go/internal/log"
9 | "github.com/restatedev/sdk-go/internal/statemachine"
10 | "io"
11 | "log/slog"
12 | "runtime/debug"
13 | )
14 |
15 | func ExecuteInvocation(ctx context.Context, logger *slog.Logger, stateMachine *statemachine.StateMachine, conn io.ReadWriteCloser, handler Handler, dropReplayLogs bool, logHandler slog.Handler, attemptHeaders map[string][]string) error {
16 | // Let's read the input entry
17 | invocationInput, err := stateMachine.SysInput(ctx)
18 | if err != nil {
19 | logger.WarnContext(ctx, "Error when reading invocation input", log.Error(err))
20 | if err = consumeOutput(ctx, stateMachine, conn); err != nil {
21 | logger.WarnContext(ctx, "Error when consuming output", log.Error(err))
22 | return err
23 | }
24 | return err
25 | }
26 |
27 | // Instantiate the restate context
28 | restateCtx := newContext(ctx, stateMachine, invocationInput, conn, attemptHeaders, dropReplayLogs, logHandler)
29 |
30 | // Invoke the handler
31 | invoke(restateCtx, handler, logger)
32 | return nil
33 | }
34 |
35 | func invoke(restateCtx *ctx, handler Handler, logger *slog.Logger) {
36 | // Run read loop on a goroutine
37 | go func(restateCtx *ctx, logger *slog.Logger) { restateCtx.readInputLoop(logger) }(restateCtx, logger)
38 |
39 | defer func() {
40 | // recover will return a non-nil object
41 | // if there was a panic
42 | //
43 | recovered := recover()
44 |
45 | switch typ := recovered.(type) {
46 | case nil:
47 | // nothing to do, just exit
48 | break
49 | case *statemachine.SuspensionError:
50 | case statemachine.SuspensionError:
51 | restateCtx.internalLogger.LogAttrs(restateCtx, slog.LevelInfo, "Suspending invocation")
52 | break
53 | default:
54 | restateCtx.internalLogger.LogAttrs(restateCtx, slog.LevelError, "Invocation panicked, returning error to Restate", slog.Any("err", typ))
55 |
56 | if err := restateCtx.stateMachine.NotifyError(restateCtx, fmt.Sprint(typ), string(debug.Stack())); err != nil {
57 | restateCtx.internalLogger.WarnContext(restateCtx, "Error when notifying error to state restateContext", log.Error(err))
58 | }
59 |
60 | break
61 | }
62 |
63 | // Consume all the state restateContext output as last step
64 | if err := consumeOutput(restateCtx, restateCtx.stateMachine, restateCtx.conn); err != nil {
65 | restateCtx.internalLogger.WarnContext(restateCtx, "Error when consuming output", log.Error(err))
66 | }
67 | }()
68 |
69 | restateCtx.internalLogger.InfoContext(restateCtx, "Handling invocation")
70 |
71 | var bytes []byte
72 | var err error
73 | bytes, err = handler.Call(restateCtx, restateCtx.request.Body)
74 |
75 | if err != nil && errors.IsTerminalError(err) {
76 | restateCtx.internalLogger.LogAttrs(restateCtx, slog.LevelError, "Invocation returned a terminal failure", log.Error(err))
77 |
78 | failure := pbinternal.Failure{}
79 | failure.SetCode(uint32(errors.ErrorCode(err)))
80 | failure.SetMessage(err.Error())
81 | outputParameters := pbinternal.VmSysWriteOutputParameters{}
82 | outputParameters.SetFailure(&failure)
83 | if err := restateCtx.stateMachine.SysWriteOutput(restateCtx, &outputParameters); err != nil {
84 | // This is handled by the panic catcher above
85 | panic(err)
86 | }
87 | } else if err != nil {
88 | restateCtx.internalLogger.LogAttrs(restateCtx, slog.LevelError, "Invocation returned a non-terminal failure", log.Error(err))
89 |
90 | // This is handled by the panic catcher above
91 | panic(err)
92 | } else {
93 | restateCtx.internalLogger.InfoContext(restateCtx, "Invocation completed successfully")
94 |
95 | outputParameters := pbinternal.VmSysWriteOutputParameters{}
96 | outputParameters.SetSuccess(bytes)
97 | if err := restateCtx.stateMachine.SysWriteOutput(restateCtx, &outputParameters); err != nil {
98 | // This is handled by the panic catcher above
99 | panic(err)
100 | }
101 | }
102 |
103 | // Sys_end the state restateContext
104 | if err := restateCtx.stateMachine.SysEnd(restateCtx); err != nil {
105 | // This is handled by the panic catcher above
106 | panic(err)
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/test-services/proxy.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "time"
5 |
6 | restate "github.com/restatedev/sdk-go"
7 | "github.com/restatedev/sdk-go/internal/options"
8 | )
9 |
10 | type ProxyRequest struct {
11 | ServiceName string `json:"serviceName"`
12 | VirtualObjectKey *string `json:"virtualObjectKey,omitempty"`
13 | HandlerName string `json:"handlerName"`
14 | // We need to use []int because Golang takes the opinionated choice of treating []byte as Base64
15 | Message []int `json:"message"`
16 | IdempotencyKey *string `json:"idempotencyKey,omitempty"`
17 | DelayMillis *uint64 `json:"delayMillis,omitempty"`
18 | }
19 |
20 | func (req *ProxyRequest) ToTarget(ctx restate.Context) restate.Client[[]byte, []byte] {
21 | if req.VirtualObjectKey != nil {
22 | return restate.WithRequestType[[]byte](restate.Object[[]byte](
23 | ctx,
24 | req.ServiceName,
25 | *req.VirtualObjectKey,
26 | req.HandlerName,
27 | restate.WithBinary))
28 | } else {
29 | return restate.WithRequestType[[]byte](restate.Service[[]byte](
30 | ctx,
31 | req.ServiceName,
32 | req.HandlerName,
33 | restate.WithBinary))
34 | }
35 | }
36 |
37 | type ManyCallRequest struct {
38 | ProxyRequest ProxyRequest `json:"proxyRequest"`
39 | OneWayCall bool `json:"oneWayCall"`
40 | AwaitAtTheEnd bool `json:"awaitAtTheEnd"`
41 | }
42 |
43 | func init() {
44 | REGISTRY.AddDefinition(
45 | restate.NewService("Proxy").
46 | Handler("call", restate.NewServiceHandler(
47 | // We need to use []int because Golang takes the opinionated choice of treating []byte as Base64
48 | func(ctx restate.Context, req ProxyRequest) ([]int, error) {
49 | input := intArrayToByteArray(req.Message)
50 | var opts []options.RequestOption
51 | if req.IdempotencyKey != nil {
52 | opts = append(opts, restate.WithIdempotencyKey(*req.IdempotencyKey))
53 | }
54 | bytes, err := req.ToTarget(ctx).Request(input, opts...)
55 | return byteArrayToIntArray(bytes), err
56 | })).
57 | Handler("oneWayCall", restate.NewServiceHandler(
58 | // We need to use []int because Golang takes the opinionated choice of treating []byte as Base64
59 | func(ctx restate.Context, req ProxyRequest) (string, error) {
60 | input := intArrayToByteArray(req.Message)
61 | var opts []options.SendOption
62 | if req.IdempotencyKey != nil {
63 | opts = append(opts, restate.WithIdempotencyKey(*req.IdempotencyKey))
64 | }
65 | if req.DelayMillis != nil {
66 | opts = append(opts, restate.WithDelay(time.Millisecond*time.Duration(*req.DelayMillis)))
67 | }
68 | return req.ToTarget(ctx).Send(input, opts...).GetInvocationId(), nil
69 | })).
70 | Handler("manyCalls", restate.NewServiceHandler(
71 | // We need to use []int because Golang takes the opinionated choice of treating []byte as Base64
72 | func(ctx restate.Context, requests []ManyCallRequest) (restate.Void, error) {
73 | var toAwait []restate.Selectable
74 |
75 | for _, req := range requests {
76 | input := intArrayToByteArray(req.ProxyRequest.Message)
77 | if req.OneWayCall {
78 | var opts []options.SendOption
79 | if req.ProxyRequest.IdempotencyKey != nil {
80 | opts = append(opts, restate.WithIdempotencyKey(*req.ProxyRequest.IdempotencyKey))
81 | }
82 | if req.ProxyRequest.DelayMillis != nil {
83 | opts = append(opts, restate.WithDelay(time.Millisecond*time.Duration(*req.ProxyRequest.DelayMillis)))
84 | }
85 | req.ProxyRequest.ToTarget(ctx).Send(input, opts...)
86 | } else {
87 | var opts []options.RequestOption
88 | if req.ProxyRequest.IdempotencyKey != nil {
89 | opts = append(opts, restate.WithIdempotencyKey(*req.ProxyRequest.IdempotencyKey))
90 | }
91 | fut := req.ProxyRequest.ToTarget(ctx).RequestFuture(input, opts...)
92 | if req.AwaitAtTheEnd {
93 | toAwait = append(toAwait, fut)
94 | }
95 | }
96 | }
97 |
98 | for fut, err := range restate.Wait(ctx, toAwait...) {
99 | if err != nil {
100 | return restate.Void{}, err
101 | }
102 | if _, err := fut.(restate.ResponseFuture[[]byte]).Response(); err != nil {
103 | return restate.Void{}, err
104 | }
105 | }
106 |
107 | return restate.Void{}, nil
108 | })))
109 | }
110 |
111 | func intArrayToByteArray(in []int) []byte {
112 | out := make([]byte, len(in))
113 | for idx, val := range in {
114 | out[idx] = byte(val)
115 | }
116 | return out
117 | }
118 |
119 | func byteArrayToIntArray(in []byte) []int {
120 | out := make([]int, len(in))
121 | for idx, val := range in {
122 | out[idx] = int(val)
123 | }
124 | return out
125 | }
126 |
--------------------------------------------------------------------------------
/examples/ticketreservation/main_test.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log/slog"
5 | "testing"
6 | "time"
7 |
8 | "github.com/google/uuid"
9 | restate "github.com/restatedev/sdk-go"
10 | "github.com/restatedev/sdk-go/mocks"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestPayment(t *testing.T) {
15 | mockCtx := mocks.NewMockContext(t)
16 |
17 | mockCtx.EXPECT().MockRand().UUID().Return(uuid.Max)
18 |
19 | mockCtx.EXPECT().RunAndExpect(mockCtx, true, nil)
20 | mockCtx.EXPECT().Log().Return(slog.Default())
21 |
22 | resp, err := (&checkout{}).Payment(restate.WithMockContext(mockCtx), PaymentRequest{Tickets: []string{"abc"}})
23 | assert.NoError(t, err)
24 | assert.Equal(t, resp, PaymentResponse{ID: "ffffffff-ffff-ffff-ffff-ffffffffffff", Price: 30})
25 | }
26 |
27 | func TestReserve(t *testing.T) {
28 | mockCtx := mocks.NewMockContext(t)
29 |
30 | mockCtx.EXPECT().GetAndReturn("status", TicketAvailable)
31 | mockCtx.EXPECT().Set("status", TicketReserved)
32 |
33 | ok, err := (&ticketService{}).Reserve(restate.WithMockContext(mockCtx), restate.Void{})
34 | assert.NoError(t, err)
35 | assert.True(t, ok)
36 | }
37 |
38 | func TestUnreserve(t *testing.T) {
39 | mockCtx := mocks.NewMockContext(t)
40 |
41 | mockCtx.EXPECT().Key().Return("foo")
42 | mockCtx.EXPECT().Log().Return(slog.Default())
43 | mockCtx.EXPECT().GetAndReturn("status", TicketAvailable)
44 | mockCtx.EXPECT().Clear("status")
45 |
46 | _, err := (&ticketService{}).Unreserve(restate.WithMockContext(mockCtx), restate.Void{})
47 | assert.NoError(t, err)
48 | }
49 |
50 | func TestMarkAsSold(t *testing.T) {
51 | mockCtx := mocks.NewMockContext(t)
52 |
53 | mockCtx.EXPECT().Key().Return("foo")
54 | mockCtx.EXPECT().Log().Return(slog.Default())
55 | mockCtx.EXPECT().GetAndReturn("status", TicketReserved)
56 | mockCtx.EXPECT().Set("status", TicketSold)
57 |
58 | _, err := (&ticketService{}).MarkAsSold(restate.WithMockContext(mockCtx), restate.Void{})
59 | assert.NoError(t, err)
60 | }
61 |
62 | func TestStatus(t *testing.T) {
63 | mockCtx := mocks.NewMockContext(t)
64 |
65 | mockCtx.EXPECT().Key().Return("foo")
66 | mockCtx.EXPECT().Log().Return(slog.Default())
67 | mockCtx.EXPECT().GetAndReturn("status", TicketReserved)
68 |
69 | status, err := (&ticketService{}).Status(restate.WithMockContext(mockCtx), restate.Void{})
70 | assert.NoError(t, err)
71 | assert.Equal(t, status, TicketReserved)
72 | }
73 |
74 | func TestAddTicket(t *testing.T) {
75 | mockCtx := mocks.NewMockContext(t)
76 |
77 | mockCtx.EXPECT().Key().Return("userID")
78 | mockCtx.EXPECT().MockObjectClient(TicketServiceName, "ticket2", "Reserve").RequestAndReturn("userID", true, nil)
79 |
80 | mockCtx.EXPECT().GetAndReturn("tickets", []string{"ticket1"})
81 | mockCtx.EXPECT().Set("tickets", []string{"ticket1", "ticket2"})
82 | mockCtx.EXPECT().MockObjectClient(UserSessionServiceName, "userID", "ExpireTicket").
83 | MockSend("ticket2", restate.WithDelay(15*time.Minute))
84 |
85 | ok, err := (&userSession{}).AddTicket(restate.WithMockContext(mockCtx), "ticket2")
86 | assert.NoError(t, err)
87 | assert.True(t, ok)
88 | }
89 |
90 | func TestExpireTicket(t *testing.T) {
91 | mockCtx := mocks.NewMockContext(t)
92 |
93 | mockCtx.EXPECT().GetAndReturn("tickets", []string{"ticket1", "ticket2"})
94 | mockCtx.EXPECT().Set("tickets", []string{"ticket1"})
95 |
96 | mockCtx.EXPECT().MockObjectClient(TicketServiceName, "ticket2", "Unreserve").MockSend(restate.Void{})
97 |
98 | _, err := (&userSession{}).ExpireTicket(restate.WithMockContext(mockCtx), "ticket2")
99 | assert.NoError(t, err)
100 | }
101 |
102 | func TestCheckout(t *testing.T) {
103 | mockCtx := mocks.NewMockContext(t)
104 |
105 | mockCtx.EXPECT().Key().Return("userID")
106 | mockCtx.EXPECT().GetAndReturn("tickets", []string{"ticket1"})
107 | mockCtx.EXPECT().Log().Return(slog.Default())
108 |
109 | mockAfter := mockCtx.EXPECT().MockAfter(time.Minute)
110 |
111 | mockResponseFuture := mockCtx.EXPECT().MockObjectClient(CheckoutServiceName, "", "Payment").
112 | MockResponseFuture(PaymentRequest{UserID: "userID", Tickets: []string{"ticket1"}})
113 |
114 | mockWaitIter := mockCtx.EXPECT().MockWaitIter(mockAfter, mockResponseFuture)
115 | mockWaitIter.Next().Return(true)
116 | mockWaitIter.Err().Return(nil)
117 | mockWaitIter.Value().Return(mockResponseFuture)
118 |
119 | mockResponseFuture.EXPECT().ResponseAndReturn(PaymentResponse{ID: "paymentID", Price: 30}, nil)
120 |
121 | mockCtx.EXPECT().MockObjectClient(TicketServiceName, "ticket1", "MarkAsSold").MockSend(restate.Void{})
122 |
123 | mockCtx.EXPECT().Clear("tickets")
124 |
125 | ok, err := (&userSession{}).Checkout(restate.WithMockContext(mockCtx), restate.Void{})
126 | assert.NoError(t, err)
127 | assert.True(t, ok)
128 | }
129 |
--------------------------------------------------------------------------------
/examples/codegen/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log/slog"
7 | "os"
8 | "time"
9 |
10 | restate "github.com/restatedev/sdk-go"
11 | helloworld "github.com/restatedev/sdk-go/examples/codegen/proto"
12 | "github.com/restatedev/sdk-go/ingress"
13 | "github.com/restatedev/sdk-go/server"
14 | )
15 |
16 | type greeter struct {
17 | helloworld.UnimplementedGreeterServer
18 | }
19 |
20 | func (greeter) SayHello(ctx restate.Context, req *helloworld.HelloRequest) (*helloworld.HelloResponse, error) {
21 | // Example usage of the generated client between services
22 | counter := helloworld.NewCounterClient(ctx, req.Name)
23 | count, err := counter.Add().
24 | Request(&helloworld.AddRequest{Delta: 1})
25 | if err != nil {
26 | return nil, err
27 | }
28 |
29 | return &helloworld.HelloResponse{
30 | Message: fmt.Sprintf("Hello, %s! Call number: %d", req.Name, count.Value),
31 | }, nil
32 | }
33 |
34 | type counter struct {
35 | helloworld.UnimplementedCounterServer
36 | }
37 |
38 | func (c counter) Add(ctx restate.ObjectContext, req *helloworld.AddRequest) (*helloworld.GetResponse, error) {
39 | count, err := restate.Get[int64](ctx, "counter")
40 | if err != nil {
41 | return nil, err
42 | }
43 |
44 | count += req.Delta
45 | restate.Set(ctx, "counter", count)
46 |
47 | return &helloworld.GetResponse{Value: count}, nil
48 | }
49 |
50 | func (c counter) Get(ctx restate.ObjectSharedContext, _ *helloworld.GetRequest) (*helloworld.GetResponse, error) {
51 | count, err := restate.Get[int64](ctx, "counter")
52 | if err != nil {
53 | return nil, err
54 | }
55 |
56 | return &helloworld.GetResponse{Value: count}, nil
57 | }
58 |
59 | type workflow struct {
60 | helloworld.UnimplementedWorkflowServer
61 | }
62 |
63 | func (workflow) Run(ctx restate.WorkflowContext, _ *helloworld.RunRequest) (*helloworld.RunResponse, error) {
64 | restate.Set(ctx, "status", "waiting")
65 | _, err := restate.Promise[restate.Void](ctx, "promise").Result()
66 | if err != nil {
67 | return nil, err
68 | }
69 | restate.Set(ctx, "status", "finished")
70 | return &helloworld.RunResponse{Status: "finished"}, nil
71 | }
72 |
73 | func (workflow) Finish(ctx restate.WorkflowSharedContext, _ *helloworld.FinishRequest) (*helloworld.FinishResponse, error) {
74 | return nil, restate.Promise[restate.Void](ctx, "promise").Resolve(restate.Void{})
75 | }
76 |
77 | func (workflow) Status(ctx restate.WorkflowSharedContext, _ *helloworld.StatusRequest) (*helloworld.StatusResponse, error) {
78 | status, err := restate.Get[string](ctx, "status")
79 | if err != nil {
80 | return nil, err
81 | }
82 | return &helloworld.StatusResponse{Status: status}, nil
83 | }
84 |
85 | func main() {
86 | server := server.NewRestate().
87 | Bind(helloworld.NewGreeterServer(greeter{})).
88 | Bind(helloworld.NewCounterServer(counter{})).
89 | Bind(helloworld.NewWorkflowServer(workflow{}))
90 |
91 | go func() {
92 | ctx := context.Background()
93 | time.Sleep(15 * time.Second)
94 |
95 | // Example usage of the generated ingress client
96 |
97 | client := ingress.NewClient("http://localhost:8080")
98 |
99 | counterClient := helloworld.NewCounterIngressClient(client, "fra")
100 |
101 | addSendRes, err := counterClient.Add().Send(ctx, &helloworld.AddRequest{Delta: 1}, restate.WithDelay(10*time.Second))
102 | if err != nil {
103 | slog.Error("failed to send request", "err", err.Error())
104 | os.Exit(1)
105 | }
106 | out, err := addSendRes.Attach(ctx)
107 | if err != nil {
108 | slog.Error("failed to attach response", "err", err.Error())
109 | os.Exit(1)
110 | }
111 | slog.Info("client attached response", "out.value", out.Value)
112 |
113 | out, err = counterClient.Get().Request(ctx, &helloworld.GetRequest{})
114 | if err != nil {
115 | slog.Error("failed to get response", "err", err.Error())
116 | os.Exit(1)
117 | }
118 | slog.Info("client get response", "out.value", out.Value)
119 |
120 | wf := helloworld.NewWorkflowIngressClient(client, "123")
121 | submitRes, err := wf.Submit(ctx, &helloworld.RunRequest{})
122 | if err != nil {
123 | slog.Error("failed to submit workflow", "err", err.Error())
124 | os.Exit(1)
125 | }
126 | slog.Info("started wf with invocation id " + addSendRes.Id())
127 |
128 | _, err = wf.Finish().Request(ctx, &helloworld.FinishRequest{})
129 | if err != nil {
130 | slog.Error("failed to finish workflow", "err", err.Error())
131 | os.Exit(1)
132 | }
133 |
134 | wfOut, err := submitRes.Attach(ctx)
135 | if err != nil {
136 | slog.Error("failed to attach response", "err", err.Error())
137 | os.Exit(1)
138 | }
139 | slog.Info("client attached response", "out.value", wfOut.Status)
140 | }()
141 |
142 | if err := server.Start(context.Background(), ":9080"); err != nil {
143 | slog.Error("application exited unexpectedly", "err", err.Error())
144 | os.Exit(1)
145 | }
146 | }
147 |
--------------------------------------------------------------------------------
/ingress/send_requester.go:
--------------------------------------------------------------------------------
1 | package ingress
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/restatedev/sdk-go/encoding"
7 | "github.com/restatedev/sdk-go/internal/ingress"
8 | "github.com/restatedev/sdk-go/internal/options"
9 | )
10 |
11 | // SimpleSendResponse represents the result of a send-only invocation (fire-and-forget).
12 | // It provides the invocation ID and status without requiring the output type parameter.
13 | //
14 | // If you need to attach to the invocation later to retrieve its output, you can:
15 | // 1. Create an InvocationHandle using InvocationById with the Id() from this response, or
16 | // 2. Use Service/Object/Workflow functions instead of ServiceSend/ObjectSend/WorkflowSend,
17 | // which return a full SendResponse[O] that includes an InvocationHandle.
18 | type SimpleSendResponse interface {
19 | Id() string
20 | Status() string
21 | }
22 |
23 | // SendRequester is a simplified version of Requester that only supports Send operations (fire-and-forget).
24 | // Unlike Requester, it does not require specifying the output type parameter, making it useful when you
25 | // don't need to retrieve the invocation result.
26 | //
27 | // If you need to later retrieve the output, use InvocationById with the Id() from SimpleSendResponse,
28 | // or use Service/Object/Workflow functions instead which return SendResponse[O] with an InvocationHandle.
29 | type SendRequester[I any] interface {
30 | Send(ctx context.Context, input I, options ...options.IngressSendOption) (SimpleSendResponse, error)
31 | }
32 |
33 | // ServiceSend gets a send-only ingress client for a Restate service handler.
34 | //
35 | // This is a simplified version of Service that doesn't require the output type generic parameter.
36 | // Use this when you only need to fire-and-forget invocations and don't need to retrieve results.
37 | //
38 | // Example:
39 | //
40 | // requester := ingress.ServiceSend[*MyInput](client, "MyService", "myHandler")
41 | // response, err := requester.Send(ctx, &MyInput{...})
42 | // fmt.Println("Invocation ID:", response.Id())
43 | func ServiceSend[I any](c *Client, serviceName, handlerName string) SendRequester[I] {
44 | return sendRequester[I]{
45 | client: c,
46 | params: ingress.IngressParams{
47 | Service: serviceName,
48 | Handler: handlerName,
49 | },
50 | }
51 | }
52 |
53 | // ObjectSend gets a send-only ingress client for a Restate virtual object handler.
54 | //
55 | // This is a simplified version of Object that doesn't require the output type generic parameter.
56 | // Use this when you only need to fire-and-forget invocations and don't need to retrieve results.
57 | //
58 | // Example:
59 | //
60 | // requester := ingress.ObjectSend[*MyInput](client, "MyObject", "object-123", "myHandler")
61 | // response, err := requester.Send(ctx, &MyInput{...})
62 | // fmt.Println("Invocation ID:", response.Id())
63 | func ObjectSend[I any](c *Client, serviceName, objectKey, handlerName string) SendRequester[I] {
64 | return sendRequester[I]{
65 | client: c,
66 | params: ingress.IngressParams{
67 | Service: serviceName,
68 | Key: objectKey,
69 | Handler: handlerName,
70 | },
71 | }
72 | }
73 |
74 | // WorkflowSend gets a send-only ingress client for a Restate workflow handler.
75 | //
76 | // This is a simplified version of Workflow that doesn't require the output type generic parameter.
77 | // Use this when you only need to fire-and-forget invocations and don't need to retrieve results.
78 | //
79 | // Example:
80 | //
81 | // requester := ingress.WorkflowSend[*MyInput](client, "MyWorkflow", "workflow-123", "myHandler")
82 | // response, err := requester.Send(ctx, &MyInput{...})
83 | // fmt.Println("Invocation ID:", response.Id())
84 | func WorkflowSend[I any](c *Client, serviceName, workflowID, handlerName string) SendRequester[I] {
85 | return sendRequester[I]{
86 | client: c,
87 | params: ingress.IngressParams{
88 | Service: serviceName,
89 | Handler: handlerName,
90 | Key: workflowID,
91 | },
92 | }
93 | }
94 |
95 | type sendRequester[I any] struct {
96 | client *Client
97 | params ingress.IngressParams
98 | codec encoding.PayloadCodec
99 | }
100 |
101 | type simpleSendResponse struct {
102 | ingress.Invocation
103 | }
104 |
105 | func (s simpleSendResponse) Id() string {
106 | return s.Invocation.Id
107 | }
108 |
109 | func (s simpleSendResponse) Status() string {
110 | return s.Invocation.Status
111 | }
112 |
113 | // Send calls the ingress API with the given input and returns an Invocation instance.
114 | func (c sendRequester[I]) Send(ctx context.Context, input I, opts ...options.IngressSendOption) (SimpleSendResponse, error) {
115 | sendOpts := options.IngressSendOptions{}
116 | sendOpts.Codec = c.codec
117 | for _, opt := range opts {
118 | opt.BeforeIngressSend(&sendOpts)
119 | }
120 |
121 | inv, err := c.client.Send(ctx, c.params, input, sendOpts)
122 | if err != nil {
123 | return nil, err
124 | }
125 |
126 | return simpleSendResponse{inv}, nil
127 | }
128 |
--------------------------------------------------------------------------------
/internal/restatecontext/async_results.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "fmt"
5 | "sync"
6 | "sync/atomic"
7 |
8 | "github.com/restatedev/sdk-go/internal/errors"
9 | pbinternal "github.com/restatedev/sdk-go/internal/generated"
10 | "github.com/restatedev/sdk-go/internal/statemachine"
11 | )
12 |
13 | var CancelledFailureValue = func() statemachine.Value {
14 | failure := pbinternal.Failure{}
15 | failure.SetCode(409)
16 | failure.SetMessage("Cancelled")
17 | return statemachine.ValueFailure{Failure: &failure}
18 | }()
19 |
20 | func errorFromFailure(failure statemachine.ValueFailure) error {
21 | return &errors.CodeError{Inner: &errors.TerminalError{Inner: fmt.Errorf("%s", failure.Failure.GetMessage())}, Code: errors.Code(failure.Failure.GetCode())}
22 | }
23 |
24 | type Selectable interface {
25 | handle() uint32
26 | }
27 |
28 | type asyncResult struct {
29 | ctx *ctx
30 | coreHandle uint32
31 | poll sync.Once
32 | result atomic.Value // statemachine.Value
33 | }
34 |
35 | func newAsyncResult(ctx *ctx, handle uint32) asyncResult {
36 | return asyncResult{
37 | ctx: ctx,
38 | coreHandle: handle,
39 | }
40 | }
41 |
42 | func (a *asyncResult) handle() uint32 {
43 | return a.coreHandle
44 | }
45 |
46 | func (a *asyncResult) pollProgress() {
47 | if a.result.Load() != nil {
48 | return
49 | }
50 | a.poll.Do(func() {
51 | cancelled := a.ctx.pollProgress([]uint32{a.coreHandle})
52 | if cancelled {
53 | a.result.Store(CancelledFailureValue)
54 | } else {
55 | value, err := a.ctx.stateMachine.TakeNotification(a.ctx, a.coreHandle)
56 | if value == nil {
57 | panic("The value should not be nil anymore")
58 | }
59 | if err != nil {
60 | panic(err)
61 | }
62 | a.result.Store(value)
63 | }
64 | })
65 | }
66 |
67 | func (a *asyncResult) mustLoadValue() statemachine.Value {
68 | value := a.result.Load()
69 | if value == nil {
70 | panic("value is not expected to be nil at this point")
71 | }
72 | return value.(statemachine.Value)
73 | }
74 |
75 | func (a *asyncResult) pollProgressAndLoadValue() statemachine.Value {
76 | a.pollProgress()
77 | return a.mustLoadValue()
78 | }
79 |
80 | func (restateCtx *ctx) pollProgress(handles []uint32) bool {
81 | // Pump output once
82 | if err := takeOutputAndWriteOut(restateCtx, restateCtx.stateMachine, restateCtx.conn); err != nil {
83 | panic(err)
84 | }
85 |
86 | for {
87 | progressResult, err := restateCtx.stateMachine.DoProgress(restateCtx, handles)
88 | if err != nil {
89 | panic(err)
90 | }
91 | if _, ok := progressResult.(statemachine.DoProgressAnyCompleted); ok {
92 | return false
93 | }
94 | _, isPendingRun := progressResult.(statemachine.DoProgressWaitingPendingRun)
95 | _, isReadFromInput := progressResult.(statemachine.DoProgressReadFromInput)
96 | if isPendingRun || isReadFromInput {
97 | // Either wait for at least one read or for run proposals
98 | select {
99 | case readRes, ok := <-restateCtx.readChan:
100 | if !ok {
101 | // Got EOF, notify and break
102 | if err = restateCtx.stateMachine.NotifyInputClosed(restateCtx); err != nil {
103 | panic(err)
104 | }
105 | break
106 | }
107 | if err = restateCtx.stateMachine.NotifyInput(restateCtx, readRes.buf[0:readRes.nRead]); err != nil {
108 | panic(err)
109 | }
110 | BufPool.Put(readRes.buf)
111 | break
112 | case proposal := <-restateCtx.runClosureCompletions:
113 | // Propose completion
114 | if err := restateCtx.stateMachine.ProposeRunCompletion(restateCtx, proposal); err != nil {
115 | panic(err)
116 | }
117 |
118 | // Pump output once. This is needed for the run completion to be effectively written
119 | if err := takeOutputAndWriteOut(restateCtx, restateCtx.stateMachine, restateCtx.conn); err != nil {
120 | panic(err)
121 | }
122 | break
123 | case <-restateCtx.Done():
124 | panic(restateCtx.Err())
125 | }
126 | }
127 | if _, ok := progressResult.(statemachine.DoProgressCancelSignalReceived); ok {
128 | // Pump output once. This is needed for cancel commands to be effectively written
129 | if err := takeOutputAndWriteOut(restateCtx, restateCtx.stateMachine, restateCtx.conn); err != nil {
130 | panic(err)
131 | }
132 |
133 | return true
134 | }
135 | if executeRun, ok := progressResult.(statemachine.DoProgressExecuteRun); ok {
136 | closure, ok := restateCtx.runClosures[executeRun.Handle]
137 | if !ok {
138 | panic(fmt.Sprintf("Need to run a Run closure with coreHandle %d, but it doesn't exist. This is an SDK bug.", executeRun.Handle))
139 | }
140 |
141 | // Delete this closure from the running list
142 | delete(restateCtx.runClosures, executeRun.Handle)
143 |
144 | // Run closure in a separate goroutine, proposing the result to runClosureCompletions
145 | go func(runClosureCompletions chan *pbinternal.VmProposeRunCompletionParameters, closure func() *pbinternal.VmProposeRunCompletionParameters) {
146 | runClosureCompletions <- closure()
147 | }(restateCtx.runClosureCompletions, closure)
148 | }
149 | }
150 | }
151 |
--------------------------------------------------------------------------------
/mocks/mock_WaitIterator.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.52.2. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import (
6 | restatecontext "github.com/restatedev/sdk-go/internal/restatecontext"
7 | mock "github.com/stretchr/testify/mock"
8 | )
9 |
10 | // MockWaitIterator is an autogenerated mock type for the WaitIterator type
11 | type MockWaitIterator struct {
12 | mock.Mock
13 | }
14 |
15 | type MockWaitIterator_Expecter struct {
16 | mock *mock.Mock
17 | }
18 |
19 | func (_m *MockWaitIterator) EXPECT() *MockWaitIterator_Expecter {
20 | return &MockWaitIterator_Expecter{mock: &_m.Mock}
21 | }
22 |
23 | // Err provides a mock function with no fields
24 | func (_m *MockWaitIterator) Err() error {
25 | ret := _m.Called()
26 |
27 | if len(ret) == 0 {
28 | panic("no return value specified for Err")
29 | }
30 |
31 | var r0 error
32 | if rf, ok := ret.Get(0).(func() error); ok {
33 | r0 = rf()
34 | } else {
35 | r0 = ret.Error(0)
36 | }
37 |
38 | return r0
39 | }
40 |
41 | // MockWaitIterator_Err_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Err'
42 | type MockWaitIterator_Err_Call struct {
43 | *mock.Call
44 | }
45 |
46 | // Err is a helper method to define mock.On call
47 | func (_e *MockWaitIterator_Expecter) Err() *MockWaitIterator_Err_Call {
48 | return &MockWaitIterator_Err_Call{Call: _e.mock.On("Err")}
49 | }
50 |
51 | func (_c *MockWaitIterator_Err_Call) Run(run func()) *MockWaitIterator_Err_Call {
52 | _c.Call.Run(func(args mock.Arguments) {
53 | run()
54 | })
55 | return _c
56 | }
57 |
58 | func (_c *MockWaitIterator_Err_Call) Return(_a0 error) *MockWaitIterator_Err_Call {
59 | _c.Call.Return(_a0)
60 | return _c
61 | }
62 |
63 | func (_c *MockWaitIterator_Err_Call) RunAndReturn(run func() error) *MockWaitIterator_Err_Call {
64 | _c.Call.Return(run)
65 | return _c
66 | }
67 |
68 | // Next provides a mock function with no fields
69 | func (_m *MockWaitIterator) Next() bool {
70 | ret := _m.Called()
71 |
72 | if len(ret) == 0 {
73 | panic("no return value specified for Next")
74 | }
75 |
76 | var r0 bool
77 | if rf, ok := ret.Get(0).(func() bool); ok {
78 | r0 = rf()
79 | } else {
80 | r0 = ret.Get(0).(bool)
81 | }
82 |
83 | return r0
84 | }
85 |
86 | // MockWaitIterator_Next_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Next'
87 | type MockWaitIterator_Next_Call struct {
88 | *mock.Call
89 | }
90 |
91 | // Next is a helper method to define mock.On call
92 | func (_e *MockWaitIterator_Expecter) Next() *MockWaitIterator_Next_Call {
93 | return &MockWaitIterator_Next_Call{Call: _e.mock.On("Next")}
94 | }
95 |
96 | func (_c *MockWaitIterator_Next_Call) Run(run func()) *MockWaitIterator_Next_Call {
97 | _c.Call.Run(func(args mock.Arguments) {
98 | run()
99 | })
100 | return _c
101 | }
102 |
103 | func (_c *MockWaitIterator_Next_Call) Return(_a0 bool) *MockWaitIterator_Next_Call {
104 | _c.Call.Return(_a0)
105 | return _c
106 | }
107 |
108 | func (_c *MockWaitIterator_Next_Call) RunAndReturn(run func() bool) *MockWaitIterator_Next_Call {
109 | _c.Call.Return(run)
110 | return _c
111 | }
112 |
113 | // Value provides a mock function with no fields
114 | func (_m *MockWaitIterator) Value() restatecontext.Selectable {
115 | ret := _m.Called()
116 |
117 | if len(ret) == 0 {
118 | panic("no return value specified for Value")
119 | }
120 |
121 | var r0 restatecontext.Selectable
122 | if rf, ok := ret.Get(0).(func() restatecontext.Selectable); ok {
123 | r0 = rf()
124 | } else {
125 | if ret.Get(0) != nil {
126 | r0 = ret.Get(0).(restatecontext.Selectable)
127 | }
128 | }
129 |
130 | return r0
131 | }
132 |
133 | // MockWaitIterator_Value_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Value'
134 | type MockWaitIterator_Value_Call struct {
135 | *mock.Call
136 | }
137 |
138 | // Value is a helper method to define mock.On call
139 | func (_e *MockWaitIterator_Expecter) Value() *MockWaitIterator_Value_Call {
140 | return &MockWaitIterator_Value_Call{Call: _e.mock.On("Value")}
141 | }
142 |
143 | func (_c *MockWaitIterator_Value_Call) Run(run func()) *MockWaitIterator_Value_Call {
144 | _c.Call.Run(func(args mock.Arguments) {
145 | run()
146 | })
147 | return _c
148 | }
149 |
150 | func (_c *MockWaitIterator_Value_Call) Return(_a0 restatecontext.Selectable) *MockWaitIterator_Value_Call {
151 | _c.Call.Return(_a0)
152 | return _c
153 | }
154 |
155 | func (_c *MockWaitIterator_Value_Call) RunAndReturn(run func() restatecontext.Selectable) *MockWaitIterator_Value_Call {
156 | _c.Call.Return(run)
157 | return _c
158 | }
159 |
160 | // NewMockWaitIterator creates a new instance of MockWaitIterator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
161 | // The first argument is typically a *testing.T value.
162 | func NewMockWaitIterator(t interface {
163 | mock.TestingT
164 | Cleanup(func())
165 | }) *MockWaitIterator {
166 | mock := &MockWaitIterator{}
167 | mock.Mock.Test(t)
168 |
169 | t.Cleanup(func() { mock.AssertExpectations(t) })
170 |
171 | return mock
172 | }
173 |
--------------------------------------------------------------------------------
/mocks/mock_AwakeableFuture.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.52.1. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import mock "github.com/stretchr/testify/mock"
6 |
7 | /* moved to helpers.go
8 | // MockAwakeableFuture is an autogenerated mock type for the AwakeableFuture type
9 | type MockAwakeableFuture struct {
10 | mock.Mock
11 | }
12 | */
13 |
14 | type MockAwakeableFuture_Expecter struct {
15 | mock *mock.Mock
16 | }
17 |
18 | func (_m *MockAwakeableFuture) EXPECT() *MockAwakeableFuture_Expecter {
19 | return &MockAwakeableFuture_Expecter{mock: &_m.Mock}
20 | }
21 |
22 | // Id provides a mock function with no fields
23 | func (_m *MockAwakeableFuture) Id() string {
24 | ret := _m.Called()
25 |
26 | if len(ret) == 0 {
27 | panic("no return value specified for Id")
28 | }
29 |
30 | var r0 string
31 | if rf, ok := ret.Get(0).(func() string); ok {
32 | r0 = rf()
33 | } else {
34 | r0 = ret.Get(0).(string)
35 | }
36 |
37 | return r0
38 | }
39 |
40 | // MockAwakeableFuture_Id_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Id'
41 | type MockAwakeableFuture_Id_Call struct {
42 | *mock.Call
43 | }
44 |
45 | // Id is a helper method to define mock.On call
46 | func (_e *MockAwakeableFuture_Expecter) Id() *MockAwakeableFuture_Id_Call {
47 | return &MockAwakeableFuture_Id_Call{Call: _e.mock.On("Id")}
48 | }
49 |
50 | func (_c *MockAwakeableFuture_Id_Call) Run(run func()) *MockAwakeableFuture_Id_Call {
51 | _c.Call.Run(func(args mock.Arguments) {
52 | run()
53 | })
54 | return _c
55 | }
56 |
57 | func (_c *MockAwakeableFuture_Id_Call) Return(_a0 string) *MockAwakeableFuture_Id_Call {
58 | _c.Call.Return(_a0)
59 | return _c
60 | }
61 |
62 | func (_c *MockAwakeableFuture_Id_Call) RunAndReturn(run func() string) *MockAwakeableFuture_Id_Call {
63 | _c.Call.Return(run)
64 | return _c
65 | }
66 |
67 | // Result provides a mock function with given fields: output
68 | func (_m *MockAwakeableFuture) Result(output any) error {
69 | ret := _m.Called(output)
70 |
71 | if len(ret) == 0 {
72 | panic("no return value specified for Result")
73 | }
74 |
75 | var r0 error
76 | if rf, ok := ret.Get(0).(func(any) error); ok {
77 | r0 = rf(output)
78 | } else {
79 | r0 = ret.Error(0)
80 | }
81 |
82 | return r0
83 | }
84 |
85 | // MockAwakeableFuture_Result_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Result'
86 | type MockAwakeableFuture_Result_Call struct {
87 | *mock.Call
88 | }
89 |
90 | // Result is a helper method to define mock.On call
91 | // - output any
92 | func (_e *MockAwakeableFuture_Expecter) Result(output interface{}) *MockAwakeableFuture_Result_Call {
93 | return &MockAwakeableFuture_Result_Call{Call: _e.mock.On("Result", output)}
94 | }
95 |
96 | func (_c *MockAwakeableFuture_Result_Call) Run(run func(output any)) *MockAwakeableFuture_Result_Call {
97 | _c.Call.Run(func(args mock.Arguments) {
98 | run(args[0].(any))
99 | })
100 | return _c
101 | }
102 |
103 | func (_c *MockAwakeableFuture_Result_Call) Return(_a0 error) *MockAwakeableFuture_Result_Call {
104 | _c.Call.Return(_a0)
105 | return _c
106 | }
107 |
108 | func (_c *MockAwakeableFuture_Result_Call) RunAndReturn(run func(any) error) *MockAwakeableFuture_Result_Call {
109 | _c.Call.Return(run)
110 | return _c
111 | }
112 |
113 | // handle provides a mock function with no fields
114 | func (_m *MockAwakeableFuture) handle() uint32 {
115 | ret := _m.Called()
116 |
117 | if len(ret) == 0 {
118 | panic("no return value specified for handle")
119 | }
120 |
121 | var r0 uint32
122 | if rf, ok := ret.Get(0).(func() uint32); ok {
123 | r0 = rf()
124 | } else {
125 | r0 = ret.Get(0).(uint32)
126 | }
127 |
128 | return r0
129 | }
130 |
131 | // MockAwakeableFuture_handle_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'handle'
132 | type MockAwakeableFuture_handle_Call struct {
133 | *mock.Call
134 | }
135 |
136 | // handle is a helper method to define mock.On call
137 | func (_e *MockAwakeableFuture_Expecter) handle() *MockAwakeableFuture_handle_Call {
138 | return &MockAwakeableFuture_handle_Call{Call: _e.mock.On("handle")}
139 | }
140 |
141 | func (_c *MockAwakeableFuture_handle_Call) Run(run func()) *MockAwakeableFuture_handle_Call {
142 | _c.Call.Run(func(args mock.Arguments) {
143 | run()
144 | })
145 | return _c
146 | }
147 |
148 | func (_c *MockAwakeableFuture_handle_Call) Return(_a0 uint32) *MockAwakeableFuture_handle_Call {
149 | _c.Call.Return(_a0)
150 | return _c
151 | }
152 |
153 | func (_c *MockAwakeableFuture_handle_Call) RunAndReturn(run func() uint32) *MockAwakeableFuture_handle_Call {
154 | _c.Call.Return(run)
155 | return _c
156 | }
157 |
158 | // NewMockAwakeableFuture creates a new instance of MockAwakeableFuture. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
159 | // The first argument is typically a *testing.T value.
160 | func NewMockAwakeableFuture(t interface {
161 | mock.TestingT
162 | Cleanup(func())
163 | }) *MockAwakeableFuture {
164 | mock := &MockAwakeableFuture{}
165 | mock.Mock.Test(t)
166 |
167 | t.Cleanup(func() { mock.AssertExpectations(t) })
168 |
169 | return mock
170 | }
171 |
--------------------------------------------------------------------------------
/router.go:
--------------------------------------------------------------------------------
1 | package restate
2 |
3 | import (
4 | "github.com/restatedev/sdk-go/encoding"
5 | "github.com/restatedev/sdk-go/internal"
6 | "github.com/restatedev/sdk-go/internal/options"
7 | "github.com/restatedev/sdk-go/internal/restatecontext"
8 | )
9 |
10 | // ServiceDefinition is the set of methods implemented by both services and virtual objects
11 | type ServiceDefinition interface {
12 | // Name returns the name of the service described in this definition
13 | Name() string
14 | // Type returns the type of this service definition (Service or Virtual Object)
15 | Type() internal.ServiceType
16 | // Handlers returns the set of handlers associated with this service definition
17 | Handlers() map[string]restatecontext.Handler
18 | // GetOptions returns the configured options
19 | GetOptions() *options.ServiceDefinitionOptions
20 | // ConfigureHandler lets you customize the handler configuration, adding per handler options.
21 | // Panics if the handler doesn't exist.
22 | ConfigureHandler(name string, opts ...options.HandlerOption) ServiceDefinition
23 | }
24 |
25 | // serviceDefinition stores a list of handlers under a named service
26 | type serviceDefinition struct {
27 | name string
28 | handlers map[string]restatecontext.Handler
29 | options options.ServiceDefinitionOptions
30 | typ internal.ServiceType
31 | }
32 |
33 | var _ ServiceDefinition = &serviceDefinition{}
34 |
35 | func (r *serviceDefinition) Name() string {
36 | return r.name
37 | }
38 |
39 | func (r *serviceDefinition) Handlers() map[string]restatecontext.Handler {
40 | return r.handlers
41 | }
42 |
43 | func (r *serviceDefinition) GetOptions() *options.ServiceDefinitionOptions {
44 | return &r.options
45 | }
46 |
47 | func (r *serviceDefinition) Type() internal.ServiceType {
48 | return r.typ
49 | }
50 |
51 | func (r *serviceDefinition) ConfigureHandler(name string, opts ...options.HandlerOption) ServiceDefinition {
52 | handler := r.handlers[name]
53 | if handler == nil {
54 | panic("handler not found: " + name)
55 | }
56 | handlerOpts := handler.GetOptions()
57 | for _, opt := range opts {
58 | opt.BeforeHandler(handlerOpts)
59 | }
60 | return r
61 | }
62 |
63 | type service struct {
64 | serviceDefinition
65 | }
66 |
67 | // NewService creates a new named Service
68 | func NewService(name string, opts ...options.ServiceDefinitionOption) *service {
69 | o := options.ServiceDefinitionOptions{}
70 | for _, opt := range opts {
71 | opt.BeforeServiceDefinition(&o)
72 | }
73 | if o.DefaultCodec == nil {
74 | o.DefaultCodec = encoding.JSONCodec
75 | }
76 | if o.WorkflowRetention != nil {
77 | panic("Workflow retention can be set only for workflows")
78 | }
79 | return &service{
80 | serviceDefinition: serviceDefinition{
81 | name: name,
82 | handlers: make(map[string]restatecontext.Handler),
83 | options: o,
84 | typ: internal.ServiceType_SERVICE,
85 | },
86 | }
87 | }
88 |
89 | // Handler registers a new Service handler by name
90 | func (r *service) Handler(name string, handler restatecontext.Handler) *service {
91 | if handler.GetOptions().Codec == nil {
92 | handler.GetOptions().Codec = r.options.DefaultCodec
93 | }
94 | r.handlers[name] = handler
95 | return r
96 | }
97 |
98 | type object struct {
99 | serviceDefinition
100 | }
101 |
102 | // NewObject creates a new named Virtual Object
103 | func NewObject(name string, opts ...options.ServiceDefinitionOption) *object {
104 | o := options.ServiceDefinitionOptions{}
105 | for _, opt := range opts {
106 | opt.BeforeServiceDefinition(&o)
107 | }
108 | if o.DefaultCodec == nil {
109 | o.DefaultCodec = encoding.JSONCodec
110 | }
111 | if o.WorkflowRetention != nil {
112 | panic("Workflow retention can be set only for workflows")
113 | }
114 | return &object{
115 | serviceDefinition: serviceDefinition{
116 | name: name,
117 | handlers: make(map[string]restatecontext.Handler),
118 | options: o,
119 | typ: internal.ServiceType_VIRTUAL_OBJECT,
120 | },
121 | }
122 | }
123 |
124 | // Handler registers a new Virtual Object handler by name
125 | func (r *object) Handler(name string, handler restatecontext.Handler) *object {
126 | if handler.GetOptions().Codec == nil {
127 | handler.GetOptions().Codec = r.options.DefaultCodec
128 | }
129 | r.handlers[name] = handler
130 | return r
131 | }
132 |
133 | type workflow struct {
134 | serviceDefinition
135 | }
136 |
137 | // NewWorkflow creates a new named Workflow
138 | func NewWorkflow(name string, opts ...options.ServiceDefinitionOption) *workflow {
139 | o := options.ServiceDefinitionOptions{}
140 | for _, opt := range opts {
141 | opt.BeforeServiceDefinition(&o)
142 | }
143 | if o.DefaultCodec == nil {
144 | o.DefaultCodec = encoding.JSONCodec
145 | }
146 | return &workflow{
147 | serviceDefinition: serviceDefinition{
148 | name: name,
149 | handlers: make(map[string]restatecontext.Handler),
150 | options: o,
151 | typ: internal.ServiceType_WORKFLOW,
152 | },
153 | }
154 | }
155 |
156 | // Handler registers a new Workflow handler by name
157 | func (r *workflow) Handler(name string, handler restatecontext.Handler) *workflow {
158 | if handler.GetOptions().Codec == nil {
159 | handler.GetOptions().Codec = r.options.DefaultCodec
160 | }
161 | r.handlers[name] = handler
162 | return r
163 | }
164 |
--------------------------------------------------------------------------------
/mocks/mock_ResponseFuture.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.52.1. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import mock "github.com/stretchr/testify/mock"
6 |
7 | /* moved to helpers.go
8 | // MockResponseFuture is an autogenerated mock type for the ResponseFuture type
9 | type MockResponseFuture struct {
10 | mock.Mock
11 | }
12 | */
13 |
14 | type MockResponseFuture_Expecter struct {
15 | mock *mock.Mock
16 | }
17 |
18 | func (_m *MockResponseFuture) EXPECT() *MockResponseFuture_Expecter {
19 | return &MockResponseFuture_Expecter{mock: &_m.Mock}
20 | }
21 |
22 | // GetInvocationId provides a mock function with no fields
23 | func (_m *MockResponseFuture) GetInvocationId() string {
24 | ret := _m.Called()
25 |
26 | if len(ret) == 0 {
27 | panic("no return value specified for GetInvocationId")
28 | }
29 |
30 | var r0 string
31 | if rf, ok := ret.Get(0).(func() string); ok {
32 | r0 = rf()
33 | } else {
34 | r0 = ret.Get(0).(string)
35 | }
36 |
37 | return r0
38 | }
39 |
40 | // MockResponseFuture_GetInvocationId_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetInvocationId'
41 | type MockResponseFuture_GetInvocationId_Call struct {
42 | *mock.Call
43 | }
44 |
45 | // GetInvocationId is a helper method to define mock.On call
46 | func (_e *MockResponseFuture_Expecter) GetInvocationId() *MockResponseFuture_GetInvocationId_Call {
47 | return &MockResponseFuture_GetInvocationId_Call{Call: _e.mock.On("GetInvocationId")}
48 | }
49 |
50 | func (_c *MockResponseFuture_GetInvocationId_Call) Run(run func()) *MockResponseFuture_GetInvocationId_Call {
51 | _c.Call.Run(func(args mock.Arguments) {
52 | run()
53 | })
54 | return _c
55 | }
56 |
57 | func (_c *MockResponseFuture_GetInvocationId_Call) Return(_a0 string) *MockResponseFuture_GetInvocationId_Call {
58 | _c.Call.Return(_a0)
59 | return _c
60 | }
61 |
62 | func (_c *MockResponseFuture_GetInvocationId_Call) RunAndReturn(run func() string) *MockResponseFuture_GetInvocationId_Call {
63 | _c.Call.Return(run)
64 | return _c
65 | }
66 |
67 | // Response provides a mock function with given fields: output
68 | func (_m *MockResponseFuture) Response(output any) error {
69 | ret := _m.Called(output)
70 |
71 | if len(ret) == 0 {
72 | panic("no return value specified for Response")
73 | }
74 |
75 | var r0 error
76 | if rf, ok := ret.Get(0).(func(any) error); ok {
77 | r0 = rf(output)
78 | } else {
79 | r0 = ret.Error(0)
80 | }
81 |
82 | return r0
83 | }
84 |
85 | // MockResponseFuture_Response_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Response'
86 | type MockResponseFuture_Response_Call struct {
87 | *mock.Call
88 | }
89 |
90 | // Response is a helper method to define mock.On call
91 | // - output any
92 | func (_e *MockResponseFuture_Expecter) Response(output interface{}) *MockResponseFuture_Response_Call {
93 | return &MockResponseFuture_Response_Call{Call: _e.mock.On("Response", output)}
94 | }
95 |
96 | func (_c *MockResponseFuture_Response_Call) Run(run func(output any)) *MockResponseFuture_Response_Call {
97 | _c.Call.Run(func(args mock.Arguments) {
98 | run(args[0].(any))
99 | })
100 | return _c
101 | }
102 |
103 | func (_c *MockResponseFuture_Response_Call) Return(_a0 error) *MockResponseFuture_Response_Call {
104 | _c.Call.Return(_a0)
105 | return _c
106 | }
107 |
108 | func (_c *MockResponseFuture_Response_Call) RunAndReturn(run func(any) error) *MockResponseFuture_Response_Call {
109 | _c.Call.Return(run)
110 | return _c
111 | }
112 |
113 | // handle provides a mock function with no fields
114 | func (_m *MockResponseFuture) handle() uint32 {
115 | ret := _m.Called()
116 |
117 | if len(ret) == 0 {
118 | panic("no return value specified for handle")
119 | }
120 |
121 | var r0 uint32
122 | if rf, ok := ret.Get(0).(func() uint32); ok {
123 | r0 = rf()
124 | } else {
125 | r0 = ret.Get(0).(uint32)
126 | }
127 |
128 | return r0
129 | }
130 |
131 | // MockResponseFuture_handle_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'handle'
132 | type MockResponseFuture_handle_Call struct {
133 | *mock.Call
134 | }
135 |
136 | // handle is a helper method to define mock.On call
137 | func (_e *MockResponseFuture_Expecter) handle() *MockResponseFuture_handle_Call {
138 | return &MockResponseFuture_handle_Call{Call: _e.mock.On("handle")}
139 | }
140 |
141 | func (_c *MockResponseFuture_handle_Call) Run(run func()) *MockResponseFuture_handle_Call {
142 | _c.Call.Run(func(args mock.Arguments) {
143 | run()
144 | })
145 | return _c
146 | }
147 |
148 | func (_c *MockResponseFuture_handle_Call) Return(_a0 uint32) *MockResponseFuture_handle_Call {
149 | _c.Call.Return(_a0)
150 | return _c
151 | }
152 |
153 | func (_c *MockResponseFuture_handle_Call) RunAndReturn(run func() uint32) *MockResponseFuture_handle_Call {
154 | _c.Call.Return(run)
155 | return _c
156 | }
157 |
158 | // NewMockResponseFuture creates a new instance of MockResponseFuture. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
159 | // The first argument is typically a *testing.T value.
160 | func NewMockResponseFuture(t interface {
161 | mock.TestingT
162 | Cleanup(func())
163 | }) *MockResponseFuture {
164 | mock := &MockResponseFuture{}
165 | mock.Mock.Test(t)
166 |
167 | t.Cleanup(func() { mock.AssertExpectations(t) })
168 |
169 | return mock
170 | }
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://pkg.go.dev/github.com/restatedev/sdk-go)
2 | [](https://github.com/restatedev/sdk-go/actions/workflows/test.yaml)
3 |
4 | # Restate Go SDK
5 |
6 | [Restate](https://restate.dev/) is a system for easily building resilient applications using *distributed durable async/await*. This repository contains the Restate SDK for writing services in **Golang**.
7 |
8 | ## Community
9 |
10 | * 🤗️ [Join our online community](https://discord.gg/skW3AZ6uGd) for help, sharing feedback and talking to the community.
11 | * 📖 [Check out our documentation](https://docs.restate.dev) to get quickly started!
12 | * 📣 [Follow us on Twitter](https://twitter.com/restatedev) for staying up to date.
13 | * 🙋 [Create a GitHub issue](https://github.com/restatedev/sdk-java/issues) for requesting a new feature or reporting a problem.
14 | * 🏠 [Visit our GitHub org](https://github.com/restatedev) for exploring other repositories.
15 |
16 | ## Prerequisites
17 | - Go: >= 1.24.0
18 |
19 | ## Examples
20 |
21 | This repo contains an [example](examples) based on the [Ticket Reservation Service](https://github.com/restatedev/examples/tree/main/tutorials/tour-of-restate-go).
22 |
23 | You can also check a list of examples available here: https://github.com/restatedev/examples?tab=readme-ov-file#go
24 |
25 | ### How to use the example
26 |
27 | Download and run restate, as described here [v1.x](https://github.com/restatedev/restate/releases/)
28 |
29 | ```bash
30 | restate-server
31 | ```
32 |
33 | In another terminal run the example
34 |
35 | ```bash
36 | cd restate-sdk-go/example
37 | go run .
38 | ```
39 |
40 | In a third terminal register:
41 |
42 | ```bash
43 | restate deployments register http://localhost:9080
44 | ```
45 |
46 | And do the following steps
47 |
48 | - Add tickets to basket
49 |
50 | ```bash
51 | curl -v localhost:8080/UserSession/azmy/AddTicket \
52 | -H 'content-type: application/json' \
53 | -d '"ticket-1"'
54 |
55 | # true
56 | curl -v localhost:8080/UserSession/azmy/AddTicket \
57 | -H 'content-type: application/json' \
58 | -d '"ticket-2"'
59 | # true
60 | ```
61 |
62 | Trying adding the same tickets again should return `false` since they are already reserved. If you didn't check out the tickets in 15min (if you are impatient change the delay in code to make it shorter)
63 |
64 | - Check out
65 |
66 | ```bash
67 | curl localhost:8080/UserSession/azmy/Checkout
68 | # true
69 | ```
70 |
71 | ## Ingress SDK
72 |
73 | When you need to call restate handlers or attach to invocations from outside the restate context,
74 | use the [ingress SDK](examples/client/main.go).
75 |
76 | ## Versions
77 |
78 | This library follows [Semantic Versioning](https://semver.org/).
79 |
80 | The compatibility with Restate is described in the following table:
81 |
82 | | Restate Server\sdk-go | < 0.16 | 0.16 - 0.17 | 0.18 - 0.19 | 0.20 - 0.21 |
83 | |-----------------------|------------------|-------------|------------------|------------------|
84 | | < 1.3 | ✅ | ❌ | ❌ | ❌ |
85 | | 1.3 | ✅ | ✅ | ✅ (1) | ✅ (2) |
86 | | 1.4 | ✅ | ✅ | ✅ | ✅ (2) |
87 | | 1.5 | ⚠ (3) | ✅ | ✅ | ✅ |
88 |
89 | (1) **Note** `WithAbortTimeout`, `WithEnableLazyState`, `WithIdempotencyRetention`, `WithInactivityTimeout`, `WithIngressPrivate`, `WithJournalRetention` and `WithWorkflowRetention` work only from Restate 1.4 onward. Check the in-code documentation for more details.
90 |
91 | (1) **Note** `WithInvocationRetryPolicy` works only from Restate 1.5 onward. Check the in-code documentation for more details.
92 |
93 | (3) **Warning** SDK versions < 0.16 are deprecated, and cannot be registered anymore. Check the [Restate 1.5 release notes](https://github.com/restatedev/restate/releases/tag/v1.5.0) for more info.
94 |
95 | ## Contributing
96 |
97 | We’re excited if you join the Restate community and start contributing!
98 | Whether it is feature requests, bug reports, ideas & feedback or PRs, we appreciate any and all contributions.
99 | We know that your time is precious and, therefore, deeply value any effort to contribute!
100 |
101 | ### Internal core
102 |
103 | To rebuild the internal core:
104 |
105 | ```shell
106 | cd shared-core
107 | cargo build --release
108 | mv target/wasm32-unknown-unknown/release/shared_core_golang_wasm_binding.wasm ../internal/statemachine
109 | ```
110 |
111 | To regenerate the protobuf contract between core and SDK:
112 |
113 | ```shell
114 | buf generate --template internal.buf.gen.yaml
115 | ```
116 |
117 | ## Mockery mocks
118 | The `mock` package is mostly autogenerated but will require some very light editing. To generate run `mockery` in the root of this repo. Then check the git diff.
119 | Certain structs (`MockAfterFuture`, `MockAttachFuture`, `MockAwakeableFuture`, `MockDurablePromise`, `MockResponseFuture`) and functions (`NewMockClient`, `NewMockContext`)
120 | have been redefined in `helpers.go` and commented out in their respective files. Please continue this state of affairs if you regenerate mocks - not doing so will be a build
121 | error.
122 |
--------------------------------------------------------------------------------
/ingress/requester.go:
--------------------------------------------------------------------------------
1 | package ingress
2 |
3 | import (
4 | "context"
5 |
6 | restate "github.com/restatedev/sdk-go"
7 | "github.com/restatedev/sdk-go/encoding"
8 | "github.com/restatedev/sdk-go/internal/ingress"
9 | "github.com/restatedev/sdk-go/internal/options"
10 | )
11 |
12 | // Requester provides both synchronous (Request) and asynchronous (Send) invocation methods for Restate handlers.
13 | // It requires both input (I) and output (O) type parameters.
14 | //
15 | // Use Request to make a call and wait for the result.
16 | // Use Send to make a fire-and-forget call that returns immediately with a SendResponse
17 | // containing an InvocationHandle to retrieve the result later.
18 | type Requester[I any, O any] interface {
19 | // Request makes a synchronous invocation and blocks until the result is available.
20 | Request(ctx context.Context, input I, options ...options.IngressRequestOption) (O, error)
21 | // Send makes an asynchronous invocation and returns immediately with a handle to retrieve the result later.
22 | Send(ctx context.Context, input I, options ...options.IngressSendOption) (SendResponse[O], error)
23 | }
24 |
25 | // SendResponse is returned by Requester.Send and combines both SimpleSendResponse (for invocation metadata)
26 | // and InvocationHandle (for retrieving the output).
27 | //
28 | // You can use the embedded InvocationHandle methods (Attach/Output) to retrieve the invocation result,
29 | // or use the SimpleSendResponse methods (Id/Status) to get invocation metadata.
30 | type SendResponse[O any] interface {
31 | InvocationHandle[O]
32 | SimpleSendResponse
33 | }
34 |
35 | // Service gets an ingress client for a Restate service handler.
36 | // This returns a Requester that supports both Request and Send operations.
37 | //
38 | // Example:
39 | //
40 | // requester := ingress.Service[*MyInput, *MyOutput](client, "MyService", "myHandler")
41 | // // Call and wait for response:
42 | // output, err := requester.Request(ctx, &MyInput{...})
43 | // // Send request:
44 | // response, err := requester.Send(ctx, &MyInput{...})
45 | func Service[I any, O any](c *Client, serviceName, handlerName string) Requester[I, O] {
46 | return requester[I, O]{
47 | client: c,
48 | params: ingress.IngressParams{
49 | Service: serviceName,
50 | Handler: handlerName,
51 | },
52 | }
53 | }
54 |
55 | // Object gets an ingress client for a Restate virtual object handler.
56 | // This returns a Requester that supports both Request and Send operations.
57 | //
58 | // Example:
59 | //
60 | // requester := ingress.Object[*MyInput, *MyOutput](client, "MyObject", "object-123", "myHandler")
61 | // // Call and wait for response:
62 | // output, err := requester.Request(ctx, &MyInput{...})
63 | // // Send request:
64 | // response, err := requester.Send(ctx, &MyInput{...})
65 | func Object[I any, O any](c *Client, serviceName, objectKey, handlerName string) Requester[I, O] {
66 | return requester[I, O]{
67 | client: c,
68 | params: ingress.IngressParams{
69 | Service: serviceName,
70 | Key: objectKey,
71 | Handler: handlerName,
72 | },
73 | }
74 | }
75 |
76 | // Workflow gets an ingress client for a Restate workflow handler.
77 | // This returns a Requester that supports both Request and Send operations.
78 | //
79 | // Example:
80 | //
81 | // requester := ingress.Workflow[*MyInput, *MyOutput](client, "MyWorkflow", "workflow-123", "myHandler")
82 | // // Call and wait for response:
83 | // output, err := requester.Request(ctx, &MyInput{...})
84 | // // Send request:
85 | // response, err := requester.Send(ctx, &MyInput{...})
86 | func Workflow[I any, O any](c *Client, serviceName, workflowID, handlerName string) Requester[I, O] {
87 | return requester[I, O]{
88 | client: c,
89 | params: ingress.IngressParams{
90 | Service: serviceName,
91 | Handler: handlerName,
92 | Key: workflowID,
93 | },
94 | }
95 | }
96 |
97 | type requester[I any, O any] struct {
98 | client *Client
99 | params ingress.IngressParams
100 | codec encoding.PayloadCodec
101 | }
102 |
103 | func NewRequester[I any, O any](c *Client, serviceName, handlerName string, key *string, codec *encoding.PayloadCodec) Requester[I, O] {
104 | req := requester[I, O]{
105 | client: c,
106 | params: ingress.IngressParams{
107 | Service: serviceName,
108 | Handler: handlerName,
109 | },
110 | }
111 | if key != nil {
112 | req.params.Key = *key
113 | }
114 | if codec != nil {
115 | req.codec = *codec
116 | }
117 | return req
118 | }
119 |
120 | // Request calls the ingress API with the given input and returns the result.
121 | func (c requester[I, O]) Request(ctx context.Context, input I, opts ...options.IngressRequestOption) (O, error) {
122 | reqOpts := options.IngressRequestOptions{}
123 | reqOpts.Codec = c.codec
124 | for _, opt := range opts {
125 | opt.BeforeIngressRequest(&reqOpts)
126 | }
127 |
128 | var output O
129 | err := c.client.Request(ctx, c.params, input, &output, reqOpts)
130 | if err != nil {
131 | return output, err
132 | }
133 | return output, nil
134 | }
135 |
136 | type sendResponse[O any] struct {
137 | InvocationHandle[O]
138 | invocation ingress.Invocation
139 | }
140 |
141 | func (s *sendResponse[O]) Id() string {
142 | return s.invocation.Id
143 | }
144 |
145 | func (s *sendResponse[O]) Status() string {
146 | return s.invocation.Status
147 | }
148 |
149 | // Send calls the ingress API with the given input and returns an Invocation instance.
150 | func (c requester[I, O]) Send(ctx context.Context, input I, opts ...options.IngressSendOption) (SendResponse[O], error) {
151 | sendOpts := options.IngressSendOptions{}
152 | sendOpts.Codec = c.codec
153 | for _, opt := range opts {
154 | opt.BeforeIngressSend(&sendOpts)
155 | }
156 |
157 | inv, err := c.client.Send(ctx, c.params, input, sendOpts)
158 | if err != nil {
159 | return nil, err
160 | }
161 |
162 | return &sendResponse[O]{
163 | invocation: inv,
164 | InvocationHandle: InvocationById[O](c.client, inv.Id, restate.WithPayloadCodec(c.codec)),
165 | }, nil
166 | }
167 |
--------------------------------------------------------------------------------
/internal/restatecontext/run.go:
--------------------------------------------------------------------------------
1 | package restatecontext
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/restatedev/sdk-go/encoding"
7 | "github.com/restatedev/sdk-go/internal/errors"
8 | pbinternal "github.com/restatedev/sdk-go/internal/generated"
9 | "github.com/restatedev/sdk-go/internal/options"
10 | "github.com/restatedev/sdk-go/internal/statemachine"
11 | "log/slog"
12 | "runtime/debug"
13 | "time"
14 | )
15 |
16 | func (restateCtx *ctx) Run(fn func(ctx RunContext) (any, error), output any, opts ...options.RunOption) error {
17 | return restateCtx.RunAsync(fn, opts...).Result(output)
18 | }
19 |
20 | func (restateCtx *ctx) RunAsync(fn func(ctx RunContext) (any, error), opts ...options.RunOption) RunAsyncFuture {
21 | o := options.RunOptions{}
22 | for _, opt := range opts {
23 | opt.BeforeRun(&o)
24 | }
25 | if o.Codec == nil {
26 | o.Codec = encoding.JSONCodec
27 | }
28 |
29 | params := pbinternal.VmSysRunParameters{}
30 | params.SetName(o.Name)
31 |
32 | handle, err := restateCtx.stateMachine.SysRun(restateCtx, o.Name)
33 | if err != nil {
34 | panic(err)
35 | }
36 | restateCtx.checkStateTransition()
37 |
38 | restateCtx.runClosures[handle] = func() *pbinternal.VmProposeRunCompletionParameters {
39 | now := time.Now()
40 |
41 | // Run the user closure
42 | output, err := runWrapPanic(fn)(runContext{Context: restateCtx, log: restateCtx.userLogger, request: &restateCtx.request})
43 |
44 | // Let's prepare the proposal of the run completion
45 | proposal := pbinternal.VmProposeRunCompletionParameters{}
46 | proposal.SetHandle(handle)
47 | proposal.SetAttemptDurationMillis(uint64(time.Now().Sub(now).Milliseconds()))
48 |
49 | // Set retry policy if any of the retry policy config options are set
50 | if o.MaxRetryAttempts != nil || o.MaxRetryInterval != nil || o.MaxRetryDuration != nil || o.RetryIntervalFactor != nil || o.InitialRetryInterval != nil {
51 | retryPolicy := pbinternal.VmProposeRunCompletionParameters_RetryPolicy{}
52 | retryPolicy.SetInitialInternalMillis(50)
53 | retryPolicy.SetFactor(2)
54 | retryPolicy.SetMaxIntervalMillis(2000)
55 |
56 | if o.MaxRetryDuration != nil {
57 | retryPolicy.SetMaxDurationMillis(uint64((*o.MaxRetryDuration).Milliseconds()))
58 | }
59 | if o.MaxRetryInterval != nil {
60 | retryPolicy.SetMaxIntervalMillis(uint64((*o.MaxRetryInterval).Milliseconds()))
61 | }
62 | if o.RetryIntervalFactor != nil {
63 | retryPolicy.SetFactor(*o.RetryIntervalFactor)
64 | }
65 | if o.MaxRetryAttempts != nil {
66 | retryPolicy.SetMaxAttempts(uint32(*o.MaxRetryAttempts))
67 | }
68 | if o.InitialRetryInterval != nil {
69 | retryPolicy.SetInitialInternalMillis(uint64((*o.InitialRetryInterval).Milliseconds()))
70 | }
71 | proposal.SetRetryPolicy(&retryPolicy)
72 | }
73 |
74 | if errors.IsTerminalError(err) {
75 | // Terminal error
76 | failure := pbinternal.Failure{}
77 | failure.SetCode(uint32(errors.ErrorCode(err)))
78 | failure.SetMessage(err.Error())
79 | proposal.SetTerminalFailure(&failure)
80 | } else if err != nil {
81 | // Retryable error
82 | failure := pbinternal.FailureWithStacktrace{}
83 | failure.SetCode(uint32(errors.ErrorCode(err)))
84 | failure.SetMessage(err.Error())
85 | proposal.SetRetryableFailure(&failure)
86 | } else {
87 | // Success
88 | bytes, err := encoding.Marshal(o.Codec, output)
89 | if err != nil {
90 | panic(fmt.Errorf("failed to marshal Run output: %w", err))
91 | }
92 |
93 | proposal.SetSuccess(bytes)
94 | }
95 |
96 | return &proposal
97 | }
98 |
99 | return &runAsyncFuture{
100 | asyncResult: newAsyncResult(restateCtx, handle),
101 | codec: o.Codec,
102 | }
103 | }
104 |
105 | func runWrapPanic(fn func(ctx RunContext) (any, error)) func(ctx RunContext) (any, error) {
106 | return func(ctx RunContext) (res any, err error) {
107 | defer func() {
108 | recovered := recover()
109 |
110 | switch typ := recovered.(type) {
111 | case nil:
112 | // nothing to do, just exit
113 | break
114 | case *statemachine.SuspensionError:
115 | case statemachine.SuspensionError:
116 | err = typ
117 | break
118 | default:
119 | err = fmt.Errorf("panic occurred while executing Run: %s\nStack: %s", fmt.Sprint(typ), string(debug.Stack()))
120 | break
121 | }
122 | }()
123 | res, err = fn(ctx)
124 | return
125 | }
126 | }
127 |
128 | type RunAsyncFuture interface {
129 | Selectable
130 | Result(output any) error
131 | }
132 |
133 | type runAsyncFuture struct {
134 | asyncResult
135 | codec encoding.Codec
136 | }
137 |
138 | func (d *runAsyncFuture) Result(output any) error {
139 | switch result := d.pollProgressAndLoadValue().(type) {
140 | case statemachine.ValueSuccess:
141 | {
142 | if err := encoding.Unmarshal(d.codec, result.Success, output); err != nil {
143 | panic(fmt.Errorf("failed to unmarshal runAsync result into output: %w", err))
144 | }
145 | return nil
146 | }
147 | case statemachine.ValueFailure:
148 | return errorFromFailure(result)
149 | default:
150 | panic(fmt.Errorf("unexpected value %s", result))
151 |
152 | }
153 | }
154 |
155 | // RunContext is passed to [Run] closures and provides the limited set of Restate operations that are safe to use there.
156 | type RunContext interface {
157 | context.Context
158 |
159 | // Log obtains a coreHandle on a slog.Logger which already has some useful fields (invocationID and method)
160 | // By default, this logger will not output messages if the invocation is currently replaying
161 | // The log handler can be set with `.WithLogger()` on the server object
162 | Log() *slog.Logger
163 |
164 | // Request gives extra information about the request that started this invocation
165 | Request() *Request
166 | }
167 |
168 | type runContext struct {
169 | context.Context
170 | log *slog.Logger
171 | request *Request
172 | }
173 |
174 | func (r runContext) Log() *slog.Logger { return r.log }
175 | func (r runContext) Request() *Request { return r.request }
176 |
--------------------------------------------------------------------------------
/proto/internal.proto:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2025 - Restate Software, Inc., Restate GmbH
3 | *
4 | * This file is part of the Restate SDK for Go,
5 | * which is released under the MIT license.
6 | *
7 | * You can find a copy of the license in file LICENSE in the root
8 | * directory of this repository or package, or at
9 | * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10 | */
11 |
12 | syntax = "proto3";
13 |
14 | option go_package = "github.com/restatedev/sdk-go/internal/generated";
15 |
16 | message Empty {}
17 |
18 | message Header {
19 | string key = 1;
20 | string value = 2;
21 | }
22 |
23 | message Failure {
24 | uint32 code = 1;
25 | string message = 2;
26 | }
27 |
28 | message FailureWithStacktrace {
29 | uint32 code = 1;
30 | string message = 2;
31 | string stacktrace = 3;
32 | }
33 |
34 | message VmNewParameters {
35 | repeated Header headers = 1;
36 | }
37 |
38 | message VmNewReturn {
39 | oneof result {
40 | uint32 pointer = 1;
41 | Failure failure = 2;
42 | }
43 | }
44 |
45 | message VmGetResponseHeadReturn {
46 | uint32 status_code = 1;
47 | repeated Header headers = 2;
48 | }
49 |
50 | message VmNotifyError {
51 | string message = 1;
52 | string stacktrace = 2;
53 | }
54 |
55 | message VmTakeOutputReturn {
56 | oneof result {
57 | bytes bytes = 1;
58 | Empty EOF = 2;
59 | }
60 | }
61 |
62 | message VmIsReadyToExecuteReturn {
63 | oneof result {
64 | bool ready = 1;
65 | Failure failure = 2;
66 | }
67 | }
68 |
69 | message VmDoProgressParameters {
70 | repeated uint32 handles = 1;
71 | }
72 |
73 | message VmDoProgressReturn {
74 | oneof result {
75 | Empty any_completed = 1;
76 | Empty read_from_input = 2;
77 | Empty waiting_pending_run = 3;
78 | uint32 execute_run = 4;
79 | Empty cancel_signal_received = 5;
80 | Empty suspended = 6;
81 | Failure failure = 7;
82 | }
83 | }
84 |
85 | message Value {
86 | message StateKeys {
87 | repeated string keys = 1;
88 | }
89 |
90 | oneof value {
91 | Empty void = 1;
92 | bytes success = 2;
93 | Failure failure = 3;
94 | StateKeys state_keys = 4;
95 | string invocation_id = 5;
96 | }
97 | }
98 |
99 | message VmTakeNotificationReturn {
100 | oneof result {
101 | Empty not_ready = 1;
102 | Value value = 2;
103 | Empty suspended = 3;
104 | Failure failure = 4;
105 | }
106 | }
107 |
108 | message VmSysInputReturn {
109 | message Input {
110 | string invocation_id = 1;
111 | string key = 2;
112 | repeated Header headers = 3;
113 | bytes input = 4;
114 | uint64 random_seed = 5;
115 | bool should_use_random_seed = 6;
116 | }
117 |
118 | oneof result {
119 | Input ok = 1;
120 | Failure failure = 2;
121 | }
122 | }
123 |
124 | message VmSysStateGetParameters {
125 | string key = 1;
126 | }
127 |
128 | message VmSysStateSetParameters {
129 | string key = 1;
130 | bytes value = 2;
131 | }
132 |
133 | message VmSysStateClearParameters {
134 | string key = 1;
135 | }
136 |
137 | message VmSysSleepParameters {
138 | uint64 wake_up_time_since_unix_epoch_millis = 1;
139 | uint64 now_since_unix_epoch_millis = 2;
140 | string name = 3;
141 | }
142 |
143 | message VmSysAwakeableReturn {
144 | message Awakeable {
145 | string id = 1;
146 | uint32 handle = 2;
147 | }
148 |
149 | oneof result {
150 | Awakeable ok = 1;
151 | Failure failure = 2;
152 | }
153 | }
154 |
155 | message VmSysCompleteAwakeableParameters {
156 | string id = 1;
157 |
158 | oneof result {
159 | bytes success = 2;
160 | Failure failure = 3;
161 | }
162 | }
163 |
164 | message VmSysCallParameters {
165 | string service = 1;
166 | string handler = 2;
167 | optional string key = 3;
168 | optional string idempotency_key = 4;
169 | repeated Header headers = 5;
170 |
171 | bytes input = 6;
172 | }
173 |
174 | message VmSysCallReturn {
175 | message CallHandles {
176 | uint32 invocation_id_handle = 1;
177 | uint32 result_handle = 2;
178 | }
179 |
180 | oneof result {
181 | CallHandles ok = 1;
182 | Failure failure = 2;
183 | }
184 | }
185 |
186 | message VmSysSendParameters {
187 | string service = 1;
188 | string handler = 2;
189 | optional string key = 3;
190 | optional string idempotency_key = 4;
191 | repeated Header headers = 5;
192 |
193 | bytes input = 6;
194 |
195 | optional uint64 execution_time_since_unix_epoch_millis = 7;
196 | }
197 |
198 | message VmSysCancelInvocation {
199 | string invocation_id = 1;
200 | }
201 |
202 | message VmSysAttachInvocation {
203 | string invocation_id = 1;
204 | }
205 |
206 | message VmSysPromiseGetParameters {
207 | string key = 1;
208 | }
209 |
210 | message VmSysPromisePeekParameters {
211 | string key = 1;
212 | }
213 |
214 | message VmSysPromiseCompleteParameters {
215 | string id = 1;
216 |
217 | oneof result {
218 | bytes success = 2;
219 | Failure failure = 3;
220 | }
221 | }
222 |
223 | message VmSysRunParameters {
224 | string name = 1;
225 | }
226 |
227 | message VmProposeRunCompletionParameters {
228 | message RetryPolicy {
229 | uint64 initial_internal_millis = 1;
230 | float factor = 2;
231 | optional uint64 max_interval_millis = 3;
232 | optional uint32 max_attempts = 4;
233 | optional uint64 max_duration_millis = 5;
234 | }
235 |
236 | uint32 handle = 1;
237 |
238 | oneof result {
239 | bytes success = 2;
240 | Failure terminal_failure = 3;
241 | FailureWithStacktrace retryable_failure = 4;
242 | }
243 |
244 | RetryPolicy retry_policy = 5;
245 | uint64 attempt_duration_millis = 6;
246 | }
247 |
248 | message VmSysWriteOutputParameters {
249 | oneof result {
250 | bytes success = 1;
251 | Failure failure = 2;
252 | }
253 | }
254 |
255 | message SimpleSysAsyncResultReturn {
256 | oneof result {
257 | uint32 handle = 1;
258 | Failure failure = 2;
259 | }
260 | }
261 |
262 | message GenericEmptyReturn {
263 | oneof result {
264 | Empty ok = 1;
265 | Failure failure = 2;
266 | }
267 | }
--------------------------------------------------------------------------------