├── .gitignore ├── .mapping.json ├── .piglet-meta.json ├── AUTHORS.md ├── CONTRIBUTING.md ├── CREDITS.md ├── Dockerfile ├── LICENSE ├── README.md ├── docker ├── auto-setup.sh └── config │ └── config_template.yaml ├── go.mod ├── go.sum ├── persistence └── pkg │ ├── base │ ├── errors │ │ └── errors.go │ ├── executor │ │ └── factory.go │ ├── mss │ │ └── mss.go │ ├── rows │ │ ├── const.go │ │ └── rows.go │ └── tokens │ │ └── tokens.go │ ├── cache │ ├── basecache.go │ ├── cache.go │ ├── errors.go │ ├── factory.go │ ├── metrics.go │ ├── range.go │ ├── scheduledcache.go │ └── tasks.go │ └── ydb │ ├── cluster_metadata_store.go │ ├── config │ └── config.go │ ├── conn │ ├── client.go │ ├── const.go │ ├── errors.go │ ├── executor.go │ ├── log │ │ ├── adapter.go │ │ └── traces.go │ ├── metrics │ │ └── metrics.go │ ├── query.go │ ├── ts.go │ └── util.go │ ├── const.go │ ├── execution_store.go │ ├── factory.go │ ├── history_store.go │ ├── matching_task_store.go │ ├── metadata_store.go │ ├── mirroring_cluster_metadata_store.go │ ├── mirroring_matching_task_store.go │ ├── mirroring_metadata_store.go │ ├── mutable_state_store.go │ ├── mutable_state_task_store.go │ ├── nexus_store.go │ ├── queue_store.go │ ├── queue_store_v2.go │ ├── rows │ ├── assertions.go │ ├── factory.go │ ├── helper.go │ └── transaction.go │ └── shard_store.go ├── schema └── temporal │ ├── 0001_initial.sql │ ├── 0002_nexus.sql │ └── 0003_queue_v2.sql └── temporal-over-ydb └── cmd ├── migrator ├── main.go ├── ping.go ├── root.go └── up.go └── server └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | /temporal 2 | -------------------------------------------------------------------------------- /.mapping.json: -------------------------------------------------------------------------------- 1 | { 2 | ".":"infra/temporal/temporal-over-ydb/github_toplevel", 3 | ".gitignore":"infra/temporal/temporal-over-ydb/github_toplevel/.gitignore", 4 | "AUTHORS.md":"infra/temporal/temporal-over-ydb/github_toplevel/AUTHORS.md", 5 | "CONTRIBUTING.md":"infra/temporal/temporal-over-ydb/github_toplevel/CONTRIBUTING.md", 6 | "CREDITS.md":"infra/temporal/temporal-over-ydb/github_toplevel/CREDITS.md", 7 | "Dockerfile":"infra/temporal/temporal-over-ydb/github_toplevel/Dockerfile", 8 | "LICENSE":"infra/temporal/temporal-over-ydb/github_toplevel/LICENSE", 9 | "README.md":"infra/temporal/temporal-over-ydb/github_toplevel/README.md", 10 | "docker/auto-setup.sh":"infra/temporal/temporal-over-ydb/github_toplevel/docker/auto-setup.sh", 11 | "docker/config/config_template.yaml":"infra/temporal/temporal-over-ydb/github_toplevel/docker/config/config_template.yaml", 12 | "go.mod":"", 13 | "go.sum":"", 14 | "persistence/pkg/base/errors/errors.go":"infra/temporal/persistence/pkg/base/errors/errors.go", 15 | "persistence/pkg/base/executor/factory.go":"infra/temporal/persistence/pkg/base/executor/factory.go", 16 | "persistence/pkg/base/mss/mss.go":"infra/temporal/persistence/pkg/base/mss/mss.go", 17 | "persistence/pkg/base/rows/const.go":"infra/temporal/persistence/pkg/base/rows/const.go", 18 | "persistence/pkg/base/rows/rows.go":"infra/temporal/persistence/pkg/base/rows/rows.go", 19 | "persistence/pkg/base/tokens/tokens.go":"infra/temporal/persistence/pkg/base/tokens/tokens.go", 20 | "persistence/pkg/cache/basecache.go":"infra/temporal/persistence/pkg/cache/basecache.go", 21 | "persistence/pkg/cache/cache.go":"infra/temporal/persistence/pkg/cache/cache.go", 22 | "persistence/pkg/cache/errors.go":"infra/temporal/persistence/pkg/cache/errors.go", 23 | "persistence/pkg/cache/factory.go":"infra/temporal/persistence/pkg/cache/factory.go", 24 | "persistence/pkg/cache/metrics.go":"infra/temporal/persistence/pkg/cache/metrics.go", 25 | "persistence/pkg/cache/range.go":"infra/temporal/persistence/pkg/cache/range.go", 26 | "persistence/pkg/cache/scheduledcache.go":"infra/temporal/persistence/pkg/cache/scheduledcache.go", 27 | "persistence/pkg/cache/tasks.go":"infra/temporal/persistence/pkg/cache/tasks.go", 28 | "persistence/pkg/ydb/cluster_metadata_store.go":"infra/temporal/persistence/pkg/ydb/cluster_metadata_store.go", 29 | "persistence/pkg/ydb/config/config.go":"infra/temporal/persistence/pkg/ydb/config/config.go", 30 | "persistence/pkg/ydb/conn/client.go":"infra/temporal/persistence/pkg/ydb/conn/client.go", 31 | "persistence/pkg/ydb/conn/const.go":"infra/temporal/persistence/pkg/ydb/conn/const.go", 32 | "persistence/pkg/ydb/conn/errors.go":"infra/temporal/persistence/pkg/ydb/conn/errors.go", 33 | "persistence/pkg/ydb/conn/executor.go":"infra/temporal/persistence/pkg/ydb/conn/executor.go", 34 | "persistence/pkg/ydb/conn/log/adapter.go":"infra/temporal/persistence/pkg/ydb/conn/log/adapter.go", 35 | "persistence/pkg/ydb/conn/log/traces.go":"infra/temporal/persistence/pkg/ydb/conn/log/traces.go", 36 | "persistence/pkg/ydb/conn/metrics/metrics.go":"infra/temporal/persistence/pkg/ydb/conn/metrics/metrics.go", 37 | "persistence/pkg/ydb/conn/query.go":"infra/temporal/persistence/pkg/ydb/conn/query.go", 38 | "persistence/pkg/ydb/conn/ts.go":"infra/temporal/persistence/pkg/ydb/conn/ts.go", 39 | "persistence/pkg/ydb/conn/util.go":"infra/temporal/persistence/pkg/ydb/conn/util.go", 40 | "persistence/pkg/ydb/const.go":"infra/temporal/persistence/pkg/ydb/const.go", 41 | "persistence/pkg/ydb/execution_store.go":"infra/temporal/persistence/pkg/ydb/execution_store.go", 42 | "persistence/pkg/ydb/factory.go":"infra/temporal/persistence/pkg/ydb/factory.go", 43 | "persistence/pkg/ydb/history_store.go":"infra/temporal/persistence/pkg/ydb/history_store.go", 44 | "persistence/pkg/ydb/matching_task_store.go":"infra/temporal/persistence/pkg/ydb/matching_task_store.go", 45 | "persistence/pkg/ydb/metadata_store.go":"infra/temporal/persistence/pkg/ydb/metadata_store.go", 46 | "persistence/pkg/ydb/mirroring_cluster_metadata_store.go":"infra/temporal/persistence/pkg/ydb/mirroring_cluster_metadata_store.go", 47 | "persistence/pkg/ydb/mirroring_matching_task_store.go":"infra/temporal/persistence/pkg/ydb/mirroring_matching_task_store.go", 48 | "persistence/pkg/ydb/mirroring_metadata_store.go":"infra/temporal/persistence/pkg/ydb/mirroring_metadata_store.go", 49 | "persistence/pkg/ydb/mutable_state_store.go":"infra/temporal/persistence/pkg/ydb/mutable_state_store.go", 50 | "persistence/pkg/ydb/mutable_state_task_store.go":"infra/temporal/persistence/pkg/ydb/mutable_state_task_store.go", 51 | "persistence/pkg/ydb/nexus_store.go":"infra/temporal/persistence/pkg/ydb/nexus_store.go", 52 | "persistence/pkg/ydb/queue_store.go":"infra/temporal/persistence/pkg/ydb/queue_store.go", 53 | "persistence/pkg/ydb/queue_store_v2.go":"infra/temporal/persistence/pkg/ydb/queue_store_v2.go", 54 | "persistence/pkg/ydb/rows/assertions.go":"infra/temporal/persistence/pkg/ydb/rows/assertions.go", 55 | "persistence/pkg/ydb/rows/factory.go":"infra/temporal/persistence/pkg/ydb/rows/factory.go", 56 | "persistence/pkg/ydb/rows/helper.go":"infra/temporal/persistence/pkg/ydb/rows/helper.go", 57 | "persistence/pkg/ydb/rows/transaction.go":"infra/temporal/persistence/pkg/ydb/rows/transaction.go", 58 | "persistence/pkg/ydb/shard_store.go":"infra/temporal/persistence/pkg/ydb/shard_store.go", 59 | "persistence/schemas/ydb/goose/0001_initial.sql":"infra/temporal/persistence/schemas/ydb/goose/0001_initial.sql", 60 | "persistence/schemas/ydb/goose/0002_nexus.sql":"infra/temporal/persistence/schemas/ydb/goose/0002_nexus.sql", 61 | "persistence/schemas/ydb/goose/0003_queue_v2.sql":"infra/temporal/persistence/schemas/ydb/goose/0003_queue_v2.sql", 62 | "schema/temporal":"infra/temporal/persistence/schemas/ydb/goose", 63 | "schema/temporal/0001_initial.sql":"infra/temporal/persistence/schemas/ydb/goose/0001_initial.sql", 64 | "schema/temporal/0002_nexus.sql":"infra/temporal/persistence/schemas/ydb/goose/0002_nexus.sql", 65 | "schema/temporal/0003_queue_v2.sql":"infra/temporal/persistence/schemas/ydb/goose/0003_queue_v2.sql", 66 | "temporal-over-ydb/cmd/migrator/main.go":"infra/temporal/temporal-over-ydb/cmd/migrator/main.go", 67 | "temporal-over-ydb/cmd/migrator/ping.go":"infra/temporal/temporal-over-ydb/cmd/migrator/ping.go", 68 | "temporal-over-ydb/cmd/migrator/root.go":"infra/temporal/temporal-over-ydb/cmd/migrator/root.go", 69 | "temporal-over-ydb/cmd/migrator/up.go":"infra/temporal/temporal-over-ydb/cmd/migrator/up.go", 70 | "temporal-over-ydb/cmd/server/main.go":"infra/temporal/temporal-over-ydb/cmd/server/main.go", 71 | "temporal-over-ydb/github_toplevel/.gitignore":"infra/temporal/temporal-over-ydb/github_toplevel/.gitignore", 72 | "temporal-over-ydb/github_toplevel/AUTHORS.md":"infra/temporal/temporal-over-ydb/github_toplevel/AUTHORS.md", 73 | "temporal-over-ydb/github_toplevel/CONTRIBUTING.md":"infra/temporal/temporal-over-ydb/github_toplevel/CONTRIBUTING.md", 74 | "temporal-over-ydb/github_toplevel/CREDITS.md":"infra/temporal/temporal-over-ydb/github_toplevel/CREDITS.md", 75 | "temporal-over-ydb/github_toplevel/Dockerfile":"infra/temporal/temporal-over-ydb/github_toplevel/Dockerfile", 76 | "temporal-over-ydb/github_toplevel/LICENSE":"infra/temporal/temporal-over-ydb/github_toplevel/LICENSE", 77 | "temporal-over-ydb/github_toplevel/README.md":"infra/temporal/temporal-over-ydb/github_toplevel/README.md", 78 | "temporal-over-ydb/github_toplevel/docker/auto-setup.sh":"infra/temporal/temporal-over-ydb/github_toplevel/docker/auto-setup.sh", 79 | "temporal-over-ydb/github_toplevel/docker/config/config_template.yaml":"infra/temporal/temporal-over-ydb/github_toplevel/docker/config/config_template.yaml" 80 | } -------------------------------------------------------------------------------- /.piglet-meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "project":"temporal-over-ydb", 3 | "repository":"arcadia" 4 | } -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | The following is an inevitably incomplete list of authors who have created 2 | the source code of "temporal-over-ydb" published and distributed by YANDEX LLC 3 | as the owner: 4 | 5 | Anton Romanovich 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Notice to external contributors 2 | 3 | 4 | ## General info 5 | 6 | Hello! In order for us (YANDEX LLC) to accept patches and other contributions from you, you will have to adopt our Yandex Contributor License Agreement (the CLA). The current version of the CLA can be found here: 7 | 1) https://yandex.ru/legal/cla/?lang=en (in English) and 8 | 2) https://yandex.ru/legal/cla/?lang=ru (in Russian). 9 | 10 | By adopting the CLA, you state the following: 11 | 12 | * You obviously wish and are willingly licensing your contributions to us for our open source projects under the terms of the CLA, 13 | * You have read the terms and conditions of the CLA and agree with them in full, 14 | * You are legally able to provide and license your contributions as stated, 15 | * We may use your contributions for our open source projects and for any other our project too, 16 | * We rely on your assurances concerning the rights of third parties in relation to your contributions. 17 | 18 | If you agree with these principles, please read and adopt our CLA. By providing us your contributions, you hereby declare that you have already read and adopt our CLA, and we may freely merge your contributions with our corresponding open source project and use it in further in accordance with terms and conditions of the CLA. 19 | 20 | ## Provide contributions 21 | 22 | If you have already adopted terms and conditions of the CLA, you are able to provide your contributions. When you submit your pull request, please add the following information into it: 23 | 24 | ``` 25 | I hereby agree to the terms of the CLA available at: [link] 26 | ``` 27 | 28 | Replace the bracketed text as follows: 29 | * [link] is the link to the current version of the CLA: https://yandex.ru/legal/cla/?lang=en (in English) or https://yandex.ru/legal/cla/?lang=ru (in Russian). 30 | 31 | It is enough to provide us such notification once. 32 | 33 | ## Other questions 34 | 35 | If you have any questions, please mail us at opensource-support@yandex-team.ru. 36 | -------------------------------------------------------------------------------- /CREDITS.md: -------------------------------------------------------------------------------- 1 | # Credits 2 | 3 | **temporal-over-ydb** is heavily inspired by persistence layer implementations existing in 4 | [Cadence](https://github.com/uber/cadence/tree/master/common/persistence) and 5 | [Temporal](https://github.com/temporalio/temporal/tree/main/common/persistence) codebases, 6 | which are distributed under MIT license. 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BUILDER_IMAGE=golang:1.23-alpine3.20 2 | ARG BASE_SERVER_IMAGE=temporalio/server:1.26.2 3 | ARG BASE_AUTO_SETUP_IMAGE=temporalio/auto-setup:1.26.2 4 | 5 | ##### Builder ##### 6 | FROM ${BUILDER_IMAGE} AS builder 7 | 8 | WORKDIR /build 9 | 10 | # build 11 | COPY go.mod go.sum ./ 12 | RUN go mod download 13 | 14 | COPY . ./ 15 | 16 | RUN CGO_ENABLED=0 GOOS=linux go build -o /ydb-migrator ./temporal-over-ydb/cmd/migrator 17 | RUN CGO_ENABLED=0 GOOS=linux go build -o /temporal-server ./temporal-over-ydb/cmd/server 18 | 19 | ##### Temporal server ##### 20 | FROM ${BASE_SERVER_IMAGE} AS temporal-server-ydb 21 | 22 | WORKDIR /etc/temporal 23 | 24 | # binaries 25 | COPY --from=builder /temporal-server /usr/local/bin 26 | 27 | # configs 28 | COPY ./docker/config/config_template.yaml /etc/temporal/config/config_template.yaml 29 | 30 | 31 | ### Server auto-setup image ### 32 | FROM ${BASE_AUTO_SETUP_IMAGE} AS temporal-server-ydb-auto-setup 33 | 34 | WORKDIR /etc/temporal 35 | 36 | # binaries 37 | # temporal-ydb binary 38 | COPY --from=builder /temporal-server /usr/local/bin 39 | # goose binary 40 | COPY --from=builder /ydb-migrator /usr/local/bin 41 | 42 | USER temporal 43 | 44 | # configs 45 | COPY ./docker/config/config_template.yaml /etc/temporal/config/config_template.yaml 46 | 47 | # schema 48 | COPY --chown=temporal:temporal ./schema /etc/temporal/schema/ydb 49 | 50 | ## scripts 51 | COPY ./docker/auto-setup.sh /etc/temporal/auto-setup.sh 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2025 YANDEX LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![license](https://img.shields.io/github/license/yandex/temporal-over-ydb)](https://github.com/yandex/temporal-over-ydb/blob/main/LICENSE) 2 | ![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/yandex/temporal-over-ydb) 3 | ![stability-wip](https://img.shields.io/badge/stability-wip-lightgrey.svg) 4 | 5 | # temporal-over-ydb 6 | 7 | **temporal-over-ydb** is an implementation of a custom [Temporal](https://temporal.io) persistence layer using [YDB](https://ydb.tech), 8 | a distributed SQL DBMS. 9 | 10 | It is still a work in active progress and is not ready for production use. 11 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yandex/temporal-over-ydb 2 | 3 | go 1.23.8 4 | 5 | require ( 6 | github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4 7 | github.com/pborman/uuid v1.2.1 8 | github.com/pressly/goose/v3 v3.24.3 9 | github.com/prometheus/client_golang v1.21.1 10 | github.com/spf13/cobra v1.9.1 11 | github.com/urfave/cli/v2 v2.27.5 12 | github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 13 | github.com/ydb-platform/ydb-go-sdk-auth-environ v0.2.0 14 | github.com/ydb-platform/ydb-go-sdk-metrics v0.18.0 15 | github.com/ydb-platform/ydb-go-sdk/v3 v3.108.1 16 | go.temporal.io/api v1.44.1 17 | go.temporal.io/server v1.26.2 18 | go.uber.org/atomic v1.11.0 19 | go.uber.org/zap v1.27.0 20 | golang.org/x/sync v0.14.0 21 | golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da 22 | ) 23 | 24 | require ( 25 | cel.dev/expr v0.19.1 // indirect 26 | cloud.google.com/go v0.118.0 // indirect 27 | cloud.google.com/go/auth v0.14.0 // indirect 28 | cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect 29 | cloud.google.com/go/compute/metadata v0.6.0 // indirect 30 | cloud.google.com/go/iam v1.3.1 // indirect 31 | cloud.google.com/go/monitoring v1.23.0 // indirect 32 | cloud.google.com/go/storage v1.50.0 // indirect 33 | filippo.io/edwards25519 v1.1.0 // indirect 34 | github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.25.0 // indirect 35 | github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1 // indirect 36 | github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 // indirect 37 | github.com/apache/thrift v0.20.0 // indirect 38 | github.com/aws/aws-sdk-go v1.54.12 // indirect 39 | github.com/benbjohnson/clock v1.3.5 // indirect 40 | github.com/beorn7/perks v1.0.1 // indirect 41 | github.com/bitly/go-hostpool v0.1.0 // indirect 42 | github.com/blang/semver/v4 v4.0.0 // indirect 43 | github.com/cactus/go-statsd-client/statsd v0.0.0-20200423205355-cb0885a1018c // indirect 44 | github.com/cactus/go-statsd-client/v5 v5.1.0 // indirect 45 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect 46 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 47 | github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 // indirect 48 | github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect 49 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 50 | github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da // indirect 51 | github.com/dustin/go-humanize v1.0.1 // indirect 52 | github.com/emirpasic/gods v1.18.1 // indirect 53 | github.com/envoyproxy/go-control-plane/envoy v1.32.4 // indirect 54 | github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect 55 | github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect 56 | github.com/felixge/httpsnoop v1.0.4 // indirect 57 | github.com/go-logr/logr v1.4.2 // indirect 58 | github.com/go-logr/stdr v1.2.2 // indirect 59 | github.com/go-sql-driver/mysql v1.9.2 // indirect 60 | github.com/gocql/gocql v1.6.0 // indirect 61 | github.com/gogo/protobuf v1.3.2 // indirect 62 | github.com/golang-jwt/jwt/v4 v4.5.2 // indirect 63 | github.com/golang/mock v1.7.0-rc.1 // indirect 64 | github.com/golang/snappy v0.0.4 // indirect 65 | github.com/google/s2a-go v0.1.9 // indirect 66 | github.com/google/uuid v1.6.0 // indirect 67 | github.com/googleapis/enterprise-certificate-proxy v0.3.5 // indirect 68 | github.com/googleapis/gax-go/v2 v2.14.1 // indirect 69 | github.com/gorilla/mux v1.8.1 // indirect 70 | github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect 71 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect 72 | github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect 73 | github.com/iancoleman/strcase v0.3.0 // indirect 74 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 75 | github.com/jackc/pgpassfile v1.0.0 // indirect 76 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 77 | github.com/jackc/pgx/v5 v5.7.4 // indirect 78 | github.com/jackc/puddle/v2 v2.2.2 // indirect 79 | github.com/jmespath/go-jmespath v0.4.0 // indirect 80 | github.com/jmoiron/sqlx v1.3.5 // indirect 81 | github.com/jonboulle/clockwork v0.5.0 // indirect 82 | github.com/josharian/intern v1.0.0 // indirect 83 | github.com/klauspost/compress v1.18.0 // indirect 84 | github.com/lib/pq v1.10.9 // indirect 85 | github.com/mailru/easyjson v0.7.7 // indirect 86 | github.com/mattn/go-isatty v0.0.20 // indirect 87 | github.com/mattn/go-sqlite3 v2.0.1+incompatible // indirect 88 | github.com/mfridman/interpolate v0.0.2 // indirect 89 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 90 | github.com/ncruces/go-strftime v0.1.9 // indirect 91 | github.com/nexus-rpc/sdk-go v0.1.0 // indirect 92 | github.com/olivere/elastic/v7 v7.0.32 // indirect 93 | github.com/opentracing/opentracing-go v1.2.0 // indirect 94 | github.com/pkg/errors v0.9.1 // indirect 95 | github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect 96 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 97 | github.com/prometheus/client_model v0.6.1 // indirect 98 | github.com/prometheus/common v0.62.0 // indirect 99 | github.com/prometheus/procfs v0.16.1 // indirect 100 | github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect 101 | github.com/rekby/fixenv v0.7.0 // indirect 102 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 103 | github.com/robfig/cron v1.2.0 // indirect 104 | github.com/robfig/cron/v3 v3.0.1 // indirect 105 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 106 | github.com/sethvargo/go-retry v0.3.0 // indirect 107 | github.com/sirupsen/logrus v1.9.3 // indirect 108 | github.com/sony/gobreaker v1.0.0 // indirect 109 | github.com/spf13/pflag v1.0.6 // indirect 110 | github.com/stretchr/objx v0.5.2 // indirect 111 | github.com/stretchr/testify v1.10.0 // indirect 112 | github.com/temporalio/ringpop-go v0.0.0-20241119001152-e505ebd8f887 // indirect 113 | github.com/temporalio/sqlparser v0.0.0-20231115171017-f4060bcfa6cb // indirect 114 | github.com/temporalio/tchannel-go v1.22.1-0.20240528171429-1db37fdea938 // indirect 115 | github.com/twmb/murmur3 v1.1.8 // indirect 116 | github.com/uber-common/bark v1.3.0 // indirect 117 | github.com/uber-go/tally/v4 v4.1.17-0.20240412215630-22fe011f5ff0 // indirect 118 | github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect 119 | github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect 120 | github.com/yandex-cloud/go-genproto v0.0.0-20240425114406-68c9b49389a1 // indirect 121 | github.com/ydb-platform/ydb-go-yc v0.12.1 // indirect 122 | github.com/ydb-platform/ydb-go-yc-metadata v0.6.1 // indirect 123 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 124 | go.opentelemetry.io/contrib/detectors/gcp v1.34.0 // indirect 125 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect 126 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect 127 | go.opentelemetry.io/otel v1.35.0 // indirect 128 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0 // indirect 129 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0 // indirect 130 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.35.0 // indirect 131 | go.opentelemetry.io/otel/exporters/prometheus v0.57.0 // indirect 132 | go.opentelemetry.io/otel/metric v1.35.0 // indirect 133 | go.opentelemetry.io/otel/sdk v1.35.0 // indirect 134 | go.opentelemetry.io/otel/sdk/metric v1.35.0 // indirect 135 | go.opentelemetry.io/otel/trace v1.35.0 // indirect 136 | go.opentelemetry.io/proto/otlp v1.5.0 // indirect 137 | go.temporal.io/sdk v1.32.1 // indirect 138 | go.temporal.io/version v0.3.0 // indirect 139 | go.uber.org/dig v1.18.0 // indirect 140 | go.uber.org/fx v1.23.0 // indirect 141 | go.uber.org/mock v0.5.0 // indirect 142 | go.uber.org/multierr v1.11.0 // indirect 143 | golang.org/x/crypto v0.38.0 // indirect 144 | golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 // indirect 145 | golang.org/x/net v0.40.0 // indirect 146 | golang.org/x/oauth2 v0.27.0 // indirect 147 | golang.org/x/sys v0.33.0 // indirect 148 | golang.org/x/text v0.25.0 // indirect 149 | golang.org/x/time v0.9.0 // indirect 150 | google.golang.org/api v0.217.0 // indirect 151 | google.golang.org/genproto v0.0.0-20250124145028-65684f501c47 // indirect 152 | google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb // indirect 153 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 // indirect 154 | google.golang.org/grpc v1.71.0 // indirect 155 | google.golang.org/protobuf v1.36.6 // indirect 156 | gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect 157 | gopkg.in/inf.v0 v0.9.1 // indirect 158 | gopkg.in/validator.v2 v2.0.1 // indirect 159 | gopkg.in/yaml.v3 v3.0.1 // indirect 160 | modernc.org/libc v1.65.0 // indirect 161 | modernc.org/mathutil v1.7.1 // indirect 162 | modernc.org/memory v1.10.0 // indirect 163 | modernc.org/sqlite v1.37.0 // indirect 164 | ) 165 | 166 | exclude github.com/keybase/go.dbus v0.0.0-20220506165403-5aa21ea2c23a 167 | 168 | replace github.com/insomniacslk/dhcp => github.com/insomniacslk/dhcp v0.0.0-20210120172423-cc9239ac6294 169 | 170 | replace cloud.google.com/go/pubsub => cloud.google.com/go/pubsub v1.30.0 171 | 172 | replace go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc => go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.22.0 173 | 174 | replace go.temporal.io/api => go.temporal.io/api v1.43.2 175 | 176 | replace go.temporal.io/server => go.temporal.io/server v1.26.2 177 | 178 | replace go.temporal.io/sdk => go.temporal.io/sdk v1.31.0 179 | 180 | replace github.com/jackc/pgtype => github.com/jackc/pgtype v1.12.0 181 | 182 | replace github.com/aws/aws-sdk-go => github.com/aws/aws-sdk-go v1.46.7 183 | 184 | replace k8s.io/api => k8s.io/api v0.26.1 185 | 186 | replace k8s.io/apiextensions-apiserver => k8s.io/apiextensions-apiserver v0.26.1 187 | 188 | replace k8s.io/apimachinery => k8s.io/apimachinery v0.26.1 189 | 190 | replace k8s.io/apiserver => k8s.io/apiserver v0.26.1 191 | 192 | replace k8s.io/cli-runtime => k8s.io/cli-runtime v0.26.1 193 | 194 | replace k8s.io/client-go => k8s.io/client-go v0.26.1 195 | 196 | replace k8s.io/cloud-provider => k8s.io/cloud-provider v0.26.1 197 | 198 | replace k8s.io/cluster-bootstrap => k8s.io/cluster-bootstrap v0.26.1 199 | 200 | replace k8s.io/code-generator => k8s.io/code-generator v0.26.1 201 | 202 | replace k8s.io/component-base => k8s.io/component-base v0.26.1 203 | 204 | replace k8s.io/cri-api => k8s.io/cri-api v0.23.5 205 | 206 | replace k8s.io/csi-translation-lib => k8s.io/csi-translation-lib v0.26.1 207 | 208 | replace k8s.io/dynamic-resource-allocation => k8s.io/dynamic-resource-allocation v0.26.1 209 | 210 | replace k8s.io/kube-controller-manager => k8s.io/kube-controller-manager v0.26.1 211 | 212 | replace k8s.io/kube-proxy => k8s.io/kube-proxy v0.26.1 213 | 214 | replace k8s.io/kube-scheduler => k8s.io/kube-scheduler v0.26.1 215 | 216 | replace k8s.io/kubelet => k8s.io/kubelet v0.26.1 217 | 218 | replace k8s.io/legacy-cloud-providers => k8s.io/legacy-cloud-providers v0.26.1 219 | 220 | replace k8s.io/mount-utils => k8s.io/mount-utils v0.26.2-rc.0 221 | 222 | replace k8s.io/pod-security-admission => k8s.io/pod-security-admission v0.26.1 223 | 224 | replace k8s.io/sample-apiserver => k8s.io/sample-apiserver v0.26.1 225 | 226 | replace github.com/temporalio/features => github.com/temporalio/features v0.0.0-20231218231852-27c681667dae 227 | 228 | replace github.com/temporalio/features/features => github.com/temporalio/features/features v0.0.0-20231218231852-27c681667dae 229 | 230 | replace github.com/temporalio/features/harness/go => github.com/temporalio/features/harness/go v0.0.0-20231218231852-27c681667dae 231 | 232 | replace github.com/temporalio/omes => github.com/temporalio/omes v0.0.0-20240701113332-211647aa9dae 233 | 234 | replace github.com/aleroyer/rsyslog_exporter => github.com/prometheus-community/rsyslog_exporter v1.1.0 235 | 236 | replace github.com/prometheus/client_model => github.com/prometheus/client_model v0.6.1 237 | 238 | replace github.com/prometheus/common => github.com/prometheus/common v0.62.0 239 | 240 | replace github.com/distribution/reference => github.com/distribution/reference v0.5.0 241 | 242 | replace github.com/jackc/pgconn => github.com/jackc/pgconn v1.14.0 243 | 244 | replace github.com/jackc/pgproto3/v2 => github.com/jackc/pgproto3/v2 v2.3.2 245 | 246 | replace github.com/mattn/go-sqlite3 => github.com/mattn/go-sqlite3 v1.14.24 247 | 248 | replace github.com/docker/docker => github.com/docker/docker v25.0.6+incompatible 249 | 250 | replace github.com/testcontainers/testcontainers-go => github.com/testcontainers/testcontainers-go v0.31.0 251 | -------------------------------------------------------------------------------- /persistence/pkg/base/errors/errors.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | import ( 4 | "fmt" 5 | 6 | "go.temporal.io/api/serviceerror" 7 | ) 8 | 9 | func NewInternal(msg string) error { 10 | return serviceerror.NewInternal(msg) 11 | } 12 | 13 | func NewInternalF(format string, args ...interface{}) error { 14 | return serviceerror.NewInternal(fmt.Sprintf(format, args...)) 15 | } 16 | 17 | func NewInvalidArgumentF(format string, args ...interface{}) error { 18 | return serviceerror.NewInvalidArgument(fmt.Sprintf(format, args...)) 19 | } 20 | 21 | func NewNotFoundF(format string, args ...interface{}) error { 22 | return serviceerror.NewNotFound(fmt.Sprintf(format, args...)) 23 | } 24 | 25 | func NewUnavailableF(format string, args ...interface{}) error { 26 | return serviceerror.NewUnavailable(fmt.Sprintf(format, args...)) 27 | } 28 | -------------------------------------------------------------------------------- /persistence/pkg/base/executor/factory.go: -------------------------------------------------------------------------------- 1 | package executor 2 | 3 | import ( 4 | "context" 5 | 6 | commonpb "go.temporal.io/api/common/v1" 7 | "go.temporal.io/server/api/enums/v1" 8 | p "go.temporal.io/server/common/persistence" 9 | "go.temporal.io/server/common/primitives" 10 | "go.temporal.io/server/service/history/tasks" 11 | ) 12 | 13 | type CurrentRunIDAndLastWriteVersion struct { 14 | LastWriteVersion int64 15 | CurrentRunID primitives.UUID 16 | } 17 | 18 | type Query[S any] interface { 19 | ToStmt(prefix string, conditionID primitives.UUID) S 20 | } 21 | 22 | type TransactionFactory interface { 23 | NewTransaction(shardID int32) Transaction 24 | } 25 | 26 | type EventsCache interface { 27 | Invalidate(ctx context.Context, shardID int32) error 28 | Put(ctx context.Context, shardID int32, tasks ...map[tasks.Category][]p.InternalHistoryTask) error 29 | } 30 | 31 | type Transaction interface { 32 | AssertShard(lockForUpdate bool, rangeIDEqualTo int64) 33 | AssertCurrentWorkflow( 34 | lockForUpdate bool, namespaceID primitives.UUID, workflowID string, 35 | currentRunIDNotEqualTo primitives.UUID, currentRunIDEqualTo primitives.UUID, 36 | mustNotExist bool, currentRunIDAndLastWriteVersionEqualTo *CurrentRunIDAndLastWriteVersion, 37 | ) 38 | AssertWorkflowExecution( 39 | lockForUpdate bool, namespaceID primitives.UUID, workflowID string, runID primitives.UUID, 40 | recordVersionEqualTo *int64, mustNotExist bool, 41 | ) 42 | 43 | UpsertShard(rangeID int64, shardInfo *commonpb.DataBlob) 44 | UpsertCurrentWorkflow( 45 | namespaceID primitives.UUID, workflowID string, currentRunID primitives.UUID, 46 | executionStateBlob *commonpb.DataBlob, lastWriteVersion int64, state enums.WorkflowExecutionState) 47 | HandleWorkflowSnapshot(workflowSnapshot *p.InternalWorkflowSnapshot) 48 | HandleWorkflowMutation(workflowMutation *p.InternalWorkflowMutation) 49 | InsertHistoryTasks(insertTasks map[tasks.Category][]p.InternalHistoryTask) 50 | // UpsertHistoryTasks is used for migration purposes 51 | UpsertHistoryTasks(tasks.Category, []p.InternalHistoryTask) error 52 | 53 | DeleteBufferedEvents(namespaceID primitives.UUID, workflowID string, runID primitives.UUID) 54 | DeleteStateItems(namespaceID primitives.UUID, workflowID string, runID primitives.UUID) 55 | 56 | Execute(ctx context.Context) error 57 | } 58 | -------------------------------------------------------------------------------- /persistence/pkg/base/mss/mss.go: -------------------------------------------------------------------------------- 1 | package mss 2 | 3 | import ( 4 | "context" 5 | 6 | commonpb "go.temporal.io/api/common/v1" 7 | "go.temporal.io/server/api/enums/v1" 8 | persistencespb "go.temporal.io/server/api/persistence/v1" 9 | p "go.temporal.io/server/common/persistence" 10 | "go.temporal.io/server/common/persistence/serialization" 11 | "go.temporal.io/server/common/primitives" 12 | "go.temporal.io/server/service/history/tasks" 13 | 14 | baseerrors "github.com/yandex/temporal-over-ydb/persistence/pkg/base/errors" 15 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/executor" 16 | ) 17 | 18 | type ( 19 | BaseMutableStateStore struct { 20 | cache executor.EventsCache 21 | } 22 | ) 23 | 24 | func NewBaseMutableStateStore( 25 | cache executor.EventsCache, 26 | ) *BaseMutableStateStore { 27 | return &BaseMutableStateStore{ 28 | cache: cache, 29 | } 30 | } 31 | 32 | func (d *BaseMutableStateStore) CreateWorkflowExecution( 33 | ctx context.Context, 34 | request *p.InternalCreateWorkflowExecutionRequest, 35 | tf executor.TransactionFactory, 36 | ) (resp *p.InternalCreateWorkflowExecutionResponse, err error) { 37 | transaction := tf.NewTransaction(request.ShardID) 38 | return d.CreateWorkflowExecutionWithinTransaction(ctx, request, transaction) 39 | } 40 | 41 | func (d *BaseMutableStateStore) CreateWorkflowExecutionWithinTransaction( 42 | ctx context.Context, 43 | request *p.InternalCreateWorkflowExecutionRequest, 44 | transaction executor.Transaction, 45 | ) (resp *p.InternalCreateWorkflowExecutionResponse, err error) { 46 | defer func() { 47 | // TODO narrow this condition: we should invalidate only if we actually lost shard ownership 48 | if err != nil { 49 | _ = d.cache.Invalidate(ctx, request.ShardID) 50 | } 51 | }() 52 | 53 | shardID := request.ShardID 54 | newWorkflow := request.NewWorkflowSnapshot 55 | lastWriteVersion := newWorkflow.LastWriteVersion 56 | namespaceID := primitives.MustParseUUID(newWorkflow.NamespaceID) 57 | workflowID := newWorkflow.WorkflowID 58 | runID := primitives.MustParseUUID(newWorkflow.RunID) 59 | 60 | transaction.AssertShard(false, request.RangeID) 61 | 62 | switch request.Mode { 63 | case p.CreateWorkflowModeBypassCurrent: 64 | // noop 65 | case p.CreateWorkflowModeUpdateCurrent: 66 | transaction.AssertCurrentWorkflow( 67 | true, 68 | namespaceID, 69 | workflowID, 70 | nil, 71 | nil, 72 | false, 73 | &executor.CurrentRunIDAndLastWriteVersion{ 74 | LastWriteVersion: request.PreviousLastWriteVersion, 75 | CurrentRunID: primitives.MustParseUUID(request.PreviousRunID), 76 | }, 77 | ) 78 | transaction.UpsertCurrentWorkflow( 79 | namespaceID, workflowID, runID, 80 | newWorkflow.ExecutionStateBlob, lastWriteVersion, newWorkflow.ExecutionState.State, 81 | ) 82 | case p.CreateWorkflowModeBrandNew: 83 | transaction.AssertCurrentWorkflow( 84 | false, 85 | namespaceID, 86 | workflowID, 87 | nil, 88 | nil, 89 | true, 90 | nil, 91 | ) 92 | transaction.UpsertCurrentWorkflow( 93 | namespaceID, workflowID, runID, 94 | newWorkflow.ExecutionStateBlob, lastWriteVersion, newWorkflow.ExecutionState.State, 95 | ) 96 | default: 97 | return nil, baseerrors.NewInternalF("unknown mode: %v", request.Mode) 98 | } 99 | transaction.HandleWorkflowSnapshot(&newWorkflow) 100 | transaction.AssertWorkflowExecution( 101 | false, 102 | namespaceID, 103 | workflowID, 104 | runID, 105 | nil, 106 | true, 107 | ) 108 | 109 | if err = transaction.Execute(ctx); err != nil { 110 | return nil, err 111 | } else { 112 | _ = d.cache.Put(ctx, shardID, newWorkflow.Tasks) 113 | } 114 | return &p.InternalCreateWorkflowExecutionResponse{}, nil 115 | } 116 | 117 | func (d *BaseMutableStateStore) UpdateWorkflowExecution( 118 | ctx context.Context, 119 | request *p.InternalUpdateWorkflowExecutionRequest, 120 | tf executor.TransactionFactory, 121 | ) (err error) { 122 | transaction := tf.NewTransaction(request.ShardID) 123 | return d.UpdateWorkflowExecutionWithinTransaction(ctx, request, transaction) 124 | } 125 | 126 | func (d *BaseMutableStateStore) UpdateWorkflowExecutionWithinTransaction( 127 | ctx context.Context, 128 | request *p.InternalUpdateWorkflowExecutionRequest, 129 | transaction executor.Transaction, 130 | ) (err error) { 131 | defer func() { 132 | // TODO narrow this condition: we should invalidate only if we actually lost shard ownership 133 | if err != nil { 134 | _ = d.cache.Invalidate(ctx, request.ShardID) 135 | } 136 | }() 137 | 138 | updateWorkflow := request.UpdateWorkflowMutation 139 | newWorkflow := request.NewWorkflowSnapshot 140 | 141 | if newWorkflow != nil { 142 | if updateWorkflow.NamespaceID != newWorkflow.NamespaceID { 143 | return baseerrors.NewInternalF("cannot continue as new to another namespace") 144 | } 145 | if err = p.ValidateCreateWorkflowStateStatus(newWorkflow.ExecutionState.State, newWorkflow.ExecutionState.Status); err != nil { 146 | return err 147 | } 148 | } 149 | // validate workflow state & close status 150 | if err = p.ValidateUpdateWorkflowStateStatus(updateWorkflow.ExecutionState.State, updateWorkflow.ExecutionState.Status); err != nil { 151 | return err 152 | } 153 | 154 | shardID := request.ShardID 155 | namespaceID := primitives.MustParseUUID(updateWorkflow.NamespaceID) 156 | 157 | transaction.AssertShard(false, request.RangeID) 158 | 159 | switch request.Mode { 160 | case p.UpdateWorkflowModeBypassCurrent: 161 | transaction.AssertCurrentWorkflow( 162 | true, 163 | namespaceID, 164 | updateWorkflow.WorkflowID, 165 | primitives.MustParseUUID(updateWorkflow.RunID), 166 | nil, 167 | false, 168 | nil, 169 | ) 170 | case p.UpdateWorkflowModeUpdateCurrent: 171 | if newWorkflow != nil { 172 | transaction.UpsertCurrentWorkflow( 173 | namespaceID, newWorkflow.WorkflowID, primitives.MustParseUUID(newWorkflow.RunID), 174 | newWorkflow.ExecutionStateBlob, newWorkflow.LastWriteVersion, newWorkflow.ExecutionState.State, 175 | ) 176 | 177 | transaction.AssertCurrentWorkflow( 178 | true, 179 | namespaceID, 180 | newWorkflow.WorkflowID, 181 | nil, 182 | primitives.MustParseUUID(updateWorkflow.RunID), 183 | false, 184 | nil, 185 | ) 186 | } else { 187 | executionStateBlob, err := serialization.WorkflowExecutionStateToBlob(updateWorkflow.ExecutionState) 188 | if err != nil { 189 | return err 190 | } 191 | transaction.UpsertCurrentWorkflow( 192 | namespaceID, updateWorkflow.WorkflowID, primitives.MustParseUUID(updateWorkflow.RunID), 193 | executionStateBlob, updateWorkflow.LastWriteVersion, updateWorkflow.ExecutionState.State, 194 | ) 195 | 196 | transaction.AssertCurrentWorkflow( 197 | true, 198 | namespaceID, 199 | updateWorkflow.WorkflowID, 200 | nil, 201 | primitives.MustParseUUID(updateWorkflow.RunID), 202 | false, 203 | nil, 204 | ) 205 | } 206 | default: 207 | return baseerrors.NewInternalF("unknown mode: %v", request.Mode) 208 | } 209 | 210 | expectedDBRecordVersion := updateWorkflow.DBRecordVersion - 1 211 | 212 | transaction.AssertWorkflowExecution( 213 | true, 214 | namespaceID, 215 | updateWorkflow.WorkflowID, 216 | primitives.MustParseUUID(updateWorkflow.RunID), 217 | &expectedDBRecordVersion, 218 | false, 219 | ) 220 | 221 | if newWorkflow != nil { 222 | transaction.AssertWorkflowExecution( 223 | false, 224 | namespaceID, 225 | newWorkflow.WorkflowID, 226 | primitives.MustParseUUID(newWorkflow.RunID), 227 | nil, 228 | true, 229 | ) 230 | } 231 | 232 | transaction.HandleWorkflowMutation(&updateWorkflow) 233 | if newWorkflow != nil { 234 | transaction.HandleWorkflowSnapshot(newWorkflow) 235 | } 236 | 237 | if updateWorkflow.ClearBufferedEvents { 238 | transaction.DeleteBufferedEvents(namespaceID, updateWorkflow.WorkflowID, primitives.MustParseUUID(updateWorkflow.RunID)) 239 | } 240 | err = transaction.Execute(ctx) 241 | if err == nil { 242 | taskMaps := []map[tasks.Category][]p.InternalHistoryTask{updateWorkflow.Tasks} 243 | _ = d.cache.Put(ctx, shardID, updateWorkflow.Tasks) 244 | if newWorkflow != nil { 245 | taskMaps = append(taskMaps, newWorkflow.Tasks) 246 | } 247 | _ = d.cache.Put(ctx, shardID, taskMaps...) 248 | } 249 | return err 250 | } 251 | 252 | func (d *BaseMutableStateStore) ConflictResolveWorkflowExecution( 253 | ctx context.Context, 254 | request *p.InternalConflictResolveWorkflowExecutionRequest, 255 | tf executor.TransactionFactory, 256 | ) (err error) { 257 | transaction := tf.NewTransaction(request.ShardID) 258 | return d.ConflictResolveWorkflowExecutionWithinTransaction(ctx, request, transaction) 259 | } 260 | 261 | func (d *BaseMutableStateStore) ConflictResolveWorkflowExecutionWithinTransaction( 262 | ctx context.Context, 263 | request *p.InternalConflictResolveWorkflowExecutionRequest, 264 | transaction executor.Transaction, 265 | ) (err error) { 266 | defer func() { 267 | // TODO narrow this condition: we should invalidate only if we actually lost shard ownership 268 | if err != nil { 269 | _ = d.cache.Invalidate(ctx, request.ShardID) 270 | } 271 | }() 272 | 273 | currentWorkflow := request.CurrentWorkflowMutation 274 | resetWorkflow := request.ResetWorkflowSnapshot 275 | newWorkflow := request.NewWorkflowSnapshot 276 | 277 | if err = p.ValidateUpdateWorkflowStateStatus(resetWorkflow.ExecutionState.State, resetWorkflow.ExecutionState.Status); err != nil { 278 | return err 279 | } 280 | 281 | shardID := request.ShardID 282 | namespaceID := primitives.MustParseUUID(resetWorkflow.NamespaceID) 283 | workflowID := resetWorkflow.WorkflowID 284 | 285 | transaction.AssertShard(false, request.RangeID) 286 | 287 | if request.Mode == p.ConflictResolveWorkflowModeBypassCurrent { 288 | transaction.AssertCurrentWorkflow( 289 | true, 290 | namespaceID, 291 | workflowID, 292 | primitives.MustParseUUID(resetWorkflow.ExecutionState.RunId), 293 | nil, 294 | false, 295 | nil, 296 | ) 297 | } else if request.Mode == p.ConflictResolveWorkflowModeUpdateCurrent { 298 | executionState := resetWorkflow.ExecutionState 299 | lastWriteVersion := resetWorkflow.LastWriteVersion 300 | if newWorkflow != nil { 301 | executionState = newWorkflow.ExecutionState 302 | lastWriteVersion = newWorkflow.LastWriteVersion 303 | } 304 | runID := executionState.RunId 305 | createRequestID := executionState.CreateRequestId 306 | state := executionState.State 307 | status := executionState.Status 308 | 309 | executionStateBlob, err := serialization.WorkflowExecutionStateToBlob(&persistencespb.WorkflowExecutionState{ 310 | RunId: runID, 311 | CreateRequestId: createRequestID, 312 | State: state, 313 | Status: status, 314 | }) 315 | if err != nil { 316 | return baseerrors.NewInternalF("WorkflowExecutionStateToBlob failed: %v", err) 317 | } 318 | 319 | var currentRunID string 320 | if currentWorkflow != nil { 321 | currentRunID = currentWorkflow.ExecutionState.RunId 322 | } else { 323 | currentRunID = resetWorkflow.ExecutionState.RunId 324 | } 325 | transaction.AssertCurrentWorkflow( 326 | true, 327 | namespaceID, 328 | workflowID, 329 | nil, 330 | primitives.MustParseUUID(currentRunID), 331 | false, 332 | nil, 333 | ) 334 | transaction.UpsertCurrentWorkflow( 335 | namespaceID, workflowID, primitives.MustParseUUID(runID), 336 | executionStateBlob, lastWriteVersion, state, 337 | ) 338 | } else { 339 | return baseerrors.NewInternalF("unknown mode: %v", request.Mode) 340 | } 341 | 342 | expectedResetWfDBRecordVersion := resetWorkflow.DBRecordVersion - 1 343 | transaction.AssertWorkflowExecution( 344 | true, 345 | namespaceID, 346 | resetWorkflow.WorkflowID, 347 | primitives.MustParseUUID(resetWorkflow.RunID), 348 | &expectedResetWfDBRecordVersion, 349 | false, 350 | ) 351 | if currentWorkflow != nil { 352 | expectedCurrentWfDBRecordVersion := currentWorkflow.DBRecordVersion - 1 353 | transaction.AssertWorkflowExecution( 354 | true, 355 | namespaceID, 356 | resetWorkflow.WorkflowID, 357 | primitives.MustParseUUID(currentWorkflow.RunID), 358 | &expectedCurrentWfDBRecordVersion, 359 | false, 360 | ) 361 | } 362 | 363 | transaction.HandleWorkflowSnapshot(&resetWorkflow) 364 | if newWorkflow != nil { 365 | transaction.HandleWorkflowSnapshot(newWorkflow) 366 | } 367 | 368 | if currentWorkflow != nil { 369 | transaction.HandleWorkflowMutation(currentWorkflow) 370 | if currentWorkflow.ClearBufferedEvents { 371 | transaction.DeleteBufferedEvents(namespaceID, currentWorkflow.WorkflowID, primitives.MustParseUUID(currentWorkflow.RunID)) 372 | } 373 | transaction.DeleteStateItems(namespaceID, resetWorkflow.WorkflowID, primitives.MustParseUUID(resetWorkflow.RunID)) 374 | } else { 375 | transaction.DeleteStateItems(namespaceID, workflowID, primitives.MustParseUUID(resetWorkflow.RunID)) 376 | } 377 | 378 | err = transaction.Execute(ctx) 379 | if err == nil { 380 | taskMaps := []map[tasks.Category][]p.InternalHistoryTask{resetWorkflow.Tasks} 381 | if newWorkflow != nil { 382 | taskMaps = append(taskMaps, newWorkflow.Tasks) 383 | } 384 | if currentWorkflow != nil { 385 | taskMaps = append(taskMaps, currentWorkflow.Tasks) 386 | } 387 | _ = d.cache.Put(ctx, shardID, taskMaps...) 388 | } 389 | return err 390 | } 391 | 392 | func (d *BaseMutableStateStore) SetWorkflowExecution( 393 | ctx context.Context, 394 | request *p.InternalSetWorkflowExecutionRequest, 395 | tf executor.TransactionFactory, 396 | ) (err error) { 397 | transaction := tf.NewTransaction(request.ShardID) 398 | return d.SetWorkflowExecutionWithinTransaction(ctx, request, transaction) 399 | } 400 | 401 | func (d *BaseMutableStateStore) SetWorkflowExecutionWithinTransaction( 402 | ctx context.Context, 403 | request *p.InternalSetWorkflowExecutionRequest, 404 | transaction executor.Transaction, 405 | ) (err error) { 406 | setSnapshot := request.SetWorkflowSnapshot 407 | namespaceID := primitives.MustParseUUID(setSnapshot.NamespaceID) 408 | runID := primitives.MustParseUUID(setSnapshot.RunID) 409 | 410 | if err = p.ValidateUpdateWorkflowStateStatus(setSnapshot.ExecutionState.State, setSnapshot.ExecutionState.Status); err != nil { 411 | return err 412 | } 413 | 414 | expectedDBRecordVersion := setSnapshot.DBRecordVersion - 1 415 | 416 | transaction.AssertShard(false, request.RangeID) 417 | transaction.AssertWorkflowExecution( 418 | true, 419 | namespaceID, 420 | setSnapshot.WorkflowID, 421 | runID, 422 | &expectedDBRecordVersion, 423 | false, 424 | ) 425 | 426 | transaction.HandleWorkflowSnapshot(&setSnapshot) 427 | transaction.DeleteStateItems(namespaceID, setSnapshot.WorkflowID, runID) 428 | return transaction.Execute(ctx) 429 | } 430 | 431 | func (d *BaseMutableStateStore) SetCurrentExecution( 432 | ctx context.Context, 433 | shardID int32, 434 | namespaceID string, 435 | workflowID string, 436 | runID string, 437 | executionStateBlob *commonpb.DataBlob, 438 | lastWriteVersion int64, 439 | state enums.WorkflowExecutionState, 440 | tf executor.TransactionFactory, 441 | ) (err error) { 442 | transaction := tf.NewTransaction(shardID) 443 | transaction.UpsertCurrentWorkflow( 444 | primitives.MustParseUUID(namespaceID), 445 | workflowID, 446 | primitives.MustParseUUID(runID), 447 | executionStateBlob, 448 | lastWriteVersion, 449 | state, 450 | ) 451 | return transaction.Execute(ctx) 452 | } 453 | 454 | func (d *BaseMutableStateStore) PutWorkflowExecution( 455 | ctx context.Context, 456 | request *p.InternalSetWorkflowExecutionRequest, 457 | asCurrent bool, 458 | tf executor.TransactionFactory, 459 | ) (err error) { 460 | transaction := tf.NewTransaction(request.ShardID) 461 | return d.PutWorkflowExecutionWithinTransaction(ctx, request, asCurrent, transaction) 462 | } 463 | 464 | func (d *BaseMutableStateStore) PutWorkflowExecutionWithinTransaction( 465 | ctx context.Context, 466 | request *p.InternalSetWorkflowExecutionRequest, 467 | asCurrent bool, 468 | transaction executor.Transaction, 469 | ) (err error) { 470 | setSnapshot := request.SetWorkflowSnapshot 471 | namespaceID := primitives.MustParseUUID(setSnapshot.NamespaceID) 472 | runID := primitives.MustParseUUID(setSnapshot.RunID) 473 | 474 | if err = p.ValidateUpdateWorkflowStateStatus(setSnapshot.ExecutionState.State, setSnapshot.ExecutionState.Status); err != nil { 475 | return err 476 | } 477 | 478 | expectedDBRecordVersion := setSnapshot.DBRecordVersion 479 | 480 | transaction.AssertShard(false, request.RangeID) 481 | transaction.AssertWorkflowExecution( 482 | true, 483 | namespaceID, 484 | setSnapshot.WorkflowID, 485 | runID, 486 | &expectedDBRecordVersion, 487 | false, 488 | ) 489 | if asCurrent { 490 | transaction.AssertCurrentWorkflow( 491 | true, 492 | namespaceID, 493 | setSnapshot.WorkflowID, 494 | nil, 495 | primitives.MustParseUUID(request.SetWorkflowSnapshot.RunID), 496 | false, 497 | nil, 498 | ) 499 | transaction.UpsertCurrentWorkflow( 500 | primitives.MustParseUUID(request.SetWorkflowSnapshot.NamespaceID), 501 | request.SetWorkflowSnapshot.WorkflowID, 502 | primitives.MustParseUUID(request.SetWorkflowSnapshot.RunID), 503 | request.SetWorkflowSnapshot.ExecutionStateBlob, 504 | request.SetWorkflowSnapshot.LastWriteVersion, 505 | request.SetWorkflowSnapshot.ExecutionState.State, 506 | ) 507 | } 508 | 509 | transaction.HandleWorkflowSnapshot(&setSnapshot) 510 | transaction.DeleteStateItems(namespaceID, setSnapshot.WorkflowID, runID) 511 | return transaction.Execute(ctx) 512 | } 513 | -------------------------------------------------------------------------------- /persistence/pkg/base/rows/const.go: -------------------------------------------------------------------------------- 1 | package rows 2 | 3 | const ( 4 | ItemTypeActivity = iota 5 | ItemTypeRequestCancel 6 | ItemTypeSignal 7 | ItemTypeTimer 8 | ItemTypeChildExecution 9 | ItemTypeBufferedEvent 10 | ItemTypeSignalRequested 11 | ) 12 | -------------------------------------------------------------------------------- /persistence/pkg/base/rows/rows.go: -------------------------------------------------------------------------------- 1 | package rows 2 | 3 | import ( 4 | "time" 5 | 6 | v1 "go.temporal.io/api/enums/v1" 7 | "go.temporal.io/server/common/primitives" 8 | ) 9 | 10 | type IDEventKey struct { 11 | ShardID int32 12 | NamespaceID primitives.UUID 13 | WorkflowID string 14 | RunID primitives.UUID 15 | ItemType int32 16 | ItemID int64 17 | } 18 | 19 | type NameEventKey struct { 20 | ShardID int32 21 | NamespaceID primitives.UUID 22 | WorkflowID string 23 | RunID primitives.UUID 24 | ItemType int32 25 | ItemName string 26 | } 27 | 28 | type ShardRow struct { 29 | ShardID int32 30 | RangeID int64 31 | Shard []byte 32 | ShardEncoding v1.EncodingType 33 | } 34 | 35 | type CurrentExecutionRow struct { 36 | ShardID int32 37 | NamespaceID primitives.UUID 38 | WorkflowID string 39 | CurrentRunID primitives.UUID 40 | ExecutionState []byte 41 | ExecutionStateEncoding v1.EncodingType 42 | LastWriteVersion int64 43 | State int32 44 | } 45 | 46 | type WorkflowExecutionRow struct { 47 | ShardID int32 48 | NamespaceID primitives.UUID 49 | WorkflowID string 50 | RunID primitives.UUID 51 | Execution []byte 52 | ExecutionEncoding v1.EncodingType 53 | ExecutionState []byte 54 | ExecutionStateEncoding v1.EncodingType 55 | Checksum []byte 56 | ChecksumEncoding v1.EncodingType 57 | NextEventID int64 58 | DBRecordVersion int64 59 | } 60 | 61 | type IdentifiedItemRow struct { 62 | ShardID int32 63 | NamespaceID primitives.UUID 64 | WorkflowID string 65 | RunID primitives.UUID 66 | ItemType int32 67 | ItemID int64 68 | Data []byte 69 | DataEncoding v1.EncodingType 70 | } 71 | 72 | type NamedItemRow struct { 73 | ShardID int32 74 | NamespaceID primitives.UUID 75 | WorkflowID string 76 | RunID primitives.UUID 77 | ItemType int32 78 | ItemName string 79 | Data []byte 80 | DataEncoding v1.EncodingType 81 | } 82 | 83 | type SignalRequestedItemRow struct { 84 | ShardID int32 85 | NamespaceID primitives.UUID 86 | WorkflowID string 87 | RunID primitives.UUID 88 | ItemName string 89 | } 90 | 91 | type ImmediateTaskRow struct { 92 | ShardID int32 93 | CategoryID int32 94 | ID int64 95 | Data []byte 96 | DataEncoding v1.EncodingType 97 | } 98 | 99 | type ScheduledTaskRow struct { 100 | ShardID int32 101 | CategoryID int32 102 | ID int64 103 | VisibilityTS time.Time 104 | Data []byte 105 | DataEncoding v1.EncodingType 106 | } 107 | 108 | type NodeRow struct { 109 | ShardID int32 110 | TreeID string 111 | BranchID string 112 | NodeID int64 113 | PrevTxnID int64 114 | TxnID int64 115 | Data []byte 116 | DataEncoding v1.EncodingType 117 | } 118 | 119 | type TreeRow struct { 120 | ShardID int32 121 | TreeID string 122 | BranchID string 123 | Branch []byte 124 | BranchEncoding v1.EncodingType 125 | } 126 | -------------------------------------------------------------------------------- /persistence/pkg/base/tokens/tokens.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import ( 4 | "encoding/json" 5 | "time" 6 | 7 | p "go.temporal.io/server/common/persistence" 8 | "go.temporal.io/server/common/primitives" 9 | ) 10 | 11 | type HistoryNodePageToken struct { 12 | LastNodeID int64 13 | LastTxnID int64 14 | } 15 | 16 | func (pt *HistoryNodePageToken) Serialize() ([]byte, error) { 17 | return json.Marshal(pt) 18 | } 19 | 20 | func (pt *HistoryNodePageToken) Deserialize(payload []byte) error { 21 | if len(payload) > 0 { 22 | return json.Unmarshal(payload, pt) 23 | } else { 24 | return nil 25 | } 26 | } 27 | 28 | type HistoryTreeBranchPageToken struct { 29 | ShardID int32 30 | TreeID string 31 | BranchID string 32 | } 33 | 34 | func (pt *HistoryTreeBranchPageToken) Serialize() ([]byte, error) { 35 | return json.Marshal(pt) 36 | } 37 | 38 | func (pt *HistoryTreeBranchPageToken) Deserialize(payload []byte) error { 39 | if len(payload) > 0 { 40 | if err := json.Unmarshal(payload, pt); err != nil { 41 | return err 42 | } 43 | } 44 | if pt.ShardID == 0 { 45 | pt.ShardID = 1 46 | } 47 | return nil 48 | } 49 | 50 | type ClusterMetadataPageToken struct { 51 | ClusterName string 52 | } 53 | 54 | func (pt *ClusterMetadataPageToken) Serialize() ([]byte, error) { 55 | return json.Marshal(pt) 56 | } 57 | 58 | func (pt *ClusterMetadataPageToken) Deserialize(payload []byte) error { 59 | if len(payload) > 0 { 60 | return json.Unmarshal(payload, pt) 61 | } else { 62 | return nil 63 | } 64 | } 65 | 66 | type TaskPageToken struct { 67 | TaskID int64 68 | } 69 | 70 | func (t *TaskPageToken) Serialize() ([]byte, error) { 71 | return json.Marshal(t) 72 | } 73 | 74 | func (t *TaskPageToken) Deserialize(payload []byte) error { 75 | if len(payload) > 0 { 76 | return json.Unmarshal(payload, t) 77 | } else { 78 | return nil 79 | } 80 | } 81 | 82 | type ScheduledTaskPageToken struct { 83 | TaskID int64 84 | Timestamp time.Time 85 | } 86 | 87 | func (t *ScheduledTaskPageToken) Serialize() ([]byte, error) { 88 | return json.Marshal(t) 89 | } 90 | 91 | func (t *ScheduledTaskPageToken) Deserialize(payload []byte) error { 92 | if len(payload) > 0 { 93 | return json.Unmarshal(payload, t) 94 | } else { 95 | return nil 96 | } 97 | } 98 | 99 | func GetImmediateTaskNextPageToken(lastTaskID int64, exclusiveMaxTaskID int64) ([]byte, error) { 100 | nextTaskID := lastTaskID + 1 101 | if nextTaskID < exclusiveMaxTaskID { 102 | token := TaskPageToken{TaskID: nextTaskID} 103 | return token.Serialize() 104 | } 105 | return nil, nil 106 | } 107 | 108 | func GetImmediateTaskReadRange(request *p.GetHistoryTasksRequest) (inclusiveMinTaskID int64, exclusiveMaxTaskID int64, err error) { 109 | inclusiveMinTaskID = request.InclusiveMinTaskKey.TaskID 110 | if len(request.NextPageToken) > 0 { 111 | var token TaskPageToken 112 | if err = token.Deserialize(request.NextPageToken); err != nil { 113 | return 0, 0, err 114 | } 115 | inclusiveMinTaskID = token.TaskID 116 | } 117 | return inclusiveMinTaskID, request.ExclusiveMaxTaskKey.TaskID, nil 118 | } 119 | 120 | type TaskQueueUserDataPageToken struct { 121 | LastTaskQueueName string 122 | } 123 | 124 | func (pt *TaskQueueUserDataPageToken) Serialize() ([]byte, error) { 125 | return json.Marshal(pt) 126 | } 127 | 128 | func (pt *TaskQueueUserDataPageToken) Deserialize(payload []byte) error { 129 | if len(payload) > 0 { 130 | return json.Unmarshal(payload, pt) 131 | } else { 132 | return nil 133 | } 134 | } 135 | 136 | type MatchingTaskPageToken struct { 137 | TaskID int64 138 | } 139 | 140 | func (pt *MatchingTaskPageToken) Serialize() ([]byte, error) { 141 | return json.Marshal(pt) 142 | } 143 | 144 | func (pt *MatchingTaskPageToken) Deserialize(payload []byte) error { 145 | if len(payload) > 0 { 146 | return json.Unmarshal(payload, pt) 147 | } else { 148 | return nil 149 | } 150 | } 151 | 152 | type ClusterMembersPageToken struct { 153 | LastSeenHostID primitives.UUID 154 | } 155 | 156 | func (pt *ClusterMembersPageToken) Serialize() ([]byte, error) { 157 | return []byte(pt.LastSeenHostID.String()), nil 158 | } 159 | 160 | func (pt *ClusterMembersPageToken) Deserialize(payload []byte) error { 161 | if lastSeenHostID, err := primitives.ParseUUID(string(payload)); err != nil { 162 | return err 163 | } else { 164 | pt.LastSeenHostID = lastSeenHostID 165 | return nil 166 | } 167 | } 168 | 169 | type QueueV2PageToken struct { 170 | LastReadQueueName string 171 | } 172 | 173 | func (qt *QueueV2PageToken) Serialize() ([]byte, error) { 174 | return []byte(qt.LastReadQueueName), nil 175 | } 176 | 177 | func (qt *QueueV2PageToken) Deserialize(payload []byte) error { 178 | qt.LastReadQueueName = string(payload) 179 | return nil 180 | } 181 | 182 | type NexusEndpointsPageToken struct { 183 | LastSeenEndpointID string 184 | } 185 | 186 | func (pt *NexusEndpointsPageToken) Serialize() []byte { 187 | return []byte(pt.LastSeenEndpointID) 188 | } 189 | 190 | func (pt *NexusEndpointsPageToken) Deserialize(payload []byte) { 191 | pt.LastSeenEndpointID = string(payload) 192 | } 193 | -------------------------------------------------------------------------------- /persistence/pkg/cache/basecache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "time" 7 | 8 | "go.temporal.io/server/common/log" 9 | "go.temporal.io/server/common/metrics" 10 | "go.temporal.io/server/service/history/tasks" 11 | ) 12 | 13 | type ( 14 | baseTaskCacheImpl struct { 15 | logger log.Logger 16 | metricsHandler metrics.Handler 17 | capacity int 18 | 19 | mu sync.Mutex 20 | 21 | validRange *keyRange 22 | tasks SortedTasks 23 | pendingTasks Tasks 24 | futureTasks Tasks 25 | } 26 | ) 27 | 28 | func (c *baseTaskCacheImpl) totalTasksLocked() int { 29 | return len(c.pendingTasks) + len(c.futureTasks) + len(c.tasks) 30 | } 31 | 32 | func (c *baseTaskCacheImpl) invalidateLocked() { 33 | c.validRange = nil 34 | c.tasks = nil 35 | c.pendingTasks = nil 36 | c.futureTasks = nil 37 | } 38 | 39 | func (c *baseTaskCacheImpl) Invalidate() { 40 | c.mu.Lock() 41 | defer c.mu.Unlock() 42 | 43 | c.invalidateLocked() 44 | } 45 | 46 | func (c *baseTaskCacheImpl) statisticsLocked() Statistics { 47 | return Statistics{ 48 | Capacity: c.capacity, 49 | TotalTasksCount: c.totalTasksLocked(), 50 | PendingTasksCount: len(c.pendingTasks), 51 | FutureTasksCount: len(c.futureTasks), 52 | TasksCount: len(c.tasks), 53 | ValidRange: c.validRange, 54 | } 55 | } 56 | 57 | func (c *baseTaskCacheImpl) Statistics() Statistics { 58 | c.mu.Lock() 59 | defer c.mu.Unlock() 60 | 61 | return c.statisticsLocked() 62 | } 63 | 64 | // putLocked puts the tasks into the cache. 65 | // It assumes that all tasks are within or after the valid range, never before. 66 | func (c *baseTaskCacheImpl) putLocked(tasks Tasks) error { 67 | var pendingCount int64 = 0 68 | var futureCount int64 = 0 69 | for _, task := range tasks { 70 | pos := c.validRange.compareTo(task.Key) 71 | if pos == -1 { 72 | c.futureTasks = append(c.futureTasks, task) 73 | futureCount++ 74 | } else { 75 | c.pendingTasks = append(c.pendingTasks, task) 76 | pendingCount++ 77 | } 78 | } 79 | c.metricsHandler.Counter(tasksPutCounter.Name()).Record(pendingCount) 80 | c.metricsHandler.Counter(futureTasksPutCounter.Name()).Record(futureCount) 81 | 82 | if c.totalTasksLocked() > c.capacity { 83 | c.compactLocked() 84 | 85 | if len(c.futureTasks) > len(c.tasks) { 86 | err := newTooManyFutureTasks(c.statisticsLocked()) 87 | c.invalidateLocked() 88 | return err 89 | } 90 | } 91 | 92 | return nil 93 | } 94 | 95 | func (c *baseTaskCacheImpl) Expand(inclusiveFromKey, inclusiveToKey tasks.Key, newTasks Tasks) (err error) { 96 | defer func() { 97 | if err != nil { 98 | c.logger.Error(err.Error()) 99 | c.metricsHandler.Counter(writeErrorCounter.Name()).Record(1) 100 | } 101 | }() 102 | 103 | expansionRange := newKeyRange(inclusiveFromKey, inclusiveToKey) 104 | 105 | if expansionRange.inclusiveFrom.CompareTo(expansionRange.inclusiveTo) == 1 { 106 | return newInvalidRequestError(fmt.Sprintf("invalid expansion range: inclusive from (%s) > exclusive to (%s)", 107 | formatKey(expansionRange.inclusiveFrom), formatKey(expansionRange.inclusiveTo))) 108 | } 109 | 110 | tasksRange := newTasks.getRange() 111 | if !expansionRange.contains(tasksRange) { 112 | return newInvalidRequestError(fmt.Sprintf("invalid expansion range: %s does not contain tasks range %s", 113 | expansionRange.String(), tasksRange.String())) 114 | } 115 | 116 | c.mu.Lock() 117 | defer c.mu.Unlock() 118 | 119 | // First we update the valid range, as putLocked separates current tasks from future tasks 120 | if c.validRange == nil { 121 | c.validRange = &expansionRange 122 | } else { 123 | validRangePlusOne := newKeyRange(c.validRange.inclusiveFrom, c.validRange.inclusiveTo.Next()) 124 | if validRangePlusOne.compareTo(expansionRange.inclusiveFrom) != 0 { 125 | return newInvalidRequestError( 126 | fmt.Sprintf("invalid expansion range: inclusive from key %s is not in valid range and does not touch it from the right: %s", 127 | formatKey(expansionRange.inclusiveFrom), c.validRange.String())) 128 | } 129 | c.validRange.inclusiveTo = tasks.MaxKey(c.validRange.inclusiveTo, expansionRange.inclusiveTo) 130 | } 131 | 132 | // Take the portion of future tasks that is covered by inclusiveToKey 133 | currTasks, futureTasks := c.futureTasks.sort().splitBy(expansionRange.inclusiveTo) 134 | c.futureTasks = Tasks(futureTasks) 135 | 136 | // Put the tasks 137 | return c.putLocked(append(newTasks, currTasks...)) 138 | } 139 | 140 | func (c *baseTaskCacheImpl) Put(tasks Tasks) (err error) { 141 | if len(tasks) == 0 { 142 | return nil 143 | } 144 | c.mu.Lock() 145 | defer c.mu.Unlock() 146 | 147 | if c.validRange == nil { 148 | return newInvalidCacheError() 149 | } 150 | 151 | defer func() { 152 | if err != nil { 153 | c.logger.Error(err.Error()) 154 | c.metricsHandler.Counter(writeErrorCounter.Name()).Record(1) 155 | } 156 | }() 157 | 158 | for _, task := range tasks { 159 | if c.validRange.compareTo(task.Key) == 1 { 160 | return newInvalidRequestError(fmt.Sprintf("task %s is below the valid range: %s", 161 | formatKey(task.Key), c.validRange)) 162 | } 163 | } 164 | 165 | return c.putLocked(tasks) 166 | } 167 | 168 | func (c *baseTaskCacheImpl) compactLocked() { 169 | mergedTasks := c.tasks.merge(c.pendingTasks.sort()) 170 | capLeft := c.capacity - len(mergedTasks) - len(c.futureTasks) 171 | if capLeft < 0 { 172 | s := min(-capLeft, len(mergedTasks)) 173 | c.validRange.inclusiveFrom = mergedTasks[s-1].Key.Next() 174 | mergedTasks = mergedTasks[s:] 175 | } 176 | c.tasks = mergedTasks 177 | c.pendingTasks = nil 178 | } 179 | 180 | func (c *baseTaskCacheImpl) list(requestedRange keyRange, batchSize int32) (ok bool, ts SortedTasks, err error) { 181 | c.mu.Lock() 182 | defer c.mu.Unlock() 183 | 184 | if c.validRange == nil { 185 | return false, nil, newInvalidCacheError() 186 | } 187 | 188 | defer func() { 189 | if err != nil { 190 | c.logger.Error(err.Error()) 191 | c.metricsHandler.Counter(readErrorCounter.Name()).Record(1) 192 | } 193 | }() 194 | 195 | c.compactLocked() 196 | 197 | if c.validRange.contains(requestedRange) { 198 | res := c.tasks.query(requestedRange, int(batchSize)) 199 | c.metricsHandler.Counter(hitCounter.Name()).Record(1) 200 | c.metricsHandler.Counter(tasksReadCounter.Name()).Record(int64(len(res))) 201 | return true, res, nil 202 | } else { 203 | c.metricsHandler.Counter(missCounter.Name()).Record(1) 204 | return false, nil, nil 205 | } 206 | } 207 | 208 | func (c *baseTaskCacheImpl) Remove(exclusiveMaxTaskKey tasks.Key) (err error) { 209 | c.mu.Lock() 210 | defer c.mu.Unlock() 211 | 212 | if c.validRange == nil { 213 | return newInvalidCacheError() 214 | } 215 | 216 | defer func() { 217 | if err != nil { 218 | c.logger.Error(err.Error()) 219 | c.metricsHandler.Counter(writeErrorCounter.Name()).Record(1) 220 | } 221 | }() 222 | 223 | c.compactLocked() 224 | 225 | inclusiveMaxTaskKey := exclusiveMaxTaskKey.Prev() 226 | pos := c.validRange.compareTo(inclusiveMaxTaskKey) 227 | if pos == -1 { 228 | return newInvalidRequestError( 229 | fmt.Sprintf("exclusive max key %s is above the valid range: %s", 230 | formatKey(exclusiveMaxTaskKey), c.validRange.String())) 231 | } else if pos == 1 { 232 | return nil 233 | } else { 234 | _, c.tasks = c.tasks.splitBy(inclusiveMaxTaskKey) 235 | } 236 | return nil 237 | } 238 | 239 | func formatKey(key tasks.Key) string { 240 | if key.FireTime.IsZero() { 241 | return fmt.Sprintf("(%d)", key.TaskID) 242 | } else { 243 | return fmt.Sprintf("(%s, %d)", key.FireTime.Format(time.StampMilli), key.TaskID) 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /persistence/pkg/cache/cache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | 6 | "go.temporal.io/server/common/log" 7 | "go.temporal.io/server/common/metrics" 8 | "go.temporal.io/server/service/history/tasks" 9 | ) 10 | 11 | type ( 12 | ImmediateCacheSuggestion struct { 13 | InclusiveMinTaskID int64 14 | ExclusiveMaxTaskID int64 15 | BatchSize int32 16 | } 17 | 18 | ImmediateTaskCache interface { 19 | Invalidate() 20 | Put(tasks Tasks) error 21 | List(inclusiveMinTaskID, exclusiveMaxTaskID int64, batchSize int32) (bool, SortedTasks, error) 22 | Remove(exclusiveMaxTaskKey tasks.Key) error 23 | Expand(inclusiveFromKey, inclusiveToKey tasks.Key, tasks Tasks) error 24 | SuggestedExpansion() *ImmediateCacheSuggestion 25 | Statistics() Statistics 26 | } 27 | 28 | immediateTaskCacheImpl struct { 29 | baseTaskCacheImpl 30 | } 31 | 32 | noopImmediateTaskCacheImpl struct { 33 | } 34 | ) 35 | 36 | func NewNoopImmediateTaskCache() ImmediateTaskCache { 37 | return noopImmediateTaskCacheImpl{} 38 | } 39 | 40 | func (n noopImmediateTaskCacheImpl) Invalidate() { 41 | } 42 | 43 | func (n noopImmediateTaskCacheImpl) Put(tasks Tasks) error { 44 | return nil 45 | } 46 | 47 | func (n noopImmediateTaskCacheImpl) List(inclusiveMinTaskID, exclusiveMaxTaskID int64, batchSize int32) (bool, SortedTasks, error) { 48 | return false, nil, nil 49 | } 50 | 51 | func (n noopImmediateTaskCacheImpl) Remove(exclusiveMaxTaskKey tasks.Key) error { 52 | return nil 53 | } 54 | 55 | func (n noopImmediateTaskCacheImpl) Expand(inclusiveFromKey, inclusiveToKey tasks.Key, tasks Tasks) error { 56 | return nil 57 | } 58 | 59 | func (n noopImmediateTaskCacheImpl) SuggestedExpansion() *ImmediateCacheSuggestion { 60 | return nil 61 | } 62 | 63 | func (n noopImmediateTaskCacheImpl) Statistics() Statistics { 64 | return Statistics{} 65 | } 66 | 67 | func NewImmediateTaskCache(logger log.Logger, metricsHandler metrics.Handler, capacity int) ImmediateTaskCache { 68 | if capacity < 1 { 69 | panic(fmt.Sprintf("invalid capacity: %d", capacity)) 70 | } 71 | return &immediateTaskCacheImpl{ 72 | baseTaskCacheImpl: baseTaskCacheImpl{ 73 | logger: logger, 74 | metricsHandler: metricsHandler, 75 | capacity: capacity, 76 | }, 77 | } 78 | } 79 | 80 | func (c *immediateTaskCacheImpl) List(inclusiveMinTaskID, exclusiveMaxTaskID int64, batchSize int32) (bool, SortedTasks, error) { 81 | if inclusiveMinTaskID > exclusiveMaxTaskID { 82 | return false, nil, newInvalidRequestError(fmt.Sprintf( 83 | "invalid requested range: inclusive min (%d) > exclusive max (%d)", 84 | inclusiveMinTaskID, exclusiveMaxTaskID)) 85 | } else if inclusiveMinTaskID == exclusiveMaxTaskID { 86 | // for some reason persistence gets this kind of requests: inclusive min == exclusive max, always empty 87 | c.metricsHandler.Counter(emptyHitCounter.Name()).Record(1) 88 | return true, nil, nil 89 | } 90 | 91 | requestedRange := newKeyRange( 92 | tasks.NewImmediateKey(inclusiveMinTaskID), 93 | tasks.NewImmediateKey(exclusiveMaxTaskID).Prev(), 94 | ) 95 | return c.list(requestedRange, batchSize) 96 | } 97 | 98 | func (c *immediateTaskCacheImpl) SuggestedExpansion() *ImmediateCacheSuggestion { 99 | c.mu.Lock() 100 | defer c.mu.Unlock() 101 | 102 | if c.validRange == nil { 103 | return nil 104 | } 105 | 106 | capLeft := c.capacity - c.totalTasksLocked() 107 | batch := min(defaultExpansionBatch, capLeft) 108 | if batch < minExpansionBatch { 109 | return nil 110 | } 111 | 112 | var exclusiveMaxTaskID int64 113 | // probably not the best logic, just adapting to a scheduled cache mechanics 114 | if len(c.futureTasks) == 0 { 115 | exclusiveMaxTaskID = c.validRange.inclusiveTo.TaskID + 1000 116 | } else { 117 | exclusiveMaxTaskID = c.futureTasks[len(c.futureTasks)-1].Key.TaskID + 1 118 | } 119 | return &ImmediateCacheSuggestion{ 120 | InclusiveMinTaskID: c.validRange.inclusiveTo.TaskID, 121 | ExclusiveMaxTaskID: exclusiveMaxTaskID, 122 | BatchSize: int32(batch), 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /persistence/pkg/cache/errors.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | type TooManyFutureTasks struct { 8 | stats Statistics 9 | } 10 | 11 | func newTooManyFutureTasks(stats Statistics) *TooManyFutureTasks { 12 | return &TooManyFutureTasks{ 13 | stats: stats, 14 | } 15 | } 16 | 17 | func (e *TooManyFutureTasks) Error() string { 18 | return fmt.Sprintf("too many future tasks (%s)", e.stats.String()) 19 | } 20 | 21 | type InvalidCacheError struct { 22 | } 23 | 24 | func newInvalidCacheError() *InvalidCacheError { 25 | return &InvalidCacheError{} 26 | } 27 | 28 | func (e *InvalidCacheError) Error() string { 29 | return "cache is not valid" 30 | } 31 | 32 | type InvalidRequestError struct { 33 | message string 34 | } 35 | 36 | func newInvalidRequestError(message string) *InvalidRequestError { 37 | return &InvalidRequestError{ 38 | message: message, 39 | } 40 | } 41 | 42 | func (e *InvalidRequestError) Error() string { 43 | return e.message 44 | } 45 | -------------------------------------------------------------------------------- /persistence/pkg/cache/factory.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | 10 | "go.temporal.io/server/common/log" 11 | "go.temporal.io/server/common/metrics" 12 | "go.temporal.io/server/common/persistence" 13 | "go.temporal.io/server/service/history/tasks" 14 | 15 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/executor" 16 | ) 17 | 18 | type ( 19 | TaskCacheFactory interface { 20 | GetOrCreateImmediateTaskCache(shardID int32, category tasks.Category) ImmediateTaskCache 21 | GetOrCreateScheduledTaskCache(shardID int32, category tasks.Category) ScheduledTaskCache 22 | InvalidateAllCaches(shardID int32) 23 | Stop() 24 | } 25 | 26 | taskCacheFactoryImpl struct { 27 | logger log.Logger 28 | metricsHandler metrics.Handler 29 | cacheCapacity int 30 | transferTaskCaches map[int32]ImmediateTaskCache 31 | visibilityTaskCaches map[int32]ImmediateTaskCache 32 | replicationTaskCaches map[int32]ImmediateTaskCache 33 | timerTaskCaches map[int32]ScheduledTaskCache 34 | m sync.Mutex 35 | 36 | metricsReporter cacheMetricsReporter 37 | } 38 | noopTaskCacheFactoryImpl struct { 39 | } 40 | ) 41 | 42 | func NewNoopTaskCacheFactory() TaskCacheFactory { 43 | return noopTaskCacheFactoryImpl{} 44 | } 45 | 46 | func (n noopTaskCacheFactoryImpl) GetOrCreateImmediateTaskCache(shardID int32, category tasks.Category) ImmediateTaskCache { 47 | return NewNoopImmediateTaskCache() 48 | } 49 | 50 | func (n noopTaskCacheFactoryImpl) GetOrCreateScheduledTaskCache(shardID int32, category tasks.Category) ScheduledTaskCache { 51 | return NewNoopScheduledTaskCache() 52 | } 53 | 54 | func (n noopTaskCacheFactoryImpl) Stop() { 55 | } 56 | 57 | func (n noopTaskCacheFactoryImpl) InvalidateAllCaches(shardID int32) { 58 | } 59 | 60 | func NewTaskCacheFactory(logger log.Logger, metricsHandler metrics.Handler, cacheCapacity int) TaskCacheFactory { 61 | f := &taskCacheFactoryImpl{ 62 | logger: logger, 63 | metricsHandler: metricsHandler, 64 | cacheCapacity: cacheCapacity, 65 | transferTaskCaches: make(map[int32]ImmediateTaskCache), 66 | visibilityTaskCaches: make(map[int32]ImmediateTaskCache), 67 | replicationTaskCaches: make(map[int32]ImmediateTaskCache), 68 | timerTaskCaches: make(map[int32]ScheduledTaskCache), 69 | } 70 | metricsReporter := cacheMetricsReporter{ 71 | handler: metricsHandler, 72 | reportInterval: 30 * time.Second, 73 | quit: make(chan struct{}), 74 | logger: logger, 75 | impl: f, 76 | } 77 | f.metricsReporter = metricsReporter 78 | metricsReporter.Start() 79 | return f 80 | } 81 | 82 | func (f *taskCacheFactoryImpl) GetOrCreateImmediateTaskCache(shardID int32, category tasks.Category) ImmediateTaskCache { 83 | f.m.Lock() 84 | defer f.m.Unlock() 85 | 86 | var lookup map[int32]ImmediateTaskCache 87 | switch category { 88 | case tasks.CategoryTransfer: 89 | lookup = f.transferTaskCaches 90 | case tasks.CategoryVisibility: 91 | lookup = f.visibilityTaskCaches 92 | case tasks.CategoryReplication: 93 | lookup = f.visibilityTaskCaches 94 | default: 95 | panic("unknown immediate task category") 96 | } 97 | if _, ok := lookup[shardID]; !ok { 98 | lookup[shardID] = NewImmediateTaskCache(f.logger, f.metricsHandler.WithTags(metrics.TaskCategoryTag(category.Name())), f.cacheCapacity) 99 | } 100 | return lookup[shardID] 101 | } 102 | 103 | func (f *taskCacheFactoryImpl) GetOrCreateScheduledTaskCache(shardID int32, category tasks.Category) ScheduledTaskCache { 104 | f.m.Lock() 105 | defer f.m.Unlock() 106 | 107 | var lookup map[int32]ScheduledTaskCache 108 | switch category { 109 | case tasks.CategoryTimer: 110 | lookup = f.timerTaskCaches 111 | default: 112 | panic("unknown scheduled task category") 113 | } 114 | if _, ok := lookup[shardID]; !ok { 115 | lookup[shardID] = NewScheduledTaskCache(f.logger, f.metricsHandler.WithTags(metrics.TaskCategoryTag(category.Name())), f.cacheCapacity) 116 | } 117 | return lookup[shardID] 118 | } 119 | 120 | func (f *taskCacheFactoryImpl) InvalidateAllCaches(shardID int32) { 121 | f.m.Lock() 122 | defer f.m.Unlock() 123 | 124 | if c, ok := f.transferTaskCaches[shardID]; ok { 125 | c.Invalidate() 126 | } 127 | if c, ok := f.visibilityTaskCaches[shardID]; ok { 128 | c.Invalidate() 129 | } 130 | if c, ok := f.replicationTaskCaches[shardID]; ok { 131 | c.Invalidate() 132 | } 133 | if c, ok := f.timerTaskCaches[shardID]; ok { 134 | c.Invalidate() 135 | } 136 | } 137 | 138 | func (f *taskCacheFactoryImpl) Stop() { 139 | f.metricsReporter.Stop() 140 | } 141 | 142 | func (f *taskCacheFactoryImpl) getStatistics() Statistics { 143 | f.m.Lock() 144 | defer f.m.Unlock() 145 | 146 | total := Statistics{} 147 | add := func(s Statistics) { 148 | total.TotalTasksCount += s.TotalTasksCount 149 | total.PendingTasksCount += s.PendingTasksCount 150 | total.FutureTasksCount += s.FutureTasksCount 151 | total.TasksCount += s.TasksCount 152 | } 153 | for _, c := range f.transferTaskCaches { 154 | add(c.Statistics()) 155 | } 156 | for _, c := range f.visibilityTaskCaches { 157 | add(c.Statistics()) 158 | } 159 | for _, c := range f.replicationTaskCaches { 160 | add(c.Statistics()) 161 | } 162 | for _, c := range f.timerTaskCaches { 163 | add(c.Statistics()) 164 | } 165 | return total 166 | } 167 | 168 | type cacheMetricsReporter struct { 169 | handler metrics.Handler 170 | reportInterval time.Duration 171 | started int32 172 | quit chan struct{} 173 | logger log.Logger 174 | impl *taskCacheFactoryImpl 175 | } 176 | 177 | func (r *cacheMetricsReporter) Start() { 178 | if !atomic.CompareAndSwapInt32(&r.started, 0, 1) { 179 | return 180 | } 181 | r.report() 182 | go func() { 183 | ticker := time.NewTicker(r.reportInterval) 184 | for { 185 | select { 186 | case <-ticker.C: 187 | r.report() 188 | case <-r.quit: 189 | ticker.Stop() 190 | return 191 | } 192 | } 193 | }() 194 | r.logger.Info("cacheMetricsReporter started") 195 | } 196 | 197 | func (r *cacheMetricsReporter) Stop() { 198 | close(r.quit) 199 | r.logger.Info("cacheMetricsReporter stopped") 200 | } 201 | 202 | func (r *cacheMetricsReporter) report() { 203 | s := r.impl.getStatistics() 204 | if r.handler != nil { 205 | r.handler.Gauge(futureTasksGauge.Name()).Record(float64(s.FutureTasksCount)) 206 | r.handler.Gauge(pendingTasksGauge.Name()).Record(float64(s.PendingTasksCount)) 207 | r.handler.Gauge(totalTasksGauge.Name()).Record(float64(s.TotalTasksCount)) 208 | } else { 209 | fmt.Printf("handler: %s, STATISTICS: %+v\n", r.handler, s) 210 | } 211 | } 212 | 213 | type eventsCache struct { 214 | taskCacheFactory TaskCacheFactory 215 | } 216 | 217 | func (d *eventsCache) Invalidate(ctx context.Context, shardID int32) error { 218 | d.taskCacheFactory.InvalidateAllCaches(shardID) 219 | return nil 220 | } 221 | 222 | func (d *eventsCache) Put(ctx context.Context, shardID int32, tsss ...map[tasks.Category][]persistence.InternalHistoryTask) error { 223 | res := make(map[tasks.Category][]persistence.InternalHistoryTask) 224 | 225 | for _, tss := range tsss { 226 | for category, ts := range tss { 227 | res[category] = append(res[category], ts...) 228 | } 229 | } 230 | 231 | for category, ts := range res { 232 | switch category.Type() { 233 | case tasks.CategoryTypeImmediate: 234 | _ = d.taskCacheFactory.GetOrCreateImmediateTaskCache(shardID, category).Put(ts) 235 | case tasks.CategoryTypeScheduled: 236 | _ = d.taskCacheFactory.GetOrCreateScheduledTaskCache(shardID, category).Put(ts) 237 | } 238 | } 239 | return nil 240 | } 241 | 242 | func NewEventsCache(taskCacheFactory TaskCacheFactory) executor.EventsCache { 243 | return &eventsCache{taskCacheFactory: taskCacheFactory} 244 | } 245 | -------------------------------------------------------------------------------- /persistence/pkg/cache/metrics.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import "go.temporal.io/server/common/metrics" 4 | 5 | var prefix = "tasks_cache_" 6 | var hitCounter = metrics.NewCounterDef(prefix + "hit") 7 | var emptyHitCounter = metrics.NewCounterDef(prefix + "empty_hit") 8 | var missCounter = metrics.NewCounterDef(prefix + "miss") 9 | var tasksPutCounter = metrics.NewCounterDef(prefix + "tasks_put") 10 | var futureTasksPutCounter = metrics.NewCounterDef(prefix + "future_tasks_put") 11 | var tasksReadCounter = metrics.NewCounterDef(prefix + "tasks_read") 12 | var writeErrorCounter = metrics.NewCounterDef(prefix + "write_error") 13 | var readErrorCounter = metrics.NewCounterDef(prefix + "read_error") 14 | 15 | var futureTasksGauge = metrics.NewGaugeDef(prefix + "future_tasks") 16 | var pendingTasksGauge = metrics.NewGaugeDef(prefix + "pending_tasks") 17 | var totalTasksGauge = metrics.NewGaugeDef(prefix + "tasks") 18 | -------------------------------------------------------------------------------- /persistence/pkg/cache/range.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | 6 | "go.temporal.io/server/service/history/tasks" 7 | ) 8 | 9 | type keyRange struct { 10 | inclusiveFrom tasks.Key 11 | inclusiveTo tasks.Key 12 | } 13 | 14 | func newKeyRange(inclusiveFrom, inclusiveTo tasks.Key) keyRange { 15 | return keyRange{ 16 | inclusiveFrom: inclusiveFrom, 17 | inclusiveTo: inclusiveTo, 18 | } 19 | } 20 | 21 | func (r *keyRange) String() string { 22 | return fmt.Sprintf("[%s, %s]", formatKey(r.inclusiveFrom), formatKey(r.inclusiveTo)) 23 | } 24 | 25 | func (r *keyRange) contains(other keyRange) bool { 26 | return r.inclusiveFrom.CompareTo(other.inclusiveFrom) <= 0 && r.inclusiveTo.CompareTo(other.inclusiveTo) >= 0 27 | } 28 | 29 | // compareTo returns 30 | // -1 if the range is before k 31 | // 0 if the range contains k 32 | // 1 if the range is after k 33 | func (r *keyRange) compareTo(k tasks.Key) int { 34 | if k.CompareTo(r.inclusiveFrom) == -1 { 35 | // . k . [from . . . to] . . . 36 | return 1 37 | } 38 | if k.CompareTo(r.inclusiveTo) == 1 { 39 | // . . . [from . . . to] . . . k 40 | return -1 41 | } 42 | return 0 43 | } 44 | -------------------------------------------------------------------------------- /persistence/pkg/cache/scheduledcache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "time" 7 | 8 | "go.temporal.io/server/common/log" 9 | "go.temporal.io/server/common/metrics" 10 | "go.temporal.io/server/service/history/tasks" 11 | ) 12 | 13 | const defaultExpansionBatch = 300 14 | const minExpansionBatch = 10 15 | const cacheAheadPeriod = time.Minute * 30 16 | 17 | type ( 18 | ScheduledCacheSuggestion struct { 19 | InclusiveMinKey tasks.Key 20 | ExclusiveMaxVisibilityTS time.Time 21 | BatchSize int32 22 | } 23 | 24 | Statistics struct { 25 | Capacity int 26 | TotalTasksCount int 27 | PendingTasksCount int 28 | FutureTasksCount int 29 | TasksCount int 30 | ValidRange *keyRange 31 | } 32 | 33 | ScheduledTaskCache interface { 34 | Invalidate() 35 | Put(tasks Tasks) error 36 | List(inclusiveMinTaskID int64, inclusiveMinVisibilityTS, exclusiveMaxVisibilityTS time.Time, batchSize int32) (bool, SortedTasks, error) 37 | Remove(exclusiveMaxTaskKey tasks.Key) error 38 | Expand(inclusiveFromKey, inclusiveToKey tasks.Key, tasks Tasks) error 39 | SuggestedExpansion() *ScheduledCacheSuggestion 40 | Statistics() Statistics 41 | } 42 | 43 | scheduledTaskCacheImpl struct { 44 | baseTaskCacheImpl 45 | } 46 | 47 | noopScheduledTaskCacheImpl struct { 48 | } 49 | ) 50 | 51 | func (n noopScheduledTaskCacheImpl) Invalidate() { 52 | } 53 | 54 | func (n noopScheduledTaskCacheImpl) Put(tasks Tasks) error { 55 | return nil 56 | } 57 | 58 | func (n noopScheduledTaskCacheImpl) List(inclusiveMinTaskID int64, inclusiveMinVisibilityTS, exclusiveMaxVisibilityTS time.Time, batchSize int32) (bool, SortedTasks, error) { 59 | return false, nil, nil 60 | } 61 | 62 | func (n noopScheduledTaskCacheImpl) Remove(exclusiveMaxTaskKey tasks.Key) error { 63 | return nil 64 | } 65 | 66 | func (n noopScheduledTaskCacheImpl) Expand(inclusiveFromKey, inclusiveToKey tasks.Key, tasks Tasks) error { 67 | return nil 68 | } 69 | 70 | func (n noopScheduledTaskCacheImpl) SuggestedExpansion() *ScheduledCacheSuggestion { 71 | return nil 72 | } 73 | 74 | func (n noopScheduledTaskCacheImpl) Statistics() Statistics { 75 | return Statistics{} 76 | } 77 | 78 | func (s *Statistics) String() string { 79 | validRange := "nil" 80 | if s.ValidRange != nil { 81 | validRange = s.ValidRange.String() 82 | } 83 | return fmt.Sprintf("cap: %d, pending tasks: %d, future tasks: %d, tasks: %d, valid range: %s", 84 | s.Capacity, s.PendingTasksCount, s.FutureTasksCount, s.TasksCount, validRange) 85 | } 86 | 87 | func NewNoopScheduledTaskCache() ScheduledTaskCache { 88 | return &noopScheduledTaskCacheImpl{} 89 | } 90 | 91 | func NewScheduledTaskCache(logger log.Logger, metricsHandler metrics.Handler, capacity int) ScheduledTaskCache { 92 | if capacity < 1 { 93 | panic(fmt.Sprintf("invalid capacity: %d", capacity)) 94 | } 95 | return &scheduledTaskCacheImpl{ 96 | baseTaskCacheImpl: baseTaskCacheImpl{ 97 | logger: logger, 98 | metricsHandler: metricsHandler, 99 | capacity: capacity, 100 | }, 101 | } 102 | } 103 | 104 | func (c *scheduledTaskCacheImpl) List(inclusiveMinTaskID int64, inclusiveMinVisibilityTS, exclusiveMaxVisibilityTS time.Time, batchSize int32) (ok bool, ts SortedTasks, err error) { 105 | if inclusiveMinVisibilityTS.Compare(exclusiveMaxVisibilityTS) == 1 { 106 | return false, nil, newInvalidRequestError(fmt.Sprintf( 107 | "invalid requested range: inclusive min (%s) > exclusive max (%s)", 108 | inclusiveMinVisibilityTS, exclusiveMaxVisibilityTS)) 109 | } else if inclusiveMinVisibilityTS.Compare(exclusiveMaxVisibilityTS) == 0 { 110 | // for some reason persistence gets this kind of requests: inclusive min == exclusive max, always empty 111 | c.metricsHandler.Counter(emptyHitCounter.Name()).Record(1) 112 | return true, nil, nil 113 | } 114 | 115 | requestedRange := newKeyRange( 116 | tasks.NewKey(inclusiveMinVisibilityTS, inclusiveMinTaskID), 117 | tasks.NewKey(exclusiveMaxVisibilityTS, 0).Prev(), 118 | ) 119 | return c.list(requestedRange, batchSize) 120 | } 121 | 122 | func (c *scheduledTaskCacheImpl) SuggestedExpansion() *ScheduledCacheSuggestion { 123 | c.mu.Lock() 124 | defer c.mu.Unlock() 125 | 126 | if c.validRange == nil { 127 | return &ScheduledCacheSuggestion{ 128 | InclusiveMinKey: tasks.NewKey(time.Now(), 0), 129 | BatchSize: defaultExpansionBatch, 130 | } 131 | } 132 | 133 | capLeft := c.capacity - c.totalTasksLocked() 134 | batch := min(defaultExpansionBatch, capLeft) 135 | if batch < minExpansionBatch { 136 | return nil 137 | } 138 | if time.Until(c.validRange.inclusiveTo.FireTime).Seconds() > cacheAheadPeriod.Seconds() { 139 | return nil 140 | } 141 | randomDuration := time.Minute * time.Duration(5+rand.Intn(15)) // between 5 and 20 minutes 142 | exclusiveMaxVisibilityTS := c.validRange.inclusiveTo.FireTime.Add(randomDuration) 143 | return &ScheduledCacheSuggestion{ 144 | InclusiveMinKey: c.validRange.inclusiveTo.Next(), 145 | ExclusiveMaxVisibilityTS: exclusiveMaxVisibilityTS, 146 | BatchSize: int32(batch), 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /persistence/pkg/cache/tasks.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "slices" 5 | "sort" 6 | 7 | p "go.temporal.io/server/common/persistence" 8 | "go.temporal.io/server/service/history/tasks" 9 | ) 10 | 11 | type Tasks []p.InternalHistoryTask 12 | 13 | // sort sorts the tasks by key 14 | func (ts Tasks) sort() SortedTasks { 15 | slices.SortFunc(ts, func(a, b p.InternalHistoryTask) int { 16 | return a.Key.CompareTo(b.Key) 17 | }) 18 | return SortedTasks(ts) 19 | } 20 | 21 | // getRange returns the range of keys in the tasks. 22 | // Must not be called on an empty list. 23 | func (ts Tasks) getRange() keyRange { 24 | minKey := tasks.MaximumKey 25 | maxKey := tasks.MinimumKey 26 | for _, t := range ts { 27 | minKey = tasks.MinKey(minKey, t.Key) 28 | maxKey = tasks.MaxKey(maxKey, t.Key) 29 | } 30 | return newKeyRange(minKey, maxKey) 31 | } 32 | 33 | func (ts Tasks) String() string { 34 | return formatTasks(ts) 35 | } 36 | 37 | type SortedTasks []p.InternalHistoryTask 38 | 39 | // splitBy splits the tasks into two parts: the first part contains tasks with key <= key, the second part contains the rest 40 | func (ts SortedTasks) splitBy(key tasks.Key) (SortedTasks, SortedTasks) { 41 | i := sort.Search(len(ts), func(i int) bool { 42 | return ts[i].Key.CompareTo(key) > 0 43 | }) 44 | return ts[:i], ts[i:] 45 | } 46 | 47 | // query returns a batch of tasks that are in the range r 48 | func (ts SortedTasks) query(r keyRange, batchSize int) SortedTasks { 49 | s := sort.Search(len(ts), func(i int) bool { 50 | return ts[i].Key.CompareTo(r.inclusiveFrom) >= 0 51 | }) 52 | e := sort.Search(len(ts), func(i int) bool { 53 | return ts[i].Key.CompareTo(r.inclusiveTo) > 0 54 | }) 55 | if e-s > batchSize { 56 | e = s + batchSize 57 | } 58 | return ts[s:e] 59 | } 60 | 61 | // merge merges two sorted task lists into one 62 | func (ts SortedTasks) merge(ts2 SortedTasks) SortedTasks { 63 | var res SortedTasks 64 | if len(ts2) == 0 { 65 | res = ts 66 | } else if len(ts) == 0 { 67 | res = ts2 68 | } else { 69 | n := len(ts) + len(ts2) 70 | res = make(SortedTasks, 0, n) 71 | i := 0 72 | j := 0 73 | for i < len(ts) && j < len(ts2) { 74 | task1 := ts[i] 75 | task2 := ts2[j] 76 | if task2.Key.CompareTo(task1.Key) <= 0 { 77 | res = append(res, task2) 78 | j++ 79 | } else { 80 | res = append(res, task1) 81 | i++ 82 | } 83 | } 84 | res = append(res, ts[i:]...) 85 | res = append(res, ts2[j:]...) 86 | } 87 | res = slices.CompactFunc(res, func(a p.InternalHistoryTask, b p.InternalHistoryTask) bool { 88 | return a.Key.CompareTo(b.Key) == 0 89 | }) 90 | return res 91 | } 92 | 93 | func (ts SortedTasks) String() string { 94 | return formatTasks(ts) 95 | } 96 | 97 | func formatTasks(ts []p.InternalHistoryTask) string { 98 | res := "" 99 | for _, t := range ts { 100 | res += formatKey(t.Key) + ", " 101 | } 102 | return res 103 | } 104 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/cluster_metadata_store.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "strings" 8 | "time" 9 | 10 | "github.com/pborman/uuid" 11 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 12 | "github.com/ydb-platform/ydb-go-sdk/v3/table/result" 13 | "github.com/ydb-platform/ydb-go-sdk/v3/table/result/named" 14 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 15 | commonpb "go.temporal.io/api/common/v1" 16 | enumspb "go.temporal.io/api/enums/v1" 17 | "go.temporal.io/api/serviceerror" 18 | "go.temporal.io/server/common/log" 19 | p "go.temporal.io/server/common/persistence" 20 | "go.temporal.io/server/common/primitives" 21 | 22 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/tokens" 23 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 24 | ) 25 | 26 | type ( 27 | ClusterMetadataStore struct { 28 | client *conn.Client 29 | logger log.Logger 30 | } 31 | ) 32 | 33 | var _ p.ClusterMetadataStore = (*ClusterMetadataStore)(nil) 34 | 35 | // NewClusterMetadataStore is used to create an instance of ClusterMetadataStore implementation 36 | func NewClusterMetadataStore( 37 | client *conn.Client, 38 | logger log.Logger, 39 | ) (p.ClusterMetadataStore, error) { 40 | return &ClusterMetadataStore{ 41 | client: client, 42 | logger: logger, 43 | }, nil 44 | } 45 | 46 | func (m *ClusterMetadataStore) ListClusterMetadata( 47 | ctx context.Context, 48 | request *p.InternalListClusterMetadataRequest, 49 | ) (resp *p.InternalListClusterMetadataResponse, err error) { 50 | defer func() { 51 | if err != nil { 52 | err = conn.ConvertToTemporalError("ListClusterMetadata", err) 53 | } 54 | }() 55 | 56 | var pageToken tokens.ClusterMetadataPageToken 57 | if err = pageToken.Deserialize(request.NextPageToken); err != nil { 58 | return nil, err 59 | } 60 | 61 | params := table.NewQueryParameters( 62 | table.ValueParam("$cluster_name", types.UTF8Value(pageToken.ClusterName)), 63 | table.ValueParam("$page_size", types.Int32Value(int32(request.PageSize))), 64 | ) 65 | template := m.client.AddQueryPrefix(` 66 | DECLARE $cluster_name AS utf8; 67 | DECLARE $page_size AS int32; 68 | 69 | SELECT data, data_encoding, version, cluster_name 70 | FROM cluster_metadata_info 71 | WHERE cluster_name > $cluster_name 72 | ORDER BY cluster_name 73 | LIMIT $page_size; 74 | `) 75 | 76 | res, err := m.client.Do(ctx, template, conn.OnlineReadOnlyTxControl(), params, table.WithIdempotent()) 77 | if err != nil { 78 | return nil, err 79 | } 80 | defer func() { 81 | err2 := res.Close() 82 | if err == nil { 83 | err = err2 84 | } 85 | }() 86 | 87 | if err = res.NextResultSetErr(ctx); err != nil { 88 | return 89 | } 90 | 91 | resp = &p.InternalListClusterMetadataResponse{} 92 | var nextPageToken tokens.ClusterMetadataPageToken 93 | for res.NextRow() { 94 | blob, clusterName, version, err := m.scanClusterMetadata(res) 95 | if err != nil { 96 | return nil, err 97 | } 98 | nextPageToken.ClusterName = clusterName 99 | resp.ClusterMetadata = append(resp.ClusterMetadata, &p.InternalGetClusterMetadataResponse{ 100 | ClusterMetadata: blob, 101 | Version: version, 102 | }) 103 | } 104 | if len(resp.ClusterMetadata) >= request.PageSize { 105 | if resp.NextPageToken, err = nextPageToken.Serialize(); err != nil { 106 | return nil, err 107 | } 108 | } 109 | return resp, nil 110 | } 111 | 112 | func (m *ClusterMetadataStore) GetClusterMetadata( 113 | ctx context.Context, 114 | request *p.InternalGetClusterMetadataRequest, 115 | ) (resp *p.InternalGetClusterMetadataResponse, err error) { 116 | defer func() { 117 | if err != nil { 118 | err = conn.ConvertToTemporalError("GetClusterMetadata", err) 119 | } 120 | }() 121 | template := m.client.AddQueryPrefix(` 122 | DECLARE $cluster_name AS utf8; 123 | 124 | SELECT cluster_name, data, data_encoding, version 125 | FROM cluster_metadata_info 126 | WHERE cluster_name = $cluster_name; 127 | `) 128 | 129 | res, err := m.client.Do(ctx, template, conn.OnlineReadOnlyTxControl(), table.NewQueryParameters( 130 | table.ValueParam("$cluster_name", types.UTF8Value(request.ClusterName)), 131 | ), table.WithIdempotent()) 132 | if err != nil { 133 | return 134 | } 135 | defer func() { 136 | err2 := res.Close() 137 | if err == nil { 138 | err = err2 139 | } 140 | }() 141 | if err = conn.EnsureOneRowCursor(ctx, res); err != nil { 142 | return 143 | } 144 | blob, _, version, err := m.scanClusterMetadata(res) 145 | if err != nil { 146 | return 147 | } 148 | 149 | return &p.InternalGetClusterMetadataResponse{ 150 | ClusterMetadata: blob, 151 | Version: version, 152 | }, nil 153 | } 154 | 155 | func (m *ClusterMetadataStore) scanClusterMetadata(res result.Result) (*commonpb.DataBlob, string, int64, error) { 156 | var encoding string 157 | var encodingType conn.EncodingTypeRaw 158 | var encodingScanner named.Value 159 | if m.client.UseIntForEncoding() { 160 | encodingScanner = named.OptionalWithDefault("data_encoding", &encodingType) 161 | } else { 162 | encodingScanner = named.OptionalWithDefault("data_encoding", &encoding) 163 | } 164 | 165 | var clusterName string 166 | var data []byte 167 | var version int64 168 | if err := res.ScanNamed( 169 | named.OptionalWithDefault("cluster_name", &clusterName), 170 | named.OptionalWithDefault("data", &data), 171 | encodingScanner, 172 | named.OptionalWithDefault("version", &version), 173 | ); err != nil { 174 | return nil, "", 0, fmt.Errorf("failed to scan cluster metadata: %w", err) 175 | } 176 | 177 | if m.client.UseIntForEncoding() { 178 | encoding = enumspb.EncodingType(encodingType).String() 179 | } 180 | return p.NewDataBlob(data, encoding), clusterName, version, nil 181 | } 182 | 183 | func (m *ClusterMetadataStore) SaveClusterMetadata( 184 | ctx context.Context, 185 | request *p.InternalSaveClusterMetadataRequest, 186 | ) (rv bool, err error) { 187 | defer func() { 188 | if err != nil { 189 | err = conn.ConvertToTemporalError("SaveClusterMetadata", err) 190 | } 191 | }() 192 | if request.Version == 0 { 193 | template := m.client.AddQueryPrefix(` 194 | DECLARE $cluster_name AS utf8; 195 | DECLARE $data AS string; 196 | DECLARE $encoding AS ` + m.client.EncodingType().String() + `; 197 | DECLARE $version AS int64; 198 | 199 | INSERT INTO cluster_metadata_info (cluster_name, data, data_encoding, version) 200 | VALUES ($cluster_name, $data, $encoding, $version); 201 | `) 202 | if err = m.client.Write(ctx, template, table.NewQueryParameters( 203 | table.ValueParam("$cluster_name", types.UTF8Value(request.ClusterName)), 204 | table.ValueParam("$data", types.BytesValue(request.ClusterMetadata.Data)), 205 | table.ValueParam("$encoding", m.client.EncodingTypeValue(request.ClusterMetadata.EncodingType)), 206 | table.ValueParam("$version", types.Int64Value(1)), 207 | )); err != nil { 208 | return false, err 209 | } 210 | } else { 211 | template := m.client.AddQueryPrefix(` 212 | DECLARE $cluster_name AS utf8; 213 | DECLARE $data AS string; 214 | DECLARE $encoding AS ` + m.client.EncodingType().String() + `; 215 | DECLARE $version AS int64; 216 | DECLARE $prev_version AS int64; 217 | 218 | DISCARD SELECT Ensure(version, version == $prev_version, "VERSION_MISMATCH") 219 | FROM cluster_metadata_info 220 | WHERE cluster_name = $cluster_name; 221 | 222 | UPDATE cluster_metadata_info 223 | SET 224 | data = $data, 225 | data_encoding = $encoding, 226 | version = $version 227 | WHERE cluster_name = $cluster_name; 228 | `) 229 | if err = m.client.Write(ctx, template, table.NewQueryParameters( 230 | table.ValueParam("$cluster_name", types.UTF8Value(request.ClusterName)), 231 | table.ValueParam("$data", types.BytesValue(request.ClusterMetadata.Data)), 232 | table.ValueParam("$encoding", m.client.EncodingTypeValue(request.ClusterMetadata.EncodingType)), 233 | table.ValueParam("$version", types.Int64Value(request.Version+1)), 234 | table.ValueParam("$prev_version", types.Int64Value(request.Version)), 235 | )); err != nil { 236 | return false, err 237 | } 238 | } 239 | return true, nil 240 | } 241 | 242 | func (m *ClusterMetadataStore) DeleteClusterMetadata( 243 | ctx context.Context, 244 | request *p.InternalDeleteClusterMetadataRequest, 245 | ) error { 246 | template := m.client.AddQueryPrefix(` 247 | DECLARE $cluster_name AS utf8; 248 | 249 | DELETE FROM cluster_metadata_info 250 | WHERE cluster_name = $cluster_name; 251 | `) 252 | err := m.client.Write(ctx, template, table.NewQueryParameters( 253 | table.ValueParam("$cluster_name", types.UTF8Value(request.ClusterName)), 254 | )) 255 | return conn.ConvertToTemporalError("DeleteClusterMetadata", err) 256 | } 257 | 258 | func (m *ClusterMetadataStore) GetClusterMembers( 259 | ctx context.Context, 260 | request *p.GetClusterMembersRequest, 261 | ) (resp *p.GetClusterMembersResponse, err error) { 262 | defer func() { 263 | if err != nil { 264 | err = conn.ConvertToTemporalError("GetClusterMembers", err) 265 | } 266 | }() 267 | 268 | var pageToken tokens.ClusterMembersPageToken 269 | if len(request.NextPageToken) > 0 { 270 | if err = pageToken.Deserialize(request.NextPageToken); err != nil { 271 | return nil, serviceerror.NewInternal("page token is corrupted") 272 | } 273 | } 274 | 275 | params := table.NewQueryParameters( 276 | table.ValueParam("$expire_at_gt", types.TimestampValueFromTime(conn.ToYDBDateTime(time.Now()))), 277 | ) 278 | var declares []string 279 | var suffixes []string 280 | if request.HostIDEquals != nil { 281 | declares = append(declares, m.client.HostIDDecl()) 282 | params.Add(table.ValueParam("$host_id", m.client.HostIDValueFromUUID(request.HostIDEquals))) 283 | suffixes = append(suffixes, "AND host_id = $host_id") 284 | } 285 | if request.RPCAddressEquals != nil { 286 | declares = append(declares, "DECLARE $rpc_address AS utf8;") 287 | params.Add(table.ValueParam("$rpc_address", types.UTF8Value(request.RPCAddressEquals.String()))) 288 | suffixes = append(suffixes, "AND rpc_address = $rpc_address") 289 | } 290 | if request.RoleEquals != p.All { 291 | declares = append(declares, "DECLARE $role AS int32;") 292 | params.Add(table.ValueParam("$role", types.Int32Value(int32(request.RoleEquals)))) 293 | suffixes = append(suffixes, "AND role = $role") 294 | } 295 | if !request.SessionStartedAfter.IsZero() { 296 | declares = append(declares, "DECLARE $session_started_gt AS Timestamp;") 297 | params.Add(table.ValueParam("$session_started_gt", types.TimestampValueFromTime(conn.ToYDBDateTime(request.SessionStartedAfter)))) 298 | suffixes = append(suffixes, "AND session_start > $session_started_gt") 299 | } 300 | if request.LastHeartbeatWithin > 0 { 301 | declares = append(declares, "DECLARE $last_heartbeat_gt AS Timestamp;") 302 | params.Add(table.ValueParam("$last_heartbeat_gt", types.TimestampValueFromTime(conn.ToYDBDateTime(time.Now().Add(-request.LastHeartbeatWithin))))) 303 | suffixes = append(suffixes, "AND last_heartbeat > $last_heartbeat_gt") 304 | } 305 | if len(pageToken.LastSeenHostID) > 0 && request.HostIDEquals == nil { 306 | declares = append(declares, fmt.Sprintf("DECLARE $host_id_gt AS %s;", m.client.HostIDType())) 307 | params.Add(table.ValueParam("$host_id_gt", m.client.HostIDValueFromUUID(pageToken.LastSeenHostID))) 308 | suffixes = append(suffixes, "AND host_id > $host_id_gt") 309 | } 310 | if request.PageSize > 0 { 311 | declares = append(declares, "DECLARE $page_size AS int32;") 312 | params.Add(table.ValueParam("$page_size", types.Int32Value(int32(request.PageSize)))) 313 | } 314 | 315 | template := m.client.AddQueryPrefix(` 316 | DECLARE $expire_at_gt AS Timestamp; 317 | ` + strings.Join(declares, "\n") + ` 318 | SELECT host_id, rpc_address, rpc_port, role, session_start, last_heartbeat, expire_at 319 | FROM cluster_membership 320 | WHERE expire_at > $expire_at_gt ` + strings.Join(suffixes, " ") + ` 321 | ORDER BY host_id 322 | `) 323 | if request.PageSize > 0 { 324 | template += " LIMIT $page_size;" 325 | } 326 | res, err := m.client.Do(ctx, template, conn.OnlineReadOnlyTxControl(), params, table.WithIdempotent()) 327 | if err != nil { 328 | return nil, err 329 | } 330 | defer func() { 331 | err2 := res.Close() 332 | if err == nil { 333 | err = err2 334 | } 335 | }() 336 | 337 | if err = res.NextResultSetErr(ctx); err != nil { 338 | return nil, err 339 | } 340 | 341 | members := make([]*p.ClusterMember, 0, request.PageSize) 342 | for res.NextRow() { 343 | member, err := m.scanClusterMember(res) 344 | if err != nil { 345 | return nil, err 346 | } 347 | members = append(members, member) 348 | } 349 | 350 | resp = &p.GetClusterMembersResponse{ 351 | ActiveMembers: members, 352 | } 353 | if request.PageSize > 0 && len(members) == request.PageSize { 354 | nextPageToken := tokens.ClusterMembersPageToken{ 355 | LastSeenHostID: primitives.UUID(members[len(members)-1].HostID), 356 | } 357 | resp.NextPageToken, err = nextPageToken.Serialize() 358 | if err != nil { 359 | return nil, fmt.Errorf("failed to create next page token") 360 | } 361 | } 362 | return resp, nil 363 | } 364 | 365 | func (m *ClusterMetadataStore) UpsertClusterMembership( 366 | ctx context.Context, 367 | request *p.UpsertClusterMembershipRequest, 368 | ) error { 369 | template := m.client.AddQueryPrefix(m.client.HostIDDecl() + ` 370 | DECLARE $rpc_address AS utf8; 371 | DECLARE $rpc_port AS int32; 372 | DECLARE $role AS int32; 373 | DECLARE $session_start AS Timestamp; 374 | DECLARE $last_heartbeat AS Timestamp; 375 | DECLARE $expire_at AS Timestamp; 376 | 377 | UPSERT INTO cluster_membership (host_id, rpc_address, rpc_port, role, session_start, last_heartbeat, expire_at) 378 | VALUES ($host_id, $rpc_address, $rpc_port, $role, $session_start, $last_heartbeat, $expire_at); 379 | `) 380 | err := m.client.Write2(ctx, template, func() *table.QueryParameters { 381 | now := time.Now() 382 | return table.NewQueryParameters( 383 | table.ValueParam("$host_id", m.client.HostIDValueFromUUID(request.HostID)), 384 | table.ValueParam("$rpc_address", types.UTF8Value(request.RPCAddress.String())), 385 | table.ValueParam("$rpc_port", types.Int32Value(int32(request.RPCPort))), 386 | table.ValueParam("$role", types.Int32Value(int32(request.Role))), 387 | table.ValueParam("$session_start", types.TimestampValueFromTime(conn.ToYDBDateTime(request.SessionStart))), 388 | table.ValueParam("$last_heartbeat", types.TimestampValueFromTime(conn.ToYDBDateTime(now))), 389 | table.ValueParam("$expire_at", types.TimestampValueFromTime(conn.ToYDBDateTime(now.Add(request.RecordExpiry)))), 390 | ) 391 | }) 392 | return conn.ConvertToTemporalError("UpsertClusterMembership", err) 393 | } 394 | 395 | func (m *ClusterMetadataStore) PruneClusterMembership( 396 | ctx context.Context, 397 | request *p.PruneClusterMembershipRequest, 398 | ) error { 399 | template := m.client.AddQueryPrefix(` 400 | DECLARE $expire_at_lt AS Timestamp; 401 | 402 | DELETE FROM cluster_membership 403 | WHERE expire_at < $expire_at_lt; 404 | `) 405 | err := m.client.Write2(ctx, template, func() *table.QueryParameters { 406 | return table.NewQueryParameters( 407 | table.ValueParam("$expire_at_lt", types.TimestampValueFromTime(conn.ToYDBDateTime(time.Now()))), 408 | ) 409 | }) 410 | return conn.ConvertToTemporalError("PruneClusterMembership", err) 411 | } 412 | 413 | func (m *ClusterMetadataStore) GetName() string { 414 | return ydbPersistenceName 415 | } 416 | 417 | func (m *ClusterMetadataStore) Close() { 418 | } 419 | 420 | func (m *ClusterMetadataStore) scanClusterMember(res result.Result) (*p.ClusterMember, error) { 421 | var hostID, rpcAddress string 422 | var hostIDBytes []byte 423 | var rpcPort, role int32 424 | var sessionStart, lastHeartbeat, expireAt time.Time 425 | 426 | var hostIDScanner named.Value 427 | if m.client.UseBytesForHostIDs() { 428 | hostIDScanner = named.OptionalWithDefault("host_id", &hostIDBytes) 429 | } else { 430 | hostIDScanner = named.OptionalWithDefault("host_id", &hostID) 431 | } 432 | 433 | if err := res.ScanNamed( 434 | hostIDScanner, 435 | named.OptionalWithDefault("rpc_address", &rpcAddress), 436 | named.OptionalWithDefault("rpc_port", &rpcPort), 437 | named.OptionalWithDefault("role", &role), 438 | named.OptionalWithDefault("session_start", &sessionStart), 439 | named.OptionalWithDefault("last_heartbeat", &lastHeartbeat), 440 | named.OptionalWithDefault("expire_at", &expireAt), 441 | ); err != nil { 442 | return nil, fmt.Errorf("failed to scan cluster member: %w", err) 443 | } 444 | 445 | if m.client.UseBytesForHostIDs() { 446 | hostID = uuid.UUID(hostIDBytes).String() 447 | } 448 | 449 | return &p.ClusterMember{ 450 | HostID: uuid.Parse(hostID), 451 | RPCAddress: net.ParseIP(rpcAddress), 452 | RPCPort: uint16(rpcPort), 453 | Role: p.ServiceType(role), 454 | SessionStart: conn.FromYDBDateTime(sessionStart), 455 | LastHeartbeat: conn.FromYDBDateTime(lastHeartbeat), 456 | RecordExpiry: conn.FromYDBDateTime(expireAt), 457 | }, nil 458 | } 459 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | type Config struct { 8 | Endpoint string `yaml:"endpoint" mapstructure:"endpoint"` 9 | Database string `yaml:"database" mapstructure:"database"` 10 | Folder string `yaml:"folder" mapstructure:"folder"` 11 | Token string `yaml:"token" mapstructure:"token"` 12 | UseSSL bool `yaml:"use_ssl" mapstructure:"use_ssl"` 13 | SessionPoolSizeLimit int `yaml:"pool_size_limit" mapstructure:"pool_size_limit"` 14 | PreferLocalDC bool `yaml:"prefer_local_dc" mapstructure:"prefer_local_dc"` 15 | UseOldTypes bool 16 | } 17 | 18 | func (c *Config) Validate() error { 19 | if c.Endpoint == "" { 20 | return errors.New("endpoint is required") 21 | } 22 | if c.Database == "" { 23 | return errors.New("database is required") 24 | } 25 | if c.Folder == "" { 26 | return errors.New("folder is required") 27 | } 28 | return nil 29 | } 30 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/client.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "time" 8 | 9 | ydbenv "github.com/ydb-platform/ydb-go-sdk-auth-environ" 10 | ydbmetrics "github.com/ydb-platform/ydb-go-sdk-metrics" 11 | "github.com/ydb-platform/ydb-go-sdk/v3" 12 | "github.com/ydb-platform/ydb-go-sdk/v3/balancers" 13 | "github.com/ydb-platform/ydb-go-sdk/v3/query" 14 | "github.com/ydb-platform/ydb-go-sdk/v3/sugar" 15 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 16 | "github.com/ydb-platform/ydb-go-sdk/v3/table/options" 17 | "github.com/ydb-platform/ydb-go-sdk/v3/table/result" 18 | "github.com/ydb-platform/ydb-go-sdk/v3/trace" 19 | tlog "go.temporal.io/server/common/log" 20 | "go.temporal.io/server/common/metrics" 21 | "go.uber.org/atomic" 22 | "golang.org/x/xerrors" 23 | 24 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/config" 25 | connlog "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn/log" 26 | connmetrics "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn/metrics" 27 | ) 28 | 29 | var ( 30 | ydbCredentialEnvs = [...]string{ 31 | "YDB_SERVICE_ACCOUNT_KEY_CREDENTIALS", 32 | "YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS", 33 | "YDB_ANONYMOUS_CREDENTIALS", 34 | "YDB_METADATA_CREDENTIALS", 35 | "YDB_ACCESS_TOKEN_CREDENTIALS", 36 | } 37 | ) 38 | 39 | type Client struct { 40 | DB *ydb.Driver 41 | Database string 42 | Folder string 43 | logger tlog.Logger 44 | logQueries atomic.Bool 45 | useOldTypes bool 46 | } 47 | 48 | func SnapshotReadOnlyTxControl(opts ...table.TxOnlineReadOnlyOption) *table.TransactionControl { 49 | return table.TxControl( 50 | table.BeginTx(table.WithOnlineReadOnly(opts...)), 51 | table.CommitTx(), // open transactions not supported for OnlineReadOnly 52 | ) 53 | } 54 | 55 | func OnlineReadOnlyTxControl(opts ...table.TxOnlineReadOnlyOption) *table.TransactionControl { 56 | return table.OnlineReadOnlyTxControl(opts...) 57 | } 58 | 59 | const detailsFull = trace.DriverRepeaterEvents | 60 | trace.DriverConnEvents | 61 | trace.DriverBalancerEvents | 62 | trace.TablePoolEvents | 63 | trace.RetryEvents | 64 | trace.DiscoveryEvents | 65 | trace.SchemeEvents 66 | 67 | const detailsTiny = trace.DriverConnEvents | 68 | trace.TablePoolEvents 69 | 70 | const detailsNone = trace.Details(0) 71 | 72 | const details = detailsNone 73 | const _ = detailsNone 74 | const _ = detailsTiny 75 | const _ = detailsFull 76 | 77 | func setupLogger(l tlog.Logger) []ydb.Option { 78 | opts := make([]ydb.Option, 0) 79 | opts = append(opts, connlog.WithTraces(l, detailsNone)) 80 | return opts 81 | } 82 | 83 | func setupMetrics(mh metrics.Handler) []ydb.Option { 84 | opts := make([]ydb.Option, 0) 85 | if mh == nil { 86 | return opts 87 | } 88 | mc := connmetrics.MakeConfig( 89 | mh, 90 | connmetrics.WithNamespace("xydb"), 91 | connmetrics.WithDetails(detailsFull), 92 | connmetrics.WithSeparator("_"), 93 | ) 94 | opts = append(opts, 95 | ydbmetrics.WithTraces(mc), 96 | ) 97 | return opts 98 | } 99 | 100 | func haveCredentialsInEnv() bool { 101 | for _, env := range ydbCredentialEnvs { 102 | if os.Getenv(env) != "" { 103 | return true 104 | } 105 | } 106 | return false 107 | } 108 | 109 | func NewClient(ctx context.Context, cfg config.Config, logger tlog.Logger, mh metrics.Handler, opts ...ydb.Option) ( 110 | *Client, 111 | error, 112 | ) { 113 | if cfg.Token != "" { 114 | opts = append(opts, ydb.WithAccessTokenCredentials(cfg.Token)) 115 | } else if haveCredentialsInEnv() { 116 | opts = append(opts, ydbenv.WithEnvironCredentials(ctx)) 117 | } else { 118 | logger.Info("no credentials provided for ydb client, relying on opts...") 119 | } 120 | 121 | opts = append(opts, setupLogger(logger)...) 122 | // opts = append(opts, setupMetrics(mh)...) 123 | opts = append(opts, ydb.WithSessionPoolIdleThreshold(time.Second*10)) 124 | 125 | balancerConfig := balancers.RandomChoice() 126 | if cfg.PreferLocalDC { 127 | balancerConfig = balancers.PreferNearestDCWithFallBack(balancerConfig) 128 | } 129 | opts = append(opts, ydb.WithBalancer(balancerConfig)) 130 | 131 | opts = append(opts, ydb.WithDialTimeout(10*time.Second)) 132 | sessionPoolSizeLimit := cfg.SessionPoolSizeLimit 133 | if sessionPoolSizeLimit > 0 { 134 | opts = append(opts, ydb.WithSessionPoolSizeLimit(sessionPoolSizeLimit)) 135 | } 136 | 137 | db, err := ydb.Open( 138 | ctx, 139 | sugar.DSN(cfg.Endpoint, cfg.Database, sugar.WithSecure(cfg.UseSSL)), 140 | opts..., 141 | ) 142 | if err != nil { 143 | return nil, xerrors.Errorf("connect error: %w", err) 144 | } 145 | return &Client{ 146 | DB: db, 147 | Database: cfg.Database, 148 | Folder: cfg.Folder, 149 | logger: logger, 150 | useOldTypes: cfg.UseOldTypes, 151 | }, nil 152 | } 153 | 154 | func (client *Client) Close(ctx context.Context) error { 155 | closeCtx, cancel := context.WithTimeout(ctx, time.Second*30) 156 | defer cancel() 157 | ts := time.Now() 158 | defer func() { 159 | client.logger.Info(fmt.Sprintf("ydb close duration: %s", time.Since(ts))) 160 | }() 161 | return client.DB.Close(closeCtx) 162 | } 163 | 164 | func (client *Client) GetPrefix() string { 165 | return fmt.Sprintf("%s/%s", client.Database, client.Folder) 166 | } 167 | 168 | func (client *Client) queryPrefix() string { 169 | return fmt.Sprintf("--!syntax_v1\nPRAGMA TablePathPrefix(\"%s\");\n", client.GetPrefix()) 170 | } 171 | 172 | func (client *Client) AddQueryPrefix(query string) string { 173 | return client.queryPrefix() + query 174 | } 175 | 176 | func (client *Client) Write( 177 | ctx context.Context, 178 | query string, 179 | params *table.QueryParameters, 180 | opts ...table.Option, 181 | ) error { 182 | res, errDo := client.Do(ctx, query, table.SerializableReadWriteTxControl(table.CommitTx()), params, opts...) 183 | if errDo != nil { 184 | return errDo 185 | } 186 | return res.Close() 187 | } 188 | 189 | func (client *Client) Write2( 190 | ctx context.Context, 191 | query string, 192 | getQueryParameters func() *table.QueryParameters, 193 | opts ...table.Option, 194 | ) error { 195 | res, errDo := client.Do(ctx, query, table.SerializableReadWriteTxControl(table.CommitTx()), getQueryParameters(), opts...) 196 | if errDo != nil { 197 | return errDo 198 | } 199 | return res.Close() 200 | } 201 | 202 | func (client *Client) Query( 203 | ctx context.Context, 204 | sql string, 205 | params *table.QueryParameters, 206 | ) (rs query.ClosableResultSet, err error) { 207 | return client.DB.Query().QueryResultSet(ctx, 208 | sql, 209 | query.WithParameters(params), 210 | query.WithTxControl(query.SnapshotReadOnlyTxControl()), 211 | query.WithIdempotent(), 212 | ) 213 | } 214 | 215 | func (client *Client) Do( 216 | ctx context.Context, 217 | query string, 218 | tx *table.TransactionControl, 219 | params *table.QueryParameters, 220 | opts ...table.Option, 221 | ) (res result.Result, err error) { 222 | err = client.DB.Table().Do( 223 | ctx, 224 | func(c context.Context, s table.Session) (err error) { 225 | _, res, err = s.Execute(c, tx, query, params) 226 | return 227 | }, 228 | opts..., 229 | ) 230 | if err != nil { 231 | return nil, err 232 | } 233 | return res, nil 234 | } 235 | 236 | func (client *Client) Do2( 237 | ctx context.Context, 238 | query string, 239 | tx *table.TransactionControl, 240 | getQueryParameters func() *table.QueryParameters, 241 | opts ...table.Option, 242 | ) (res result.Result, err error) { 243 | err = client.DB.Table().Do( 244 | ctx, 245 | func(c context.Context, s table.Session) (err error) { 246 | _, res, err = s.Execute(c, tx, query, getQueryParameters()) 247 | return 248 | }, 249 | opts..., 250 | ) 251 | if err != nil { 252 | return nil, err 253 | } 254 | return res, nil 255 | } 256 | 257 | func (client *Client) DoSchema( 258 | ctx context.Context, 259 | query string, 260 | params ...options.ExecuteSchemeQueryOption, 261 | ) (err error) { 262 | err = client.DB.Table().Do( 263 | ctx, func(ctx context.Context, s table.Session) (err error) { 264 | err = s.ExecuteSchemeQuery( 265 | ctx, query, params..., 266 | ) 267 | return 268 | }, table.WithIdempotent(), 269 | ) 270 | 271 | if err != nil { 272 | return xerrors.Errorf("failed ydb request: %w", err) 273 | } 274 | return nil 275 | } 276 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/const.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 5 | "go.temporal.io/api/enums/v1" 6 | "go.temporal.io/server/common/primitives" 7 | ) 8 | 9 | type EncodingTypeRaw = int32 10 | 11 | func (client *Client) EncodingType() types.Type { 12 | if client.useOldTypes { 13 | return types.TypeUTF8 14 | } 15 | return types.TypeInt16 16 | } 17 | 18 | func (client *Client) NamespaceIDType() types.Type { 19 | if client.useOldTypes { 20 | return types.TypeUTF8 21 | } 22 | return types.TypeBytes 23 | } 24 | 25 | func (client *Client) RunIDType() types.Type { 26 | if client.useOldTypes { 27 | return types.TypeUTF8 28 | } 29 | return types.TypeBytes 30 | } 31 | 32 | func (client *Client) HistoryIDType() types.Type { 33 | if client.useOldTypes { 34 | return types.TypeUTF8 35 | } 36 | return types.TypeBytes 37 | } 38 | 39 | func (client *Client) HostIDType() types.Type { 40 | if client.useOldTypes { 41 | return types.TypeUTF8 42 | } 43 | return types.TypeBytes 44 | } 45 | 46 | func (client *Client) NamspaceIDDecl() string { 47 | if client.useOldTypes { 48 | return "DECLARE $namespace_id AS Utf8;\n" 49 | } 50 | return "DECLARE $namespace_id AS Bytes;\n" 51 | } 52 | 53 | func (client *Client) RunIDDecl() string { 54 | if client.useOldTypes { 55 | return "DECLARE $run_id AS Utf8;\n" 56 | } 57 | return "DECLARE $run_id AS Bytes;\n" 58 | } 59 | 60 | func (client *Client) CurrentRunIDDecl() string { 61 | if client.useOldTypes { 62 | return "DECLARE $current_run_id AS Utf8;\n" 63 | } 64 | return "DECLARE $current_run_id AS Bytes;\n" 65 | } 66 | 67 | func (client *Client) HostIDDecl() string { 68 | if client.useOldTypes { 69 | return "DECLARE $host_id AS Utf8;\n" 70 | } 71 | return "DECLARE $host_id AS Bytes;\n" 72 | } 73 | 74 | func (client *Client) NamespaceIDValue(v string) types.Value { 75 | if client.useOldTypes { 76 | return types.UTF8Value(primitives.MustValidateUUID(v)) 77 | } 78 | return types.BytesValue(primitives.MustParseUUID(v)) 79 | } 80 | 81 | func (client *Client) NamespaceIDValueFromUUID(uuid primitives.UUID) types.Value { 82 | if client.useOldTypes { 83 | return types.UTF8Value(uuid.String()) 84 | } 85 | return types.BytesValue(uuid) 86 | } 87 | 88 | func (client *Client) EmptyNamespaceIDValue() types.Value { 89 | if client.useOldTypes { 90 | return types.UTF8Value("") 91 | } 92 | return types.BytesValue(primitives.UUID{}) 93 | } 94 | 95 | func (client *Client) RunIDValue(v string) types.Value { 96 | if client.useOldTypes { 97 | return types.UTF8Value(primitives.MustValidateUUID(v)) 98 | } 99 | return types.BytesValue(primitives.MustParseUUID(v)) 100 | } 101 | 102 | func (client *Client) RunIDValueFromUUID(uuid primitives.UUID) types.Value { 103 | if client.useOldTypes { 104 | return types.UTF8Value(uuid.String()) 105 | } 106 | return types.BytesValue(uuid) 107 | } 108 | 109 | func (client *Client) EmptyRunIDValue() types.Value { 110 | if client.useOldTypes { 111 | return types.UTF8Value("") 112 | } 113 | return types.BytesValue(primitives.UUID{}) 114 | } 115 | 116 | func (client *Client) HistoryIDValue(v string) types.Value { 117 | if client.useOldTypes { 118 | return types.UTF8Value(primitives.MustValidateUUID(v)) 119 | } 120 | return types.BytesValue(primitives.MustParseUUID(v)) 121 | } 122 | 123 | func (client *Client) HostIDValueFromUUID(uuid []byte) types.Value { 124 | if client.useOldTypes { 125 | return types.UTF8Value(primitives.UUIDString(uuid)) 126 | } 127 | return types.BytesValue(uuid) 128 | } 129 | 130 | func (client *Client) EncodingTypeValue(v enums.EncodingType) types.Value { 131 | if client.useOldTypes { 132 | return types.UTF8Value(v.String()) 133 | } 134 | return types.Int16Value(int16(v)) 135 | } 136 | 137 | func (client *Client) NewEncodingTypeValue(v enums.EncodingType) types.Value { 138 | return types.Int16Value(int16(v)) 139 | } 140 | 141 | func (client *Client) UseIntForEncoding() bool { 142 | return !client.useOldTypes 143 | } 144 | 145 | func (client *Client) UseBytesForNamespaceIDs() bool { 146 | return !client.useOldTypes 147 | } 148 | 149 | func (client *Client) UseBytesForRunIDs() bool { 150 | return !client.useOldTypes 151 | } 152 | 153 | func (client *Client) UseBytesForHistoryIDs() bool { 154 | return !client.useOldTypes 155 | } 156 | 157 | func (client *Client) UseBytesForHostIDs() bool { 158 | return !client.useOldTypes 159 | } 160 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/errors.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" 10 | "github.com/ydb-platform/ydb-go-sdk/v3" 11 | enumspb "go.temporal.io/api/enums/v1" 12 | "go.temporal.io/api/serviceerror" 13 | "go.temporal.io/server/common/persistence" 14 | ) 15 | 16 | func convertError(operation string, err error, details ...string) error { 17 | msg := fmt.Sprintf("operation %v encountered %v", operation, err.Error()) 18 | if len(details) > 0 { 19 | msg += " (" + strings.Join(details, ", ") + ")" 20 | } 21 | if err == context.DeadlineExceeded || ydb.IsTimeoutError(err) { 22 | return &persistence.TimeoutError{Msg: msg} 23 | } 24 | if ydb.IsOperationErrorNotFoundError(err) { 25 | return serviceerror.NewNotFound(msg) 26 | } 27 | if ydb.IsOperationError(err, Ydb.StatusIds_PRECONDITION_FAILED) { 28 | return &persistence.ConditionFailedError{Msg: msg} 29 | } 30 | if ydb.IsOperationErrorOverloaded(err) { 31 | return serviceerror.NewResourceExhausted(enumspb.RESOURCE_EXHAUSTED_CAUSE_SYSTEM_OVERLOADED, msg) 32 | } 33 | return serviceerror.NewUnavailable(msg) 34 | } 35 | 36 | func IsPreconditionFailedAndContains(err error, substr string) bool { 37 | rv := false 38 | if ydb.IsOperationError(err, Ydb.StatusIds_PRECONDITION_FAILED) { 39 | ydb.IterateByIssues(err, func(message string, code Ydb.StatusIds_StatusCode, severity uint32) { 40 | if strings.Contains(message, substr) { 41 | rv = true 42 | } 43 | }) 44 | } 45 | if !rv && ydb.IsOperationError(err, Ydb.StatusIds_GENERIC_ERROR) { 46 | ydb.IterateByIssues(err, func(message string, code Ydb.StatusIds_StatusCode, severity uint32) { 47 | if strings.Contains(message, substr) { 48 | rv = true 49 | } 50 | }) 51 | } 52 | return rv 53 | } 54 | 55 | func IsIntermediateDataMaterializationExceededSizeLimitError(err error) bool { 56 | return IsPreconditionFailedAndContains(err, "Intermediate data materialization exceeded size limit") 57 | } 58 | 59 | type RootCauseError struct { 60 | new func(message string) error 61 | message string 62 | // or 63 | err error 64 | } 65 | 66 | func (e *RootCauseError) Error() string { 67 | if e.err != nil { 68 | return e.err.Error() 69 | } 70 | return e.new(e.message).Error() 71 | } 72 | 73 | func NewRootCauseError(new func(message string) error, message string) error { 74 | return &RootCauseError{new: new, message: message} 75 | } 76 | 77 | func WrapErrorAsRootCause(err error) error { 78 | return &RootCauseError{err: err} 79 | } 80 | 81 | func ConvertToTemporalError(operation string, err error, details ...string) error { 82 | if err == nil { 83 | return nil 84 | } 85 | var rv *RootCauseError 86 | if errors.As(err, &rv) { 87 | if rv.err != nil { 88 | return rv.err 89 | } else { 90 | msg := fmt.Sprintf("operation %s failed: %s", operation, rv.message) 91 | if len(details) > 0 { 92 | msg += " (" + strings.Join(details, ", ") + ")" 93 | } 94 | return rv.new(msg) 95 | } 96 | } 97 | return convertError(operation, err, details...) 98 | } 99 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/executor.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 7 | "github.com/ydb-platform/ydb-go-sdk/v3/table/result" 8 | ) 9 | 10 | /* 11 | Executor is a very primitive wrapper that provides the same interface for executing 12 | queries in either session or transaction. 13 | 14 | Allows having helper functions that could be used in both cases. 15 | */ 16 | type Executor interface { 17 | Execute( 18 | ctx context.Context, 19 | query string, 20 | params *table.QueryParameters, 21 | ) (result.Result, error) 22 | 23 | // Write is a special case of Execute that immediately closes its result. 24 | Write( 25 | ctx context.Context, 26 | query string, 27 | params *table.QueryParameters, 28 | ) error 29 | } 30 | 31 | func NewExecutorFromSession(s table.Session, tx *table.TransactionControl) Executor { 32 | return &sessionExecutor{session: s, tx: tx} 33 | } 34 | 35 | type sessionExecutor struct { 36 | session table.Session 37 | tx *table.TransactionControl 38 | } 39 | 40 | func (s *sessionExecutor) Execute(ctx context.Context, query string, params *table.QueryParameters) (result.Result, error) { 41 | _, res, err := s.session.Execute(ctx, s.tx, query, params) 42 | return res, err 43 | } 44 | 45 | func (s *sessionExecutor) Write(ctx context.Context, query string, params *table.QueryParameters) error { 46 | res, err := s.Execute(ctx, query, params) 47 | if err != nil { 48 | return err 49 | } 50 | return res.Close() 51 | } 52 | 53 | type actorExecutor struct { 54 | a table.TransactionActor 55 | } 56 | 57 | func (a actorExecutor) Execute(ctx context.Context, query string, params *table.QueryParameters) (result.Result, error) { 58 | return a.a.Execute(ctx, query, params) 59 | } 60 | 61 | func (a actorExecutor) Write(ctx context.Context, query string, params *table.QueryParameters) error { 62 | res, err := a.Execute(ctx, query, params) 63 | if err != nil { 64 | return err 65 | } 66 | return res.Close() 67 | } 68 | 69 | func NewExecutorFromTransactionActor(a table.TransactionActor) Executor { 70 | return &actorExecutor{a: a} 71 | } 72 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/log/adapter.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/ydb-platform/ydb-go-sdk/v3/log" 8 | tlog "go.temporal.io/server/common/log" 9 | "go.temporal.io/server/common/log/tag" 10 | "go.uber.org/zap" 11 | ) 12 | 13 | var _ log.Logger = adapter{} 14 | 15 | type adapter struct { 16 | l tlog.Logger 17 | } 18 | 19 | func (a adapter) Log(ctx context.Context, msg string, fields ...log.Field) { 20 | tags := Tags(fields) 21 | tags = append(tags, tag.NewStringTag("namespace", strings.Join(log.NamesFromContext(ctx), "."))) 22 | 23 | switch log.LevelFromContext(ctx) { 24 | case log.TRACE, log.DEBUG: 25 | a.l.Debug(msg, tags...) 26 | case log.INFO: 27 | a.l.Info(msg, tags...) 28 | case log.WARN: 29 | a.l.Warn(msg, tags...) 30 | case log.ERROR: 31 | a.l.Error(msg, tags...) 32 | case log.FATAL: 33 | a.l.Fatal(msg, tags...) 34 | default: 35 | a.l.Error("[Unknown log level] "+msg, tags...) 36 | } 37 | } 38 | 39 | func fieldToField(field log.Field) tag.Tag { 40 | var f zap.Field 41 | switch field.Type() { 42 | case log.IntType: 43 | f = zap.Int(field.Key(), field.IntValue()) 44 | case log.Int64Type: 45 | f = zap.Int64(field.Key(), field.Int64Value()) 46 | case log.StringType: 47 | f = zap.String(field.Key(), field.StringValue()) 48 | case log.BoolType: 49 | f = zap.Bool(field.Key(), field.BoolValue()) 50 | case log.DurationType: 51 | f = zap.Duration(field.Key(), field.DurationValue()) 52 | case log.StringsType: 53 | f = zap.Strings(field.Key(), field.StringsValue()) 54 | case log.ErrorType: 55 | f = zap.Error(field.ErrorValue()) 56 | case log.StringerType: 57 | f = zap.Stringer(field.Key(), field.Stringer()) 58 | default: 59 | f = zap.Any(field.Key(), field.AnyValue()) 60 | } 61 | return tag.NewZapTag(f) 62 | } 63 | 64 | func Tags(fields []log.Field) []tag.Tag { 65 | tags := make([]tag.Tag, len(fields)) 66 | for i, f := range fields { 67 | tags[i] = fieldToField(f) 68 | } 69 | return tags 70 | } 71 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/log/traces.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "github.com/ydb-platform/ydb-go-sdk/v3" 5 | "github.com/ydb-platform/ydb-go-sdk/v3/log" 6 | "github.com/ydb-platform/ydb-go-sdk/v3/trace" 7 | tlog "go.temporal.io/server/common/log" 8 | ) 9 | 10 | func WithTraces(l tlog.Logger, d trace.Detailer, opts ...log.Option) ydb.Option { 11 | a := adapter{l: l} 12 | return ydb.MergeOptions( 13 | ydb.WithTraceDriver(log.Driver(a, d, opts...)), 14 | ydb.WithTraceTable(log.Table(a, d, opts...)), 15 | ydb.WithTraceScripting(log.Scripting(a, d, opts...)), 16 | ydb.WithTraceScheme(log.Scheme(a, d, opts...)), 17 | ydb.WithTraceCoordination(log.Coordination(a, d, opts...)), 18 | ydb.WithTraceRatelimiter(log.Ratelimiter(a, d, opts...)), 19 | ydb.WithTraceDiscovery(log.Discovery(a, d, opts...)), 20 | ydb.WithTraceTopic(log.Topic(a, d, opts...)), 21 | ydb.WithTraceDatabaseSQL(log.DatabaseSQL(a, d, opts...)), 22 | ) 23 | } 24 | 25 | func WithLogger(l tlog.Logger, d trace.Detailer, opts ...log.Option) ydb.Option { 26 | return ydb.WithLogger(adapter{l: l}, d, opts...) 27 | } 28 | 29 | func Table(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.Table { 30 | return log.Table(&adapter{l: l}, d, opts...) 31 | } 32 | 33 | func Topic(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.Topic { 34 | return log.Topic(&adapter{l: l}, d, opts...) 35 | } 36 | 37 | func Driver(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.Driver { 38 | return log.Driver(&adapter{l: l}, d, opts...) 39 | } 40 | 41 | func Coordination(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.Coordination { 42 | return log.Coordination(&adapter{l: l}, d, opts...) 43 | } 44 | 45 | func Discovery(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.Discovery { 46 | return log.Discovery(&adapter{l: l}, d, opts...) 47 | } 48 | 49 | func Ratelimiter(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.Ratelimiter { 50 | return log.Ratelimiter(&adapter{l: l}, d, opts...) 51 | } 52 | 53 | func Scheme(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.Scheme { 54 | return log.Scheme(&adapter{l: l}, d, opts...) 55 | } 56 | 57 | func Scripting(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.Scripting { 58 | return log.Scripting(&adapter{l: l}, d, opts...) 59 | } 60 | 61 | func DatabaseSQL(l tlog.Logger, d trace.Detailer, opts ...log.Option) trace.DatabaseSQL { 62 | return log.DatabaseSQL(&adapter{l: l}, d, opts...) 63 | } 64 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/metrics/metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "sync" 7 | "time" 8 | 9 | "github.com/prometheus/client_golang/prometheus" 10 | "github.com/ydb-platform/ydb-go-sdk-metrics/registry" 11 | "github.com/ydb-platform/ydb-go-sdk/v3/trace" 12 | "go.temporal.io/server/common/metrics" 13 | ) 14 | 15 | const ( 16 | defaultNamespace = "ydb_go_sdk" 17 | defaultSeparator = "_" 18 | ) 19 | 20 | var ( 21 | defaultTimerBuckets = prometheus.ExponentialBuckets(time.Millisecond.Seconds(), 1.25, 15) 22 | ) 23 | 24 | type Config struct { 25 | details trace.Details 26 | separator string 27 | handler metrics.Handler 28 | namespace string 29 | timerBuckets []float64 30 | 31 | m sync.Mutex 32 | counters map[metricKey]*counterVec 33 | gauges map[metricKey]*gaugeVec 34 | timers map[metricKey]*timerVec 35 | histograms map[metricKey]*histogramVec 36 | } 37 | 38 | func MakeConfig(registry metrics.Handler, opts ...option) *Config { 39 | c := &Config{ 40 | handler: registry, 41 | namespace: defaultNamespace, 42 | separator: defaultSeparator, 43 | timerBuckets: defaultTimerBuckets, 44 | } 45 | for _, o := range opts { 46 | o(c) 47 | } 48 | if c.details == 0 { 49 | c.details = trace.DetailsAll 50 | } 51 | return c 52 | } 53 | 54 | func (c *Config) CounterVec(name string, labelNames ...string) registry.CounterVec { 55 | counterKey := newCounterKey(c.namespace, name) 56 | c.m.Lock() 57 | defer c.m.Unlock() 58 | if cnt, ok := c.counters[counterKey]; ok { 59 | return cnt 60 | } 61 | cnt := &counterVec{ 62 | c: c.handler.Counter(c.namespace + c.separator + name), 63 | } 64 | c.counters[counterKey] = cnt 65 | return cnt 66 | } 67 | 68 | func (c *Config) join(a, b string) string { 69 | if a == "" { 70 | return b 71 | } 72 | if b == "" { 73 | return "" 74 | } 75 | return strings.Join([]string{a, b}, c.separator) 76 | } 77 | 78 | func (c *Config) WithSystem(subsystem string) registry.Config { 79 | return &Config{ 80 | separator: c.separator, 81 | details: c.details, 82 | handler: c.handler, 83 | timerBuckets: c.timerBuckets, 84 | namespace: c.join(c.namespace, subsystem), 85 | counters: make(map[metricKey]*counterVec), 86 | gauges: make(map[metricKey]*gaugeVec), 87 | timers: make(map[metricKey]*timerVec), 88 | histograms: make(map[metricKey]*histogramVec), 89 | } 90 | } 91 | 92 | type metricKey struct { 93 | Namespace string 94 | Subsystem string 95 | Name string 96 | Buckets string 97 | } 98 | 99 | func newCounterKey(namespace, name string) metricKey { 100 | return metricKey{ 101 | Namespace: namespace, 102 | Name: name, 103 | } 104 | } 105 | 106 | func newGaugeKey(namespace, name string) metricKey { 107 | return metricKey{ 108 | Namespace: namespace, 109 | Name: name, 110 | } 111 | } 112 | 113 | func newHistogramKey(namespace, name string, buckets []float64) metricKey { 114 | return metricKey{ 115 | Namespace: namespace, 116 | Name: name, 117 | Buckets: fmt.Sprintf("%v", buckets), 118 | } 119 | } 120 | 121 | func newTimerKey(namespace, name string, buckets []float64) metricKey { 122 | return metricKey{ 123 | Namespace: namespace, 124 | Name: name, 125 | Buckets: fmt.Sprintf("%v", buckets), 126 | } 127 | } 128 | 129 | type counterVec struct { 130 | c metrics.CounterIface 131 | } 132 | 133 | func maybeReplaceReservedName(key string) string { 134 | if key == "name" { 135 | return "ydb_name" 136 | } else { 137 | return key 138 | } 139 | } 140 | 141 | func (c *counterVec) With(labels map[string]string) registry.Counter { 142 | tags := make([]metrics.Tag, 0, len(labels)) 143 | for k, v := range labels { 144 | k = maybeReplaceReservedName(k) 145 | tags = append(tags, metrics.StringTag(k, v)) 146 | } 147 | return &counter{c: c.c, tags: tags} 148 | } 149 | 150 | type counter struct { 151 | c metrics.CounterIface 152 | tags []metrics.Tag 153 | } 154 | 155 | func (c counter) Inc() { 156 | c.c.Record(1, c.tags...) 157 | } 158 | 159 | type gaugeVec struct { 160 | g metrics.GaugeIface 161 | } 162 | 163 | func (c *gaugeVec) With(labels map[string]string) registry.Gauge { 164 | tags := make([]metrics.Tag, 0, len(labels)) 165 | for k, v := range labels { 166 | k = maybeReplaceReservedName(k) 167 | tags = append(tags, metrics.StringTag(k, v)) 168 | } 169 | return &gauge{c: c.g, tags: tags} 170 | } 171 | 172 | type gauge struct { 173 | c metrics.GaugeIface 174 | tags []metrics.Tag 175 | m sync.Mutex 176 | absValue float64 // current value, for Add method 177 | } 178 | 179 | func (g *gauge) Add(delta float64) { 180 | g.m.Lock() 181 | defer g.m.Unlock() 182 | g.absValue += delta 183 | g.c.Record(g.absValue) 184 | } 185 | 186 | func (g *gauge) Set(value float64) { 187 | g.m.Lock() 188 | defer g.m.Unlock() 189 | g.absValue = value 190 | g.c.Record(value) 191 | } 192 | 193 | type histogramVec struct { 194 | h metrics.HistogramIface 195 | } 196 | 197 | func (h *histogramVec) With(labels map[string]string) registry.Histogram { 198 | tags := make([]metrics.Tag, 0, len(labels)) 199 | for k, v := range labels { 200 | k = maybeReplaceReservedName(k) 201 | tags = append(tags, metrics.StringTag(k, v)) 202 | } 203 | return &histogram{h: h.h, tags: tags} 204 | } 205 | 206 | type histogram struct { 207 | h metrics.HistogramIface 208 | tags []metrics.Tag 209 | } 210 | 211 | func (h *histogram) Record(v float64) { 212 | h.h.Record(int64(v), h.tags...) 213 | } 214 | 215 | type timerVec struct { 216 | t metrics.TimerIface 217 | } 218 | 219 | func (t *timerVec) With(labels map[string]string) registry.Timer { 220 | tags := make([]metrics.Tag, 0, len(labels)) 221 | for k, v := range labels { 222 | k = maybeReplaceReservedName(k) 223 | tags = append(tags, metrics.StringTag(k, v)) 224 | } 225 | return &timer{t: t.t, tags: tags} 226 | } 227 | 228 | type timer struct { 229 | t metrics.TimerIface 230 | tags []metrics.Tag 231 | } 232 | 233 | func (t *timer) Record(d time.Duration) { 234 | t.t.Record(d, t.tags...) 235 | } 236 | 237 | func (c *Config) GaugeVec(name string, labelNames ...string) registry.GaugeVec { 238 | gaugeKey := newGaugeKey(c.namespace, name) 239 | c.m.Lock() 240 | defer c.m.Unlock() 241 | if g, ok := c.gauges[gaugeKey]; ok { 242 | return g 243 | } 244 | g := &gaugeVec{ 245 | g: c.handler.Gauge(c.namespace + c.separator + name), 246 | } 247 | c.gauges[gaugeKey] = g 248 | return g 249 | } 250 | 251 | func (c *Config) TimerVec(name string, labelNames ...string) registry.TimerVec { 252 | timersKey := newTimerKey(c.namespace, name, c.timerBuckets) 253 | c.m.Lock() 254 | defer c.m.Unlock() 255 | if t, ok := c.timers[timersKey]; ok { 256 | return t 257 | } 258 | t := &timerVec{ 259 | t: c.handler.Timer(c.namespace + c.separator + name), 260 | } 261 | c.timers[timersKey] = t 262 | return t 263 | } 264 | 265 | func (c *Config) HistogramVec(name string, buckets []float64, labelNames ...string) registry.HistogramVec { 266 | histogramsKey := newHistogramKey(c.namespace, name, buckets) 267 | c.m.Lock() 268 | defer c.m.Unlock() 269 | if h, ok := c.histograms[histogramsKey]; ok { 270 | return h 271 | } 272 | h := &histogramVec{ 273 | h: c.handler.Histogram(c.namespace+c.separator+name, "ms"), 274 | } 275 | c.histograms[histogramsKey] = h 276 | return h 277 | } 278 | 279 | func (c *Config) Details() trace.Details { 280 | return c.details 281 | } 282 | 283 | type option func(*Config) 284 | 285 | func WithNamespace(namespace string) option { 286 | return func(c *Config) { 287 | c.namespace = namespace 288 | } 289 | } 290 | 291 | func WithDetails(details trace.Details) option { 292 | return func(c *Config) { 293 | c.details |= details 294 | } 295 | } 296 | 297 | func WithSeparator(separator string) option { 298 | return func(c *Config) { 299 | c.separator = separator 300 | } 301 | } 302 | 303 | func WithTimerBuckets(timerBuckets []float64) option { 304 | return func(c *Config) { 305 | c.timerBuckets = timerBuckets 306 | } 307 | } 308 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/query.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/ydb-platform/ydb-go-sdk/v3/sugar" 8 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 9 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 10 | "go.temporal.io/server/common/primitives" 11 | 12 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/executor" 13 | ) 14 | 15 | type Query struct { 16 | Values map[string]types.Value 17 | QueryTemplate string 18 | } 19 | 20 | func NewQueryPart(values map[string]types.Value, queryTemplate string) *Query { 21 | return &Query{ 22 | Values: values, 23 | QueryTemplate: queryTemplate, 24 | } 25 | } 26 | 27 | func NewSingleArgQueryPart(argName string, v types.Value, queryTemplate string) *Query { 28 | return &Query{ 29 | Values: map[string]types.Value{argName: v}, 30 | QueryTemplate: queryTemplate, 31 | } 32 | } 33 | 34 | type Stmt struct { 35 | Declaration []string 36 | Queries []string 37 | Params []table.ParameterOption 38 | } 39 | 40 | func (b *Query) ToStmt(prefix string, conditionID primitives.UUID) *Stmt { 41 | var rv Stmt 42 | rv.Queries = append(rv.Queries, fmt.Sprintf(b.QueryTemplate, prefix)) 43 | 44 | for name, p := range b.Values { 45 | name = prefix + strings.TrimPrefix(name, "$") 46 | rv.Params = append(rv.Params, table.ValueParam("$"+name, p)) 47 | } 48 | 49 | decl, err := sugar.GenerateDeclareSection(rv.Params) 50 | if err != nil { 51 | panic(err) 52 | } 53 | rv.Declaration = append(rv.Declaration, decl) 54 | 55 | return &rv 56 | } 57 | 58 | var _ executor.Query[*Stmt] = (*Query)(nil) 59 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/ts.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "math" 5 | "time" 6 | ) 7 | 8 | var ( 9 | minYDBDateTime = getMinYDBDateTime() 10 | maxYDBDateTime = time.Unix(math.MaxInt32, 0).UTC() 11 | ) 12 | 13 | // ToYDBDateTime converts to time to YDB datetime 14 | func ToYDBDateTime(t time.Time) time.Time { 15 | if t.IsZero() { 16 | return minYDBDateTime 17 | } 18 | if t.After(maxYDBDateTime) { 19 | return maxYDBDateTime 20 | } 21 | return t.UTC() 22 | } 23 | 24 | // FromYDBDateTime converts YDB datetime and returns go time 25 | func FromYDBDateTime(t time.Time) time.Time { 26 | if t.Equal(minYDBDateTime) { 27 | return time.Time{}.UTC() 28 | } 29 | return t.UTC() 30 | } 31 | 32 | func getMinYDBDateTime() time.Time { 33 | t, err := time.Parse(time.RFC3339, "1000-01-01T00:00:00Z") 34 | if err != nil { 35 | return time.Unix(0, 0).UTC() 36 | } 37 | return t.UTC() 38 | } 39 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/conn/util.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/ydb-platform/ydb-go-sdk/v3/table/result" 8 | "go.temporal.io/api/serviceerror" 9 | ) 10 | 11 | func EnsureOneRowCursor(ctx context.Context, res result.Result) error { 12 | if err := res.NextResultSetErr(ctx); err != nil { 13 | return NewRootCauseError( 14 | serviceerror.NewInternal, fmt.Sprintf("failed to get first result set: %s", err.Error())) 15 | } 16 | if !res.NextRow() { 17 | return NewRootCauseError( 18 | serviceerror.NewNotFound, "failed to get first row: empty result") 19 | } 20 | if res.HasNextRow() { 21 | return NewRootCauseError( 22 | serviceerror.NewInternal, 23 | fmt.Sprintf("result contains more than one row (%d)", res.CurrentResultSet().RowCount())) 24 | } 25 | return nil 26 | } 27 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/const.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | const ydbPersistenceName = "ydb" 4 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/execution_store.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | import ( 4 | "context" 5 | "strconv" 6 | "time" 7 | 8 | "go.temporal.io/server/common/log" 9 | "go.temporal.io/server/common/metrics" 10 | p "go.temporal.io/server/common/persistence" 11 | 12 | "github.com/yandex/temporal-over-ydb/persistence/pkg/cache" 13 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 14 | ) 15 | 16 | type ( 17 | ExecutionStore struct { 18 | metricsHandler metrics.Handler 19 | enableDebugMetrics bool 20 | *HistoryStore 21 | *MutableStateStore 22 | *MutableStateTaskStore 23 | } 24 | ) 25 | 26 | var _ p.ExecutionStore = (*ExecutionStore)(nil) 27 | 28 | func NewExecutionStore( 29 | client *conn.Client, 30 | logger log.Logger, 31 | metricsHandler metrics.Handler, 32 | taskCacheFactory cache.TaskCacheFactory, 33 | ) *ExecutionStore { 34 | eventsCache := cache.NewEventsCache(taskCacheFactory) 35 | return &ExecutionStore{ 36 | metricsHandler: metricsHandler, 37 | enableDebugMetrics: true, // TODO 38 | HistoryStore: NewHistoryStore(client, logger), 39 | MutableStateStore: NewMutableStateStore(client, logger, eventsCache), 40 | MutableStateTaskStore: NewMutableStateTaskStore(client, logger, eventsCache, taskCacheFactory), 41 | } 42 | } 43 | 44 | func (d *ExecutionStore) CreateWorkflowExecution( 45 | ctx context.Context, 46 | request *p.InternalCreateWorkflowExecutionRequest, 47 | ) (resp *p.InternalCreateWorkflowExecutionResponse, err error) { 48 | defer func() { 49 | if err != nil { 50 | err = conn.ConvertToTemporalError("CreateWorkflowExecution", err) 51 | } 52 | }() 53 | startTime := time.Now().UTC() 54 | nodeCount, treeCount, err := d.HistoryStore.AppendHistoryNodesForCreateWorkflowExecutionRequest(ctx, request) 55 | if err != nil { 56 | return nil, err 57 | } 58 | handler := d.metricsHandler.WithTags(metrics.OperationTag(metrics.PersistenceCreateWorkflowExecutionScope)) 59 | 60 | if d.enableDebugMetrics { 61 | latency := time.Since(startTime) 62 | handler.WithTags( 63 | metrics.StringTag("tree_count", strconv.Itoa(treeCount)), 64 | metrics.StringTag("node_count", strconv.Itoa(nodeCount)), 65 | ).Timer("stage1").Record(latency) 66 | } 67 | 68 | startTime = time.Now().UTC() 69 | resp, err = d.MutableStateStore.createWorkflowExecution(ctx, request) 70 | if d.enableDebugMetrics { 71 | latency := time.Since(startTime) 72 | handler.Timer("stage2").Record(latency) 73 | } 74 | return resp, err 75 | } 76 | 77 | func (d *ExecutionStore) UpdateWorkflowExecution( 78 | ctx context.Context, 79 | request *p.InternalUpdateWorkflowExecutionRequest, 80 | ) (err error) { 81 | defer func() { 82 | if err != nil { 83 | err = conn.ConvertToTemporalError("UpdateWorkflowExecution", err) 84 | } 85 | }() 86 | startTime := time.Now().UTC() 87 | nodeCount, treeCount, err := d.HistoryStore.AppendHistoryNodesForUpdateWorkflowExecutionRequest(ctx, request) 88 | if err != nil { 89 | return err 90 | } 91 | handler := d.metricsHandler.WithTags(metrics.OperationTag(metrics.PersistenceUpdateWorkflowExecutionScope)) 92 | 93 | if d.enableDebugMetrics { 94 | latency := time.Since(startTime) 95 | handler.WithTags( 96 | metrics.StringTag("tree_count", strconv.Itoa(treeCount)), 97 | metrics.StringTag("node_count", strconv.Itoa(nodeCount)), 98 | ).Timer("stage1").Record(latency) 99 | } 100 | 101 | startTime = time.Now().UTC() 102 | err = d.MutableStateStore.updateWorkflowExecution(ctx, request) 103 | if d.enableDebugMetrics { 104 | latency := time.Since(startTime) 105 | handler.Timer("stage2").Record(latency) 106 | } 107 | return err 108 | } 109 | 110 | func (d *ExecutionStore) ConflictResolveWorkflowExecution( 111 | ctx context.Context, 112 | request *p.InternalConflictResolveWorkflowExecutionRequest, 113 | ) (err error) { 114 | defer func() { 115 | if err != nil { 116 | err = conn.ConvertToTemporalError("ConflictResolveWorkflowExecution", err) 117 | } 118 | }() 119 | startTime := time.Now().UTC() 120 | nodeCount, treeCount, err := d.HistoryStore.AppendHistoryNodesForConflictResolveWorkflowExecutionRequest(ctx, request) 121 | if err != nil { 122 | return err 123 | } 124 | handler := d.metricsHandler.WithTags(metrics.OperationTag(metrics.PersistenceConflictResolveWorkflowExecutionScope)) 125 | 126 | if d.enableDebugMetrics { 127 | latency := time.Since(startTime) 128 | handler.WithTags( 129 | metrics.StringTag("tree_count", strconv.Itoa(treeCount)), 130 | metrics.StringTag("node_count", strconv.Itoa(nodeCount)), 131 | ).Timer("stage1").Record(latency) 132 | } 133 | 134 | startTime = time.Now().UTC() 135 | err = d.MutableStateStore.conflictResolveWorkflowExecution(ctx, request) 136 | if d.enableDebugMetrics { 137 | latency := time.Since(startTime) 138 | handler.Timer("stage2").Record(latency) 139 | } 140 | return err 141 | } 142 | 143 | func (d *ExecutionStore) GetName() string { 144 | return ydbPersistenceName 145 | } 146 | 147 | func (d *ExecutionStore) Close() { 148 | 149 | } 150 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/factory.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "github.com/mitchellh/mapstructure" 8 | "github.com/ydb-platform/ydb-go-sdk/v3" 9 | "go.temporal.io/server/common/config" 10 | "go.temporal.io/server/common/log" 11 | "go.temporal.io/server/common/log/tag" 12 | "go.temporal.io/server/common/metrics" 13 | p "go.temporal.io/server/common/persistence" 14 | "go.temporal.io/server/common/persistence/client" 15 | "go.temporal.io/server/common/resolver" 16 | 17 | "github.com/yandex/temporal-over-ydb/persistence/pkg/cache" 18 | ydbconfig "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/config" 19 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 20 | ) 21 | 22 | type ( 23 | // Factory vends datastore implementations backed by YDB 24 | Factory struct { 25 | sync.RWMutex 26 | clusterName string 27 | cfg ydbconfig.Config 28 | logger log.Logger 29 | Client *conn.Client 30 | metricsHandler metrics.Handler 31 | taskCacheFactory cache.TaskCacheFactory 32 | 33 | clientOptions []ydb.Option 34 | } 35 | ) 36 | 37 | func OptionsToYDBConfig(options map[string]any) (ydbconfig.Config, error) { 38 | cfg := ydbconfig.Config{} 39 | if err := mapstructure.WeakDecode(options, &cfg); err != nil { 40 | return ydbconfig.Config{}, err 41 | } 42 | 43 | // YDB-only persistence always uses old types 44 | cfg.UseOldTypes = true 45 | 46 | if err := cfg.Validate(); err != nil { 47 | return ydbconfig.Config{}, err 48 | } 49 | 50 | return cfg, nil 51 | } 52 | 53 | type ydbAbstractDataStoreFactory struct { 54 | ydbClientOptions []ydb.Option 55 | } 56 | 57 | func NewYDBAbstractDataStoreFactory(ydbClientOptions ...ydb.Option) client.AbstractDataStoreFactory { 58 | return &ydbAbstractDataStoreFactory{ 59 | ydbClientOptions: ydbClientOptions, 60 | } 61 | } 62 | 63 | func (f *ydbAbstractDataStoreFactory) NewFactory( 64 | cfg config.CustomDatastoreConfig, 65 | r resolver.ServiceResolver, 66 | clusterName string, 67 | logger log.Logger, 68 | metricsHandler metrics.Handler, 69 | ) p.DataStoreFactory { 70 | return NewFactory( 71 | cfg, 72 | resolver.NewNoopResolver(), 73 | clusterName, 74 | logger, 75 | metricsHandler, 76 | f.ydbClientOptions, 77 | ) 78 | } 79 | 80 | // NewFactory returns an instance of a factory object which can be used to create 81 | // data stores that are backed by YDB 82 | func NewFactory( 83 | cfg config.CustomDatastoreConfig, 84 | r resolver.ServiceResolver, 85 | clusterName string, 86 | logger log.Logger, 87 | metricsHandler metrics.Handler, 88 | ydbClientOptions []ydb.Option, 89 | ) *Factory { 90 | ydbCfg, err := OptionsToYDBConfig(cfg.Options) 91 | if err != nil { 92 | logger.Fatal("unable to initialize custom datastore config for YDB", tag.Error(err)) 93 | } 94 | return NewFactoryFromYDBConfig(clusterName, ydbCfg, r, logger, metricsHandler, ydbClientOptions) 95 | } 96 | 97 | func NewFactoryFromYDBConfig( 98 | clusterName string, 99 | ydbCfg ydbconfig.Config, 100 | r resolver.ServiceResolver, 101 | logger log.Logger, 102 | metricsHandler metrics.Handler, 103 | ydbClientOptions []ydb.Option, 104 | ) *Factory { 105 | ydbCfg.Endpoint = r.Resolve(ydbCfg.Endpoint)[0] 106 | ydbClient, err := conn.NewClient(context.Background(), ydbCfg, logger, metricsHandler, ydbClientOptions...) 107 | if err != nil { 108 | logger.Fatal("unable to initialize YDB session", tag.Error(err)) 109 | } 110 | taskCacheFactory := cache.NewNoopTaskCacheFactory() 111 | // if v := os.Getenv("TEMPORAL_YDBPGX_CACHE_CAPACITY"); v != "" { 112 | // if cacheCapacity, err := strconv.Atoi(v); err == nil { 113 | // taskCacheFactory = cache.NewTaskCacheFactory(logger, metricsHandler, cacheCapacity) 114 | // } else { 115 | // logger.Warn("unable to parse TEMPORAL_YDBPGX_CACHE_CAPACITY", tag.Error(err)) 116 | // } 117 | // } 118 | // if cfg.ShardQueueCache.Capacity > 0 { 119 | // taskCacheFactory = cache.NewTaskCacheFactory(logger, metricsHandler, cfg.ShardQueueCache.Capacity) 120 | // } 121 | return &Factory{ 122 | clusterName: clusterName, 123 | cfg: ydbCfg, 124 | logger: logger, 125 | Client: ydbClient, 126 | metricsHandler: metricsHandler, 127 | taskCacheFactory: taskCacheFactory, 128 | } 129 | } 130 | 131 | // NewTaskStore returns a new task store 132 | func (f *Factory) NewTaskStore() (p.TaskStore, error) { 133 | return NewMatchingTaskStore(f.Client, f.logger), nil 134 | } 135 | 136 | // NewMirroringTaskStore returns a new task store 137 | func (f *Factory) NewMirroringTaskStore() (p.TaskStore, error) { 138 | return NewMirroringMatchingTaskStore(f.Client, f.logger), nil 139 | } 140 | 141 | // NewShardStore returns a new shard store 142 | func (f *Factory) NewShardStore() (p.ShardStore, error) { 143 | return NewShardStore(f.clusterName, f.Client, f.logger), nil 144 | } 145 | 146 | // NewMetadataStore returns a metadata store 147 | func (f *Factory) NewMetadataStore() (p.MetadataStore, error) { 148 | return NewMetadataStore(f.clusterName, f.Client, f.logger) 149 | } 150 | 151 | // NewMirroringMetadataStore returns a metadata store 152 | func (f *Factory) NewMirroringMetadataStore() (*MirroringMetadataStore, error) { 153 | return NewMirroringMetadataStore(f.Client) 154 | } 155 | 156 | // NewClusterMetadataStore returns a metadata store 157 | func (f *Factory) NewClusterMetadataStore() (p.ClusterMetadataStore, error) { 158 | return NewClusterMetadataStore(f.Client, f.logger) 159 | } 160 | 161 | // NewMirroringClusterMetadataStore returns a new metadata store 162 | func (f *Factory) NewMirroringClusterMetadataStore() (*MirroringClusterMetadataStore, error) { 163 | return NewMirroringClusterMetadataStore(f.Client) 164 | } 165 | 166 | // NewExecutionStore returns a new ExecutionStore. 167 | func (f *Factory) NewExecutionStore() (p.ExecutionStore, error) { 168 | return NewExecutionStore(f.Client, f.logger, f.metricsHandler, f.taskCacheFactory), nil 169 | } 170 | 171 | // NewQueue returns a new queue backed by YDB 172 | func (f *Factory) NewQueue(queueType p.QueueType) (p.Queue, error) { 173 | return NewQueueStore(queueType, f.Client, f.logger) 174 | } 175 | 176 | func (f *Factory) NewQueueV2() (p.QueueV2, error) { 177 | return NewQueueStoreV2(f.Client, f.logger) 178 | } 179 | 180 | func (f *Factory) NewNexusEndpointStore() (p.NexusEndpointStore, error) { 181 | return NewNexusEndpointStore(f.Client, f.logger) 182 | } 183 | 184 | // Close closes the factory 185 | func (f *Factory) Close() { 186 | f.Lock() 187 | defer f.Unlock() 188 | _ = f.Client.Close(context.TODO()) 189 | } 190 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/mirroring_cluster_metadata_store.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 8 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 9 | p "go.temporal.io/server/common/persistence" 10 | 11 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 12 | ) 13 | 14 | type ( 15 | MirroringClusterMetadataStore struct { 16 | client *conn.Client 17 | } 18 | ) 19 | 20 | // NewMirroringClusterMetadataStore is used to create an instance of MirroringClusterMetadataStore implementation 21 | func NewMirroringClusterMetadataStore( 22 | client *conn.Client, 23 | ) (*MirroringClusterMetadataStore, error) { 24 | return &MirroringClusterMetadataStore{ 25 | client: client, 26 | }, nil 27 | } 28 | 29 | func (m *MirroringClusterMetadataStore) SaveClusterMetadata( 30 | ctx context.Context, 31 | request *p.InternalSaveClusterMetadataRequest, 32 | ) (rv bool, err error) { 33 | defer func() { 34 | if err != nil { 35 | err = conn.ConvertToTemporalError("SaveClusterMetadata", err) 36 | } 37 | }() 38 | 39 | template := m.client.AddQueryPrefix(` 40 | DECLARE $cluster_name AS utf8; 41 | DECLARE $data AS string; 42 | DECLARE $data_encoding AS ` + m.client.EncodingType().String() + `; 43 | DECLARE $version AS int64; 44 | 45 | UPSERT INTO cluster_metadata_info (cluster_name, data, data_encoding, version) 46 | VALUES ($cluster_name, $data, $data_encoding, $version); 47 | `) 48 | if err = m.client.Write(ctx, template, table.NewQueryParameters( 49 | table.ValueParam("$cluster_name", types.UTF8Value(request.ClusterName)), 50 | table.ValueParam("$data", types.BytesValue(request.ClusterMetadata.Data)), 51 | table.ValueParam("$data_encoding", m.client.EncodingTypeValue(request.ClusterMetadata.EncodingType)), 52 | table.ValueParam("$version", types.Int64Value(request.Version)), 53 | )); err != nil { 54 | return false, err 55 | } 56 | return true, nil 57 | } 58 | 59 | func (m *MirroringClusterMetadataStore) UpsertClusterMembership( 60 | ctx context.Context, 61 | request *p.UpsertClusterMembershipRequest, 62 | ) error { 63 | template := m.client.AddQueryPrefix(m.client.HostIDDecl() + ` 64 | DECLARE $rpc_address AS utf8; 65 | DECLARE $rpc_port AS int32; 66 | DECLARE $role AS int32; 67 | DECLARE $session_start AS Timestamp; 68 | DECLARE $last_heartbeat AS Timestamp; 69 | DECLARE $expire_at AS Timestamp; 70 | 71 | UPSERT INTO cluster_membership (host_id, rpc_address, rpc_port, role, session_start, last_heartbeat, expire_at) 72 | VALUES ($host_id, $rpc_address, $rpc_port, $role, $session_start, $last_heartbeat, $expire_at); 73 | `) 74 | err := m.client.Write2(ctx, template, func() *table.QueryParameters { 75 | now := time.Now() 76 | return table.NewQueryParameters( 77 | table.ValueParam("$host_id", m.client.HostIDValueFromUUID(request.HostID)), 78 | table.ValueParam("$rpc_address", types.UTF8Value(request.RPCAddress.String())), 79 | table.ValueParam("$rpc_port", types.Int32Value(int32(request.RPCPort))), 80 | table.ValueParam("$role", types.Int32Value(int32(request.Role))), 81 | table.ValueParam("$session_start", types.TimestampValueFromTime(conn.ToYDBDateTime(request.SessionStart))), 82 | table.ValueParam("$last_heartbeat", types.TimestampValueFromTime(conn.ToYDBDateTime(now))), 83 | table.ValueParam("$expire_at", types.TimestampValueFromTime(conn.ToYDBDateTime(now.Add(request.RecordExpiry)))), 84 | ) 85 | }) 86 | return conn.ConvertToTemporalError("UpsertClusterMembership", err) 87 | } 88 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/mirroring_metadata_store.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 7 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 8 | p "go.temporal.io/server/common/persistence" 9 | 10 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 11 | ) 12 | 13 | type ( 14 | MirroringMetadataStore struct { 15 | client *conn.Client 16 | } 17 | ) 18 | 19 | // NewMirroringMetadataStore is used to create an instance of the Namespace MirroringMetadataStore implementation 20 | func NewMirroringMetadataStore( 21 | client *conn.Client, 22 | ) (*MirroringMetadataStore, error) { 23 | return &MirroringMetadataStore{ 24 | client: client, 25 | }, nil 26 | } 27 | 28 | func (m *MirroringMetadataStore) UpsertNamespace( 29 | ctx context.Context, 30 | request *p.InternalUpdateNamespaceRequest, 31 | ) error { 32 | err := m.client.DB.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) error { 33 | e := conn.NewExecutorFromTransactionActor(tx) 34 | 35 | template := m.client.AddQueryPrefix(m.client.NamspaceIDDecl() + ` 36 | DECLARE $detail AS string; 37 | DECLARE $detail_encoding AS ` + m.client.EncodingType().String() + `; 38 | DECLARE $notification_version AS int64; 39 | DECLARE $is_global_namespace AS bool; 40 | DECLARE $name AS utf8; 41 | DECLARE $metadata_record_name AS utf8; 42 | 43 | UPSERT INTO namespaces_by_id (id, name) 44 | VALUES ($namespace_id, $name); 45 | 46 | UPSERT INTO namespaces (id, name, detail, detail_encoding, is_global_namespace, notification_version) 47 | VALUES ($namespace_id, $name, $detail, $detail_encoding, $is_global_namespace, $notification_version); 48 | `) 49 | params := table.NewQueryParameters( 50 | table.ValueParam("$namespace_id", m.client.NamespaceIDValue(request.Id)), 51 | table.ValueParam("$name", types.UTF8Value(request.Name)), 52 | table.ValueParam("$detail", types.BytesValue(request.Namespace.Data)), 53 | table.ValueParam("$detail_encoding", m.client.EncodingTypeValue(request.Namespace.EncodingType)), 54 | table.ValueParam("$notification_version", types.Int64Value(request.NotificationVersion)), 55 | table.ValueParam("$is_global_namespace", types.BoolValue(request.IsGlobal)), 56 | table.ValueParam("$metadata_record_name", types.UTF8Value(namespaceMetadataRecordName)), 57 | ) 58 | return e.Write(ctx, template, params) 59 | }) 60 | return conn.ConvertToTemporalError("UpsertNamespace", err) 61 | } 62 | 63 | func (m *MirroringMetadataStore) SetMetadata(ctx context.Context, notificationVersion int64) error { 64 | err := m.client.DB.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) error { 65 | e := conn.NewExecutorFromTransactionActor(tx) 66 | 67 | template := m.client.AddQueryPrefix(` 68 | DECLARE $notification_version AS int64; 69 | DECLARE $metadata_record_name AS utf8; 70 | 71 | UPSERT INTO namespaces (name, notification_version) 72 | VALUES ($metadata_record_name, $notification_version); 73 | `) 74 | params := table.NewQueryParameters( 75 | table.ValueParam("$notification_version", types.Int64Value(notificationVersion)), 76 | table.ValueParam("$metadata_record_name", types.UTF8Value(namespaceMetadataRecordName)), 77 | ) 78 | return e.Write(ctx, template, params) 79 | }) 80 | if err != nil { 81 | return conn.ConvertToTemporalError("SetMetadata", err) 82 | } 83 | return nil 84 | } 85 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/mutable_state_store.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | 8 | "github.com/ydb-platform/ydb-go-sdk/v3/query" 9 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 10 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 11 | commonpb "go.temporal.io/api/common/v1" 12 | enumspb "go.temporal.io/api/enums/v1" 13 | "go.temporal.io/api/serviceerror" 14 | "go.temporal.io/server/common/log" 15 | p "go.temporal.io/server/common/persistence" 16 | "go.temporal.io/server/common/persistence/serialization" 17 | "go.temporal.io/server/common/primitives" 18 | 19 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/executor" 20 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/mss" 21 | baserows "github.com/yandex/temporal-over-ydb/persistence/pkg/base/rows" 22 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 23 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/rows" 24 | ) 25 | 26 | type ( 27 | MutableStateStore struct { 28 | client *conn.Client 29 | logger log.Logger 30 | baseStore *mss.BaseMutableStateStore 31 | tf executor.TransactionFactory 32 | } 33 | ) 34 | 35 | func NewMutableStateStore( 36 | client *conn.Client, 37 | logger log.Logger, 38 | cache executor.EventsCache, 39 | ) *MutableStateStore { 40 | tf := rows.NewTransactionFactory(client) 41 | return &MutableStateStore{ 42 | client: client, 43 | logger: logger, 44 | baseStore: mss.NewBaseMutableStateStore(cache), 45 | tf: tf, 46 | } 47 | } 48 | 49 | func (d *MutableStateStore) createWorkflowExecution( 50 | ctx context.Context, 51 | request *p.InternalCreateWorkflowExecutionRequest, 52 | ) (resp *p.InternalCreateWorkflowExecutionResponse, err error) { 53 | return d.baseStore.CreateWorkflowExecution(ctx, request, d.tf) 54 | } 55 | 56 | func (d *MutableStateStore) updateWorkflowExecution( 57 | ctx context.Context, 58 | request *p.InternalUpdateWorkflowExecutionRequest, 59 | ) (err error) { 60 | return d.baseStore.UpdateWorkflowExecution(ctx, request, d.tf) 61 | } 62 | 63 | func (d *MutableStateStore) conflictResolveWorkflowExecution( 64 | ctx context.Context, 65 | request *p.InternalConflictResolveWorkflowExecutionRequest, 66 | ) (err error) { 67 | return d.baseStore.ConflictResolveWorkflowExecution(ctx, request, d.tf) 68 | } 69 | 70 | func (d *MutableStateStore) SetWorkflowExecution(ctx context.Context, request *p.InternalSetWorkflowExecutionRequest) (err error) { 71 | err = d.baseStore.SetWorkflowExecution(ctx, request, d.tf) 72 | if err != nil { 73 | err = conn.ConvertToTemporalError("SetWorkflowExecution", err) 74 | } 75 | return 76 | } 77 | 78 | func (d *MutableStateStore) GetWorkflowExecution( 79 | ctx context.Context, 80 | request *p.GetWorkflowExecutionRequest, 81 | ) (resp *p.InternalGetWorkflowExecutionResponse, err error) { 82 | defer func() { 83 | if err != nil { 84 | details := fmt.Sprintf("shard_id: %v, namespace_id: %v, workflow_id: %v, run_id: %v", 85 | request.ShardID, request.NamespaceID, request.WorkflowID, request.RunID) 86 | err = conn.ConvertToTemporalError("GetWorkflowExecution", err, details) 87 | } 88 | }() 89 | 90 | // With a single SELECT we read the execution row and all its events 91 | template := d.client.AddQueryPrefix(d.client.NamspaceIDDecl() + d.client.RunIDDecl() + ` 92 | DECLARE $shard_id AS uint32; 93 | DECLARE $workflow_id AS Utf8; 94 | 95 | SELECT 96 | execution, execution_encoding, execution_state, execution_state_encoding, 97 | next_event_id, checksum, checksum_encoding, db_record_version, 98 | event_type, event_id, event_name, data, data_encoding 99 | FROM executions 100 | WHERE shard_id = $shard_id 101 | AND namespace_id = $namespace_id 102 | AND workflow_id = $workflow_id 103 | AND run_id = $run_id 104 | AND task_id IS NULL 105 | AND task_category_id IS NULL 106 | AND task_visibility_ts IS NULL; 107 | `) 108 | res, err := d.client.Query(ctx, template, table.NewQueryParameters( 109 | table.ValueParam("$shard_id", types.Uint32Value(rows.ToShardIDColumnValue(request.ShardID))), 110 | table.ValueParam("$namespace_id", d.client.NamespaceIDValue(request.NamespaceID)), 111 | table.ValueParam("$workflow_id", types.UTF8Value(request.WorkflowID)), 112 | table.ValueParam("$run_id", d.client.RunIDValue(request.RunID)), 113 | )) 114 | if err != nil { 115 | return 116 | } 117 | 118 | defer func() { 119 | err2 := res.Close(ctx) 120 | if err == nil { 121 | err = err2 122 | } 123 | }() 124 | 125 | state, dbRecordVersion, err := d.scanMutableState(ctx, res) 126 | if err != nil { 127 | return nil, err 128 | } 129 | 130 | return &p.InternalGetWorkflowExecutionResponse{ 131 | State: state, 132 | DBRecordVersion: dbRecordVersion, 133 | }, nil 134 | } 135 | 136 | func (d *MutableStateStore) DeleteWorkflowExecution( 137 | ctx context.Context, 138 | request *p.DeleteWorkflowExecutionRequest, 139 | ) error { 140 | template := d.client.AddQueryPrefix(d.client.NamspaceIDDecl() + d.client.RunIDDecl() + ` 141 | DECLARE $shard_id AS uint32; 142 | DECLARE $workflow_id AS Utf8; 143 | 144 | DELETE FROM executions 145 | WHERE shard_id = $shard_id 146 | AND namespace_id = $namespace_id 147 | AND workflow_id = $workflow_id 148 | AND run_id = $run_id 149 | AND task_id IS NULL 150 | AND task_category_id IS NULL 151 | AND task_visibility_ts IS NULL 152 | ; 153 | `) 154 | err := d.client.Write(ctx, template, table.NewQueryParameters( 155 | table.ValueParam("$shard_id", types.Uint32Value(rows.ToShardIDColumnValue(request.ShardID))), 156 | table.ValueParam("$namespace_id", d.client.NamespaceIDValue(request.NamespaceID)), 157 | table.ValueParam("$workflow_id", types.UTF8Value(request.WorkflowID)), 158 | table.ValueParam("$run_id", d.client.RunIDValue(request.RunID)), 159 | )) 160 | 161 | if err != nil { 162 | return conn.ConvertToTemporalError("DeleteWorkflowExecution", err) 163 | } 164 | return nil 165 | } 166 | 167 | func (d *MutableStateStore) DeleteCurrentWorkflowExecution( 168 | ctx context.Context, 169 | request *p.DeleteCurrentWorkflowExecutionRequest, 170 | ) error { 171 | template := d.client.AddQueryPrefix(d.client.NamspaceIDDecl() + d.client.RunIDDecl() + d.client.CurrentRunIDDecl() + ` 172 | DECLARE $shard_id AS uint32; 173 | DECLARE $workflow_id AS Utf8; 174 | 175 | DELETE FROM executions 176 | WHERE shard_id = $shard_id 177 | AND namespace_id = $namespace_id 178 | AND workflow_id = $workflow_id 179 | AND run_id = $run_id 180 | AND task_id IS NULL 181 | AND task_category_id IS NULL 182 | AND task_visibility_ts IS NULL 183 | AND event_type IS NULL 184 | AND event_id IS NULL 185 | AND event_name IS NULL 186 | AND current_run_id = $current_run_id; 187 | `) 188 | err := d.client.Write(ctx, template, table.NewQueryParameters( 189 | table.ValueParam("$shard_id", types.Uint32Value(rows.ToShardIDColumnValue(request.ShardID))), 190 | table.ValueParam("$namespace_id", d.client.NamespaceIDValue(request.NamespaceID)), 191 | table.ValueParam("$workflow_id", types.UTF8Value(request.WorkflowID)), 192 | table.ValueParam("$run_id", d.client.EmptyRunIDValue()), 193 | table.ValueParam("$current_run_id", d.client.RunIDValue(request.RunID)), 194 | )) 195 | if err != nil { 196 | return conn.ConvertToTemporalError("DeleteCurrentWorkflowExecution", err) 197 | } 198 | return nil 199 | } 200 | 201 | func (d *MutableStateStore) GetCurrentExecution( 202 | ctx context.Context, 203 | request *p.GetCurrentExecutionRequest, 204 | ) (resp *p.InternalGetCurrentExecutionResponse, err error) { 205 | defer func() { 206 | if err != nil { 207 | details := fmt.Sprintf("shard_id: %v, namespace_id: %v, workflow_id: %v", request.ShardID, request.NamespaceID, request.WorkflowID) 208 | err = conn.ConvertToTemporalError("GetCurrentExecution", err, details) 209 | } 210 | }() 211 | query := d.client.AddQueryPrefix(d.client.NamspaceIDDecl() + d.client.RunIDDecl() + ` 212 | DECLARE $shard_id AS uint32; 213 | DECLARE $workflow_id AS Utf8; 214 | 215 | SELECT current_run_id, execution_state, execution_state_encoding 216 | FROM executions 217 | WHERE shard_id = $shard_id 218 | AND namespace_id = $namespace_id 219 | AND workflow_id = $workflow_id 220 | AND run_id = $run_id 221 | AND task_id IS NULL 222 | AND task_category_id IS NULL 223 | AND task_visibility_ts IS NULL 224 | AND event_type IS NULL 225 | AND event_id IS NULL 226 | AND event_name IS NULL 227 | LIMIT 1; 228 | `) 229 | res, err := d.client.Do(ctx, query, table.OnlineReadOnlyTxControl(), table.NewQueryParameters( 230 | table.ValueParam("$shard_id", types.Uint32Value(rows.ToShardIDColumnValue(request.ShardID))), 231 | table.ValueParam("$namespace_id", d.client.NamespaceIDValue(request.NamespaceID)), 232 | table.ValueParam("$workflow_id", types.UTF8Value(request.WorkflowID)), 233 | table.ValueParam("$run_id", d.client.EmptyRunIDValue()), 234 | ), table.WithIdempotent()) 235 | if err != nil { 236 | return nil, err 237 | } 238 | defer func() { 239 | err2 := res.Close() 240 | if err == nil { 241 | err = err2 242 | } 243 | }() 244 | 245 | if err = conn.EnsureOneRowCursor(ctx, res); err != nil { 246 | return nil, err 247 | } 248 | var data []byte 249 | var encoding string 250 | var encodingType conn.EncodingTypeRaw 251 | var encodingPtr interface{} 252 | if d.client.UseIntForEncoding() { 253 | encodingPtr = &encodingType 254 | } else { 255 | encodingPtr = &encoding 256 | } 257 | 258 | var currentRunID string 259 | var currentRunIDBytes []byte 260 | var currentRunIDPtr interface{} 261 | if d.client.UseBytesForRunIDs() { 262 | currentRunIDPtr = ¤tRunIDBytes 263 | } else { 264 | currentRunIDPtr = ¤tRunID 265 | } 266 | 267 | if err = res.ScanWithDefaults(currentRunIDPtr, &data, encodingPtr); err != nil { 268 | return nil, fmt.Errorf("failed to scan current workflow execution row: %w", err) 269 | } 270 | if d.client.UseBytesForRunIDs() { 271 | currentRunID = primitives.UUIDString(currentRunIDBytes) 272 | } 273 | if d.client.UseIntForEncoding() { 274 | encoding = enumspb.EncodingType(encodingType).String() 275 | } 276 | blob := p.NewDataBlob(data, encoding) 277 | executionState, err := serialization.WorkflowExecutionStateFromBlob(blob.Data, blob.EncodingType.String()) 278 | if err != nil { 279 | return nil, err 280 | } 281 | 282 | return &p.InternalGetCurrentExecutionResponse{ 283 | RunID: currentRunID, 284 | ExecutionState: executionState, 285 | }, nil 286 | } 287 | 288 | func (d *MutableStateStore) ListConcreteExecutions( 289 | _ context.Context, 290 | _ *p.ListConcreteExecutionsRequest, 291 | ) (*p.InternalListConcreteExecutionsResponse, error) { 292 | return nil, serviceerror.NewUnimplemented("ListConcreteExecutions is not implemented") 293 | } 294 | 295 | func (d *MutableStateStore) scanMutableState(ctx context.Context, res query.ResultSet) (*p.InternalWorkflowMutableState, int64, error) { 296 | var resultDBRecordVersion int64 297 | 298 | state := &p.InternalWorkflowMutableState{ 299 | ActivityInfos: make(map[int64]*commonpb.DataBlob), 300 | TimerInfos: make(map[string]*commonpb.DataBlob), 301 | ChildExecutionInfos: make(map[int64]*commonpb.DataBlob), 302 | RequestCancelInfos: make(map[int64]*commonpb.DataBlob), 303 | SignalInfos: make(map[int64]*commonpb.DataBlob), 304 | } 305 | 306 | for row, err := range res.Rows(ctx) { 307 | if err != nil { 308 | return nil, 0, fmt.Errorf("failed to get row: %w", err) 309 | } 310 | 311 | var executionData []byte 312 | var executionEncoding string 313 | var encodingType conn.EncodingTypeRaw 314 | var encodingPtr any 315 | if d.client.UseIntForEncoding() { 316 | encodingPtr = &encodingType 317 | } else { 318 | encodingPtr = &executionEncoding 319 | } 320 | var nextEventID int64 321 | var stateData []byte 322 | var stateEncoding string 323 | var stateEncodingType conn.EncodingTypeRaw 324 | var stateEncodingPtr any 325 | if d.client.UseIntForEncoding() { 326 | stateEncodingPtr = &stateEncodingType 327 | } else { 328 | stateEncodingPtr = &stateEncoding 329 | } 330 | var checksumData []byte 331 | var checksumEncoding string 332 | var checksumEncodingType conn.EncodingTypeRaw 333 | var checksumEncodingPtr any 334 | if d.client.UseIntForEncoding() { 335 | checksumEncodingPtr = &checksumEncodingType 336 | } else { 337 | checksumEncodingPtr = &checksumEncoding 338 | } 339 | var eventType int32 340 | var eventID int64 341 | var eventName string 342 | var eventData []byte 343 | var eventEncoding string 344 | var eventEncodingType conn.EncodingTypeRaw 345 | var eventEncodingPtr any 346 | if d.client.UseIntForEncoding() { 347 | eventEncodingPtr = &eventEncodingType 348 | } else { 349 | eventEncodingPtr = &eventEncoding 350 | } 351 | var dbRecordVersion int64 352 | if err := row.ScanNamed( 353 | query.Named("next_event_id", &nextEventID), 354 | query.Named("execution", &executionData), 355 | query.Named("execution_encoding", encodingPtr), 356 | query.Named("db_record_version", &dbRecordVersion), 357 | query.Named("execution_state", &stateData), 358 | query.Named("execution_state_encoding", stateEncodingPtr), 359 | query.Named("checksum", &checksumData), 360 | query.Named("checksum_encoding", checksumEncodingPtr), 361 | query.Named("event_type", &eventType), 362 | query.Named("event_id", &eventID), 363 | query.Named("event_name", &eventName), 364 | query.Named("data_encoding", eventEncodingPtr), 365 | query.Named("data", &eventData), 366 | ); err != nil { 367 | return nil, 0, fmt.Errorf("failed to scan execution: %w", err) 368 | } 369 | 370 | if d.client.UseIntForEncoding() { 371 | executionEncoding = enumspb.EncodingType(encodingType).String() 372 | stateEncoding = enumspb.EncodingType(stateEncodingType).String() 373 | checksumEncoding = enumspb.EncodingType(checksumEncodingType).String() 374 | eventEncoding = enumspb.EncodingType(eventEncodingType).String() 375 | } 376 | 377 | if eventID > 0 || len(eventName) > 0 { 378 | switch eventType { 379 | case baserows.ItemTypeActivity: 380 | state.ActivityInfos[eventID] = p.NewDataBlob(eventData, eventEncoding) 381 | case baserows.ItemTypeTimer: 382 | state.TimerInfos[eventName] = p.NewDataBlob(eventData, eventEncoding) 383 | case baserows.ItemTypeChildExecution: 384 | state.ChildExecutionInfos[eventID] = p.NewDataBlob(eventData, eventEncoding) 385 | case baserows.ItemTypeRequestCancel: 386 | state.RequestCancelInfos[eventID] = p.NewDataBlob(eventData, eventEncoding) 387 | case baserows.ItemTypeSignal: 388 | state.SignalInfos[eventID] = p.NewDataBlob(eventData, eventEncoding) 389 | case baserows.ItemTypeSignalRequested: 390 | state.SignalRequestedIDs = append(state.SignalRequestedIDs, eventName) 391 | case baserows.ItemTypeBufferedEvent: 392 | state.BufferedEvents = append(state.BufferedEvents, p.NewDataBlob(eventData, eventEncoding)) 393 | default: 394 | return nil, 0, fmt.Errorf("unknown event type: %d", eventType) 395 | } 396 | } else { 397 | if state.ExecutionInfo != nil { 398 | return nil, 0, errors.New("got multiple executions rows") 399 | } 400 | state.ExecutionInfo = p.NewDataBlob(executionData, executionEncoding) 401 | state.ExecutionState = p.NewDataBlob(stateData, stateEncoding) 402 | state.Checksum = p.NewDataBlob(checksumData, checksumEncoding) 403 | state.NextEventID = nextEventID 404 | resultDBRecordVersion = dbRecordVersion 405 | } 406 | } 407 | 408 | if state.ExecutionInfo == nil { 409 | // TODO: return Unavailable instead of NotFound if we have seen at least one row 410 | return nil, 0, conn.NewRootCauseError(serviceerror.NewNotFound, "workflow execution not found") 411 | } 412 | 413 | return state, resultDBRecordVersion, nil 414 | } 415 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/nexus_store.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 8 | "github.com/ydb-platform/ydb-go-sdk/v3/table/result/named" 9 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 10 | enumspb "go.temporal.io/api/enums/v1" 11 | "go.temporal.io/api/serviceerror" 12 | "go.temporal.io/server/common/log" 13 | p "go.temporal.io/server/common/persistence" 14 | 15 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/tokens" 16 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 17 | ) 18 | 19 | type ( 20 | NexusEndpointStore struct { 21 | client *conn.Client 22 | logger log.Logger 23 | } 24 | ) 25 | 26 | const tableVersionEndpointID = `00000000-0000-0000-0000-000000000000` 27 | 28 | const ( 29 | rowTypePartitionStatus = iota 30 | rowTypeNexusEndpoint 31 | ) 32 | 33 | func NewNexusEndpointStore( 34 | client *conn.Client, 35 | logger log.Logger, 36 | ) (p.NexusEndpointStore, error) { 37 | return &NexusEndpointStore{ 38 | client: client, 39 | logger: logger, 40 | }, nil 41 | } 42 | 43 | func (s *NexusEndpointStore) Close() { 44 | } 45 | 46 | func (s *NexusEndpointStore) GetName() string { 47 | return ydbPersistenceName 48 | } 49 | 50 | func (s *NexusEndpointStore) CreateOrUpdateNexusEndpoint(ctx context.Context, request *p.InternalCreateOrUpdateNexusEndpointRequest) error { 51 | declare := ` 52 | DECLARE $partition_status_type AS Int32; 53 | DECLARE $table_version_id AS String; 54 | DECLARE $table_expected_version AS Int64; 55 | DECLARE $table_new_version AS Int64; 56 | 57 | DECLARE $endpoint_type AS Int32; 58 | DECLARE $endpoint_id AS String; 59 | DECLARE $endpoint_expected_version AS Int64; 60 | DECLARE $endpoint_new_version AS Int64; 61 | DECLARE $endpoint_data AS String; 62 | DECLARE $endpoint_data_encoding AS Int16; 63 | 64 | DISCARD SELECT $table_expected_version, $endpoint_expected_version; 65 | ` 66 | 67 | var endpointTemplate string 68 | if request.Endpoint.Version == 0 { 69 | endpointTemplate = ` 70 | INSERT INTO nexus_endpoints(type, id, data, data_encoding, version) 71 | VALUES($endpoint_type, $endpoint_id, $endpoint_data, $endpoint_data_encoding, $endpoint_new_version) 72 | ; 73 | ` 74 | } else { 75 | endpointTemplate = ` 76 | DISCARD SELECT Ensure(version, version == $endpoint_expected_version, "ENDPOINT_VERSION_MISMATCH") 77 | FROM nexus_endpoints 78 | WHERE type = $endpoint_type AND id = $endpoint_id 79 | ; 80 | 81 | UPDATE nexus_endpoints 82 | SET data = $endpoint_data, data_encoding = $endpoint_data_encoding, version = $endpoint_new_version 83 | WHERE type = $endpoint_type AND id = $endpoint_id 84 | ; 85 | ` 86 | } 87 | 88 | var versionTemplate string 89 | if request.LastKnownTableVersion == 0 { 90 | versionTemplate = ` 91 | INSERT INTO nexus_endpoints(type, id, version) 92 | VALUES ($partition_status_type, $table_version_id, $table_new_version) 93 | ; 94 | ` 95 | } else { 96 | versionTemplate = ` 97 | DISCARD SELECT Ensure(version, version == $table_expected_version, "TABLE_VERSION_MISMATCH") 98 | FROM nexus_endpoints 99 | WHERE type = $partition_status_type AND id = $table_version_id 100 | ; 101 | 102 | UPDATE nexus_endpoints 103 | SET version = $table_new_version 104 | WHERE type = $partition_status_type AND id = $table_version_id 105 | ; 106 | ` 107 | } 108 | 109 | template := s.client.AddQueryPrefix(declare + endpointTemplate + versionTemplate) 110 | params := table.NewQueryParameters( 111 | // table 112 | table.ValueParam("$partition_status_type", types.Int32Value(rowTypePartitionStatus)), 113 | table.ValueParam("$table_version_id", types.StringValueFromString(tableVersionEndpointID)), 114 | table.ValueParam("$table_expected_version", types.Int64Value(request.LastKnownTableVersion)), 115 | table.ValueParam("$table_new_version", types.Int64Value(request.LastKnownTableVersion+1)), 116 | // endpoint 117 | table.ValueParam("$endpoint_type", types.Int32Value(rowTypeNexusEndpoint)), 118 | table.ValueParam("$endpoint_id", types.BytesValue([]byte(request.Endpoint.ID))), 119 | table.ValueParam("$endpoint_expected_version", types.Int64Value(request.Endpoint.Version)), 120 | table.ValueParam("$endpoint_new_version", types.Int64Value(request.Endpoint.Version+1)), 121 | table.ValueParam("$endpoint_data", types.BytesValue(request.Endpoint.Data.Data)), 122 | table.ValueParam("$endpoint_data_encoding", s.client.NewEncodingTypeValue(request.Endpoint.Data.EncodingType)), 123 | ) 124 | 125 | err := s.client.Write(ctx, template, params, table.WithIdempotent()) 126 | if err != nil { 127 | if conn.IsPreconditionFailedAndContains(err, "ENDPOINT_VERSION_MISMATCH") || conn.IsPreconditionFailedAndContains(err, "Conflict with existing key") { 128 | return p.ErrNexusEndpointVersionConflict 129 | } else if conn.IsPreconditionFailedAndContains(err, "TABLE_VERSION_MISMATCH") { 130 | return p.ErrNexusTableVersionConflict 131 | } 132 | return conn.ConvertToTemporalError("CreateOrUpdateNexusEndpoint", err) 133 | } 134 | 135 | return nil 136 | } 137 | 138 | func (s *NexusEndpointStore) DeleteNexusEndpoint(ctx context.Context, request *p.DeleteNexusEndpointRequest) error { 139 | template := s.client.AddQueryPrefix(` 140 | DECLARE $partition_status_type AS Int32; 141 | DECLARE $table_version_id AS String; 142 | DECLARE $table_expected_version AS Int64; 143 | DECLARE $table_new_version AS Int64; 144 | 145 | DECLARE $endpoint_type AS Int32; 146 | DECLARE $endpoint_id AS String; 147 | 148 | DISCARD SELECT Ensure(version, version == $table_expected_version, "TABLE_VERSION_MISMATCH") 149 | FROM nexus_endpoints 150 | WHERE type = $partition_status_type AND id = $table_version_id 151 | ; 152 | 153 | DISCARD SELECT Ensure(0, Count(*) > 0, "ENDPOINT_DOES_NOT_EXIST") 154 | FROM nexus_endpoints 155 | WHERE type = $endpoint_type AND id = $endpoint_id 156 | ; 157 | 158 | UPDATE nexus_endpoints 159 | SET version = $table_new_version 160 | WHERE type = $partition_status_type AND id = $table_version_id 161 | ; 162 | 163 | DELETE FROM nexus_endpoints 164 | WHERE type = $endpoint_type AND id = $endpoint_id 165 | ; 166 | `) 167 | 168 | params := table.NewQueryParameters( 169 | // table 170 | table.ValueParam("$partition_status_type", types.Int32Value(rowTypePartitionStatus)), 171 | table.ValueParam("$table_version_id", types.StringValueFromString(tableVersionEndpointID)), 172 | table.ValueParam("$table_expected_version", types.Int64Value(request.LastKnownTableVersion)), 173 | table.ValueParam("$table_new_version", types.Int64Value(request.LastKnownTableVersion+1)), 174 | // endpoint 175 | table.ValueParam("$endpoint_type", types.Int32Value(rowTypeNexusEndpoint)), 176 | table.ValueParam("$endpoint_id", types.BytesValue([]byte(request.ID))), 177 | ) 178 | 179 | err := s.client.Write(ctx, template, params, table.WithIdempotent()) 180 | if err != nil { 181 | if conn.IsPreconditionFailedAndContains(err, "ENDPOINT_DOES_NOT_EXIST") { 182 | return serviceerror.NewNotFound(fmt.Sprintf("nexus endpoint not found for ID: %v", request.ID)) 183 | } else if conn.IsPreconditionFailedAndContains(err, "TABLE_VERSION_MISMATCH") { 184 | return p.ErrNexusTableVersionConflict 185 | } 186 | return conn.ConvertToTemporalError("DeleteNexusEndpoint", err) 187 | } 188 | 189 | return nil 190 | } 191 | 192 | func (s *NexusEndpointStore) GetNexusEndpoint(ctx context.Context, request *p.GetNexusEndpointRequest) (resp *p.InternalNexusEndpoint, err error) { 193 | defer func() { 194 | if err != nil { 195 | err = conn.ConvertToTemporalError("GetNexusEndpoint", err) 196 | } 197 | }() 198 | 199 | template := s.client.AddQueryPrefix(` 200 | DECLARE $endpoint_type AS Int32; 201 | DECLARE $endpoint_id AS String; 202 | 203 | SELECT data, data_encoding, version 204 | FROM nexus_endpoints 205 | WHERE type = $endpoint_type AND id = $endpoint_id 206 | LIMIT 1 207 | `) 208 | 209 | params := table.NewQueryParameters( 210 | table.ValueParam("$endpoint_type", types.Int32Value(rowTypeNexusEndpoint)), 211 | table.ValueParam("$endpoint_id", types.BytesValue([]byte(request.ID))), 212 | ) 213 | 214 | res, err := s.client.Do(ctx, template, conn.OnlineReadOnlyTxControl(), params, table.WithIdempotent()) 215 | if err != nil { 216 | return nil, conn.ConvertToTemporalError("GetNexusEndpoint", err) 217 | } 218 | 219 | defer func() { 220 | err2 := res.Close() 221 | if err == nil { 222 | err = err2 223 | } 224 | }() 225 | 226 | if err = res.NextResultSetErr(ctx); err != nil { 227 | return nil, err 228 | } 229 | 230 | if !res.NextRow() { 231 | return nil, conn.WrapErrorAsRootCause( 232 | serviceerror.NewNotFound(fmt.Sprintf("Nexus incoming service with ID `%v` not found", request.ID)), 233 | ) 234 | } 235 | 236 | resp = &p.InternalNexusEndpoint{ 237 | ID: request.ID, 238 | } 239 | var data []byte 240 | var encodingType conn.EncodingTypeRaw 241 | 242 | err = res.ScanNamed( 243 | named.OptionalWithDefault("data", &data), 244 | named.OptionalWithDefault("data_encoding", &encodingType), 245 | named.OptionalWithDefault("version", &resp.Version), 246 | ) 247 | if err != nil { 248 | return nil, err 249 | } 250 | 251 | resp.Data = p.NewDataBlob(data, enumspb.EncodingType(encodingType).String()) 252 | return resp, nil 253 | } 254 | 255 | func (s *NexusEndpointStore) ListNexusEndpoints(ctx context.Context, request *p.ListNexusEndpointsRequest) (resp *p.InternalListNexusEndpointsResponse, err error) { 256 | defer func() { 257 | if err != nil { 258 | err = conn.ConvertToTemporalError("ListNexusEndpoints", err) 259 | } 260 | }() 261 | 262 | var token tokens.NexusEndpointsPageToken 263 | token.Deserialize(request.NextPageToken) 264 | 265 | template := s.client.AddQueryPrefix(` 266 | DECLARE $partition_status_type AS Int32; 267 | DECLARE $table_version_id AS String; 268 | 269 | DECLARE $endpoint_type AS Int32; 270 | DECLARE $endpoint_last_seen_id AS String; 271 | DECLARE $page_size AS Uint64; 272 | 273 | SELECT type, id, data, data_encoding, version 274 | FROM nexus_endpoints 275 | WHERE type = $partition_status_type AND id = $table_version_id 276 | ; 277 | 278 | SELECT type, id, data, data_encoding, version 279 | FROM nexus_endpoints 280 | WHERE type = $endpoint_type AND id > $endpoint_last_seen_id 281 | ORDER BY id 282 | LIMIT $page_size 283 | ; 284 | `) 285 | 286 | params := table.NewQueryParameters( 287 | // table 288 | table.ValueParam("$partition_status_type", types.Int32Value(rowTypePartitionStatus)), 289 | table.ValueParam("$table_version_id", types.StringValueFromString(tableVersionEndpointID)), 290 | // endpoint 291 | table.ValueParam("$endpoint_type", types.Int32Value(rowTypeNexusEndpoint)), 292 | table.ValueParam("$endpoint_last_seen_id", types.StringValueFromString(token.LastSeenEndpointID)), 293 | table.ValueParam("$page_size", types.Uint64Value(uint64(request.PageSize))), 294 | ) 295 | 296 | res, err := s.client.Do(ctx, template, conn.OnlineReadOnlyTxControl(), params, table.WithIdempotent()) 297 | if err != nil { 298 | return nil, conn.ConvertToTemporalError("ListNexusEndpoints", err) 299 | } 300 | 301 | defer func() { 302 | err2 := res.Close() 303 | if err == nil { 304 | err = err2 305 | } 306 | }() 307 | 308 | if err = res.NextResultSetErr(ctx); err != nil { 309 | return nil, err 310 | } 311 | 312 | if !res.NextRow() { 313 | return &p.InternalListNexusEndpointsResponse{}, nil 314 | } 315 | 316 | var tableVersion int64 317 | err = res.ScanNamed( 318 | named.OptionalWithDefault("version", &tableVersion), 319 | ) 320 | if err != nil { 321 | return nil, err 322 | } 323 | if request.LastKnownTableVersion != 0 && tableVersion != request.LastKnownTableVersion { 324 | return nil, fmt.Errorf("%w. provided table version: %v current table version: %v", 325 | p.ErrNexusTableVersionConflict, 326 | request.LastKnownTableVersion, 327 | tableVersion) 328 | } 329 | 330 | if err = res.NextResultSetErr(ctx); err != nil { 331 | return nil, err 332 | } 333 | 334 | var endpoints []p.InternalNexusEndpoint 335 | for res.NextRow() { 336 | var endpoint p.InternalNexusEndpoint 337 | var data []byte 338 | var encodingType conn.EncodingTypeRaw 339 | err = res.ScanNamed( 340 | named.OptionalWithDefault("id", &endpoint.ID), 341 | named.OptionalWithDefault("version", &endpoint.Version), 342 | named.OptionalWithDefault("data", &data), 343 | named.OptionalWithDefault("data_encoding", &encodingType), 344 | ) 345 | if err != nil { 346 | return nil, err 347 | } 348 | endpoint.Data = p.NewDataBlob(data, enumspb.EncodingType(encodingType).String()) 349 | endpoints = append(endpoints, endpoint) 350 | } 351 | 352 | var nextPageToken []byte 353 | if len(endpoints) == request.PageSize { 354 | token := tokens.NexusEndpointsPageToken{ 355 | LastSeenEndpointID: endpoints[len(endpoints)-1].ID, 356 | } 357 | nextPageToken = token.Serialize() 358 | } 359 | 360 | return &p.InternalListNexusEndpointsResponse{ 361 | TableVersion: tableVersion, 362 | NextPageToken: nextPageToken, 363 | Endpoints: endpoints, 364 | }, nil 365 | } 366 | 367 | var _ p.NexusEndpointStore = (*NexusEndpointStore)(nil) 368 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/rows/factory.go: -------------------------------------------------------------------------------- 1 | package rows 2 | 3 | import ( 4 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/executor" 5 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 6 | ) 7 | 8 | type transactionFactoryImpl struct { 9 | client *conn.Client 10 | } 11 | 12 | func NewTransactionFactory(client *conn.Client) executor.TransactionFactory { 13 | return &transactionFactoryImpl{ 14 | client: client, 15 | } 16 | } 17 | 18 | func (e *transactionFactoryImpl) NewTransaction(shardID int32) executor.Transaction { 19 | return &transactionImpl{ 20 | client: e.client, 21 | shardID: shardID, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/rows/helper.go: -------------------------------------------------------------------------------- 1 | package rows 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 7 | commonpb "go.temporal.io/api/common/v1" 8 | "go.temporal.io/server/common/primitives" 9 | ) 10 | 11 | const SlowDeleteBatchSize = 10000 12 | 13 | var NumHistoryShards = 1024 14 | 15 | func ToShardIDColumnValue(shardID int32) uint32 { 16 | // Uniformly spread shard across (0, math.MaxUint32) interval 17 | step := math.MaxUint32/NumHistoryShards - 1 18 | return uint32(shardID * int32(step)) 19 | } 20 | 21 | func structNullValue(name string, t types.Type) types.StructValueOption { 22 | return types.StructFieldValue(name, types.NullValue(t)) 23 | } 24 | 25 | func (f *transactionImpl) getNullStructFieldValues() map[string]types.StructValueOption { 26 | return map[string]types.StructValueOption{ 27 | "namespace_id": structNullValue("namespace_id", f.client.NamespaceIDType()), 28 | "workflow_id": structNullValue("workflow_id", types.TypeUTF8), 29 | "run_id": structNullValue("run_id", f.client.RunIDType()), 30 | "task_category_id": structNullValue("task_category_id", types.TypeInt32), 31 | "task_id": structNullValue("task_id", types.TypeInt64), 32 | "event_type": structNullValue("event_type", types.TypeInt32), 33 | "event_id": structNullValue("event_id", types.TypeInt64), 34 | "event_name": structNullValue("event_name", types.TypeUTF8), 35 | "task_visibility_ts": structNullValue("task_visibility_ts", types.TypeTimestamp), 36 | "data": structNullValue("data", types.TypeBytes), 37 | "data_encoding": structNullValue("data_encoding", f.client.EncodingType()), 38 | "execution": structNullValue("execution", types.TypeBytes), 39 | "execution_encoding": structNullValue("execution_encoding", f.client.EncodingType()), 40 | "execution_state": structNullValue("execution_state", types.TypeBytes), 41 | "execution_state_encoding": structNullValue("execution_state_encoding", f.client.EncodingType()), 42 | "checksum": structNullValue("checksum", types.TypeBytes), 43 | "checksum_encoding": structNullValue("checksum_encoding", f.client.EncodingType()), 44 | "next_event_id": structNullValue("next_event_id", types.TypeInt64), 45 | "db_record_version": structNullValue("db_record_version", types.TypeInt64), 46 | "current_run_id": structNullValue("current_run_id", f.client.RunIDType()), 47 | "last_write_version": structNullValue("last_write_version", types.TypeInt64), 48 | "state": structNullValue("state", types.TypeInt32), 49 | 50 | "range_id": structNullValue("range_id", types.TypeInt64), 51 | "shard": structNullValue("shard", types.TypeBytes), 52 | "shard_encoding": structNullValue("shard_encoding", f.client.EncodingType()), 53 | } 54 | } 55 | 56 | func (f *transactionImpl) createExecutionsTableRow(shardID int32, fields map[string]types.Value) types.Value { 57 | rv := make([]types.StructValueOption, 0, len(fields)+1) 58 | rv = append(rv, types.StructFieldValue("shard_id", types.Uint32Value(ToShardIDColumnValue(shardID)))) 59 | for k, nullValue := range f.getNullStructFieldValues() { 60 | if value, ok := fields[k]; ok { 61 | rv = append(rv, types.StructFieldValue(k, types.OptionalValue(value))) 62 | } else { 63 | rv = append(rv, nullValue) 64 | } 65 | } 66 | return types.StructValue(rv...) 67 | } 68 | 69 | func createStructValue(fields map[string]types.Value) types.Value { 70 | rv := make([]types.StructValueOption, 0, len(fields)) 71 | for k, v := range fields { 72 | rv = append(rv, types.StructFieldValue(k, v)) 73 | } 74 | return types.StructValue(rv...) 75 | } 76 | 77 | func (f *transactionImpl) createWorkflowExecutionRow( 78 | namespaceID primitives.UUID, 79 | workflowID string, 80 | runID primitives.UUID, 81 | executionInfoBlob *commonpb.DataBlob, 82 | executionStateBlob *commonpb.DataBlob, 83 | nextEventID int64, 84 | dbRecordVersion int64, 85 | checksumBlob *commonpb.DataBlob, 86 | ) types.Value { 87 | return f.createExecutionsTableRow(f.shardID, map[string]types.Value{ 88 | "namespace_id": f.client.NamespaceIDValueFromUUID(namespaceID), 89 | "workflow_id": types.UTF8Value(workflowID), 90 | "run_id": f.client.RunIDValueFromUUID(runID), 91 | "execution": types.BytesValue(executionInfoBlob.Data), 92 | "execution_encoding": f.client.EncodingTypeValue(executionInfoBlob.EncodingType), 93 | "execution_state": types.BytesValue(executionStateBlob.Data), 94 | "execution_state_encoding": f.client.EncodingTypeValue(executionStateBlob.EncodingType), 95 | "checksum": types.BytesValue(checksumBlob.Data), 96 | "checksum_encoding": f.client.EncodingTypeValue(checksumBlob.EncodingType), 97 | "next_event_id": types.Int64Value(nextEventID), 98 | "db_record_version": types.Int64Value(dbRecordVersion), 99 | }) 100 | } 101 | -------------------------------------------------------------------------------- /persistence/pkg/ydb/shard_store.go: -------------------------------------------------------------------------------- 1 | package ydb 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | 8 | ydb "github.com/ydb-platform/ydb-go-sdk/v3" 9 | "github.com/ydb-platform/ydb-go-sdk/v3/table" 10 | "github.com/ydb-platform/ydb-go-sdk/v3/table/result/named" 11 | "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 12 | commonpb "go.temporal.io/api/common/v1" 13 | enumspb "go.temporal.io/api/enums/v1" 14 | "go.temporal.io/server/common/log" 15 | p "go.temporal.io/server/common/persistence" 16 | 17 | "github.com/yandex/temporal-over-ydb/persistence/pkg/base/executor" 18 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/conn" 19 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/rows" 20 | ) 21 | 22 | type ( 23 | ShardStore struct { 24 | clusterName string 25 | client *conn.Client 26 | logger log.Logger 27 | tf executor.TransactionFactory 28 | } 29 | ) 30 | 31 | func NewShardStore( 32 | clusterName string, 33 | client *conn.Client, 34 | logger log.Logger, 35 | ) *ShardStore { 36 | tf := rows.NewTransactionFactory(client) 37 | return &ShardStore{ 38 | clusterName: clusterName, 39 | client: client, 40 | logger: logger, 41 | tf: tf, 42 | } 43 | } 44 | 45 | func (d *ShardStore) GetOrCreateShard( 46 | ctx context.Context, 47 | request *p.InternalGetOrCreateShardRequest, 48 | ) (resp *p.InternalGetOrCreateShardResponse, err error) { 49 | shardID := request.ShardID 50 | 51 | defer func() { 52 | if err != nil { 53 | details := fmt.Sprintf("shard_id: %v", shardID) 54 | err = conn.ConvertToTemporalError("GetOrCreateShard", err, details) 55 | } 56 | }() 57 | 58 | shardInfo, err := d.getShard(ctx, shardID) 59 | if err != nil { 60 | return nil, err 61 | } 62 | if shardInfo != nil { 63 | resp = &p.InternalGetOrCreateShardResponse{ 64 | ShardInfo: shardInfo, 65 | } 66 | return resp, nil 67 | } 68 | if request.CreateShardInfo == nil { 69 | return nil, errors.New("shard not found and CreateShardInfo is nil") 70 | } 71 | 72 | // shard was not found and we should create it 73 | rangeID, shardInfo, err := request.CreateShardInfo() 74 | if err != nil { 75 | return nil, err 76 | } 77 | template := d.client.AddQueryPrefix(` 78 | DECLARE $shard_id AS uint32; 79 | DECLARE $range_id AS int64; 80 | DECLARE $shard AS string; 81 | DECLARE $shard_encoding AS ` + d.client.EncodingType().String() + `; 82 | 83 | INSERT INTO executions (shard_id, namespace_id, workflow_id, run_id, task_id, task_category_id, task_visibility_ts, event_type, event_id, event_name, range_id, shard, shard_encoding) 84 | VALUES ($shard_id, "", "", "", NULL, NULL, NULL, NULL, NULL, NULL, $range_id, $shard, $shard_encoding); 85 | `) 86 | err = d.client.Write(ctx, template, table.NewQueryParameters( 87 | table.ValueParam("$shard_id", types.Uint32Value(rows.ToShardIDColumnValue(request.ShardID))), 88 | table.ValueParam("$shard", types.BytesValue(shardInfo.Data)), 89 | table.ValueParam("$shard_encoding", d.client.EncodingTypeValue(shardInfo.EncodingType)), 90 | table.ValueParam("$range_id", types.Int64Value(rangeID)), 91 | )) 92 | 93 | if ydb.IsOperationErrorAlreadyExistsError(err) { 94 | shardInfo, err = d.getShard(ctx, shardID) 95 | if err != nil { 96 | return nil, err 97 | } 98 | if shardInfo == nil { 99 | return nil, errors.New("couldn't get shard that already exists") 100 | } 101 | } else if err != nil { 102 | return 103 | } 104 | 105 | resp = &p.InternalGetOrCreateShardResponse{ 106 | ShardInfo: shardInfo, 107 | } 108 | return resp, nil 109 | } 110 | 111 | func (d *ShardStore) UpdateShard(ctx context.Context, request *p.InternalUpdateShardRequest) (err error) { 112 | defer func() { 113 | if err != nil { 114 | err = conn.ConvertToTemporalError("UpdateShard", err) 115 | } 116 | }() 117 | 118 | transaction := d.tf.NewTransaction(request.ShardID) 119 | transaction.AssertShard(true, request.PreviousRangeID) 120 | transaction.UpsertShard(request.RangeID, request.ShardInfo) 121 | return transaction.Execute(ctx) 122 | } 123 | 124 | func (d *ShardStore) AssertShardOwnership(ctx context.Context, request *p.AssertShardOwnershipRequest) error { 125 | transaction := d.tf.NewTransaction(request.ShardID) 126 | transaction.AssertShard(false, request.RangeID) 127 | return transaction.Execute(ctx) 128 | } 129 | 130 | func (d *ShardStore) GetName() string { 131 | return ydbPersistenceName 132 | } 133 | 134 | func (d *ShardStore) GetClusterName() string { 135 | return d.clusterName 136 | } 137 | 138 | func (d *ShardStore) Close() { 139 | } 140 | 141 | func (d *ShardStore) getShard(ctx context.Context, shardID int32) (rv *commonpb.DataBlob, err error) { 142 | template := d.client.AddQueryPrefix(` 143 | DECLARE $shard_id AS uint32; 144 | 145 | SELECT shard, shard_encoding 146 | FROM executions 147 | WHERE shard_id = $shard_id 148 | AND namespace_id = "" 149 | AND workflow_id = "" 150 | AND run_id = "" 151 | AND task_id IS NULL 152 | AND task_category_id IS NULL 153 | AND task_visibility_ts IS NULL 154 | AND event_type IS NULL 155 | AND event_id IS NULL 156 | AND event_name IS NULL; 157 | `) 158 | res, err := d.client.Do(ctx, template, conn.OnlineReadOnlyTxControl(), table.NewQueryParameters( 159 | table.ValueParam("$shard_id", types.Uint32Value(rows.ToShardIDColumnValue(shardID))), 160 | ), table.WithIdempotent()) 161 | if err != nil { 162 | return nil, err 163 | } 164 | defer func() { 165 | err2 := res.Close() 166 | if err == nil { 167 | err = err2 168 | } 169 | }() 170 | if err = res.NextResultSetErr(ctx); err != nil { 171 | return nil, err 172 | } 173 | if !res.NextRow() { 174 | return nil, nil 175 | } 176 | var data []byte 177 | var encoding string 178 | var encodingType conn.EncodingTypeRaw 179 | var encodingScanner named.Value 180 | if d.client.UseIntForEncoding() { 181 | encodingScanner = named.OptionalWithDefault("shard_encoding", &encodingType) 182 | } else { 183 | encodingScanner = named.OptionalWithDefault("shard_encoding", &encoding) 184 | } 185 | if err = res.ScanNamed( 186 | named.OptionalWithDefault("shard", &data), 187 | encodingScanner, 188 | ); err != nil { 189 | return nil, fmt.Errorf("failed to scan shard: %w", err) 190 | } 191 | if d.client.UseIntForEncoding() { 192 | encoding = enumspb.EncodingType(encodingType).String() 193 | } 194 | return p.NewDataBlob(data, encoding), nil 195 | } 196 | -------------------------------------------------------------------------------- /schema/temporal/0001_initial.sql: -------------------------------------------------------------------------------- 1 | -- +goose Up 2 | CREATE TABLE IF NOT EXISTS tasks_and_task_queues 3 | ( 4 | namespace_id Utf8 NOT NULL, 5 | task_queue_name Utf8 NOT NULL, 6 | task_queue_type Int32 NOT NULL, 7 | task_id Int64, 8 | expire_at Timestamp, 9 | 10 | range_id Int64, 11 | task String, 12 | task_encoding Utf8, 13 | task_queue String, 14 | task_queue_encoding Utf8, 15 | PRIMARY KEY (namespace_id, task_queue_name, task_queue_type, task_id) 16 | ) WITH ( 17 | TTL = Interval ("PT0S") ON expire_at, 18 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 19 | AUTO_PARTITIONING_BY_LOAD = ENABLED 20 | ); 21 | 22 | -- 23 | -- This table stores the following rows: 24 | -- PK (shard_id, "", "", "", NULL, ...) - shard 25 | -- PK (shard_id, namespace_id, workflow_id, "", ...) - current 26 | -- PK (shard_id, namespace_id, workflow_id, run_id, NULL, ...) - workflow 27 | -- PK (shard_id, namespace_id, workflow_id, run_id, NULL, NULL, NULL, event_type, event_id, event_id, ...) - event 28 | -- PK (shard_id, NULL, NULL, NULL, task_category_id, task_visibility_ts, task_id, NULL, ...) - task 29 | 30 | CREATE TABLE IF NOT EXISTS executions 31 | ( 32 | shard_id Uint32 NOT NULL, 33 | namespace_id Utf8, 34 | workflow_id Utf8, 35 | run_id Utf8, 36 | 37 | task_category_id Int32, 38 | task_visibility_ts Timestamp, 39 | task_id Int64, 40 | 41 | event_type Int32, 42 | event_id Int64, 43 | event_name Utf8, 44 | 45 | execution String, -- workflow 46 | execution_encoding Utf8, -- workflow 47 | execution_state String, -- workflow and current 48 | execution_state_encoding Utf8, -- workflow and current 49 | checksum String, -- workflow 50 | checksum_encoding Utf8, -- workflow 51 | db_record_version Int64, -- workflow 52 | next_event_id Int64, -- workflow 53 | 54 | current_run_id Utf8, -- current 55 | last_write_version Int64, -- current 56 | state Int32, -- current 57 | 58 | range_id Int64, -- shard 59 | shard String, -- shard 60 | shard_encoding Utf8, -- shard 61 | 62 | data String, -- event and task 63 | data_encoding Utf8, -- event and task 64 | 65 | PRIMARY KEY (shard_id, namespace_id, workflow_id, run_id, task_category_id, task_visibility_ts, task_id, event_type, 66 | event_id, event_name) 67 | ) WITH ( 68 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 69 | AUTO_PARTITIONING_BY_LOAD = ENABLED, 70 | AUTO_PARTITIONING_MIN_PARTITIONS_COUNT = 8, 71 | UNIFORM_PARTITIONS = 8, 72 | AUTO_PARTITIONING_PARTITION_SIZE_MB = 4096, 73 | KEY_BLOOM_FILTER = ENABLED 74 | ); 75 | 76 | CREATE TABLE IF NOT EXISTS history_node 77 | ( 78 | tree_id Utf8 NOT NULL, -- run_id if no reset, otherwise run_id of first run 79 | branch_id Utf8 NOT NULL, -- changes in case of reset workflow. Conflict resolution can also change branch id. 80 | node_id Int64 NOT NULL, -- == first eventID in a batch of events 81 | txn_id Int64 NOT NULL, -- in case of multiple transactions on same node, we utilize highest transaction ID. Unique. 82 | prev_txn_id Int64, -- point to the previous node: event chaining 83 | data String, -- batch of workflow execution history events as a blob 84 | data_encoding Utf8, -- protocol used for history serialization 85 | PRIMARY KEY (tree_id, branch_id, node_id, txn_id) 86 | ) WITH ( 87 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 88 | AUTO_PARTITIONING_BY_LOAD = ENABLED 89 | ); 90 | 91 | CREATE TABLE IF NOT EXISTS history_tree 92 | ( 93 | tree_id Utf8 NOT NULL, 94 | branch_id Utf8 NOT NULL, 95 | branch String, 96 | branch_encoding Utf8, 97 | PRIMARY KEY (tree_id, branch_id) 98 | ) WITH ( 99 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 100 | AUTO_PARTITIONING_BY_LOAD = ENABLED 101 | ); 102 | 103 | CREATE TABLE IF NOT EXISTS replication_tasks 104 | ( 105 | shard_id Uint32, 106 | source_cluster_name Utf8, 107 | task_id Int64, 108 | data String, 109 | data_encoding Utf8, 110 | PRIMARY KEY (shard_id, source_cluster_name, task_id) 111 | ) WITH ( 112 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 113 | AUTO_PARTITIONING_BY_LOAD = ENABLED 114 | ); 115 | 116 | CREATE TABLE IF NOT EXISTS cluster_metadata_info 117 | ( 118 | cluster_name Utf8, 119 | data String, 120 | data_encoding Utf8, 121 | version Int64, 122 | PRIMARY KEY (cluster_name) 123 | ) WITH ( 124 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 125 | AUTO_PARTITIONING_BY_LOAD = ENABLED 126 | ); 127 | 128 | CREATE TABLE IF NOT EXISTS queue 129 | ( 130 | queue_type Int32, 131 | message_id Int64, 132 | message_payload String, 133 | message_encoding Utf8, 134 | PRIMARY KEY (queue_type, message_id) 135 | ) WITH ( 136 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 137 | AUTO_PARTITIONING_BY_LOAD = ENABLED 138 | ); 139 | 140 | CREATE TABLE IF NOT EXISTS queue_metadata 141 | ( 142 | queue_type Int32, 143 | data String, 144 | data_encoding Utf8, 145 | version Int64, 146 | PRIMARY KEY (queue_type) 147 | ) WITH ( 148 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 149 | AUTO_PARTITIONING_BY_LOAD = ENABLED 150 | ); 151 | 152 | 153 | -- this table is only used for storage of mapping of namespace uuid to namespace name 154 | CREATE TABLE IF NOT EXISTS namespaces_by_id 155 | ( 156 | id Utf8, 157 | name Utf8, 158 | PRIMARY KEY (id) 159 | ); 160 | 161 | CREATE TABLE IF NOT EXISTS namespaces 162 | ( 163 | name Utf8, 164 | id Utf8, 165 | detail String, 166 | detail_encoding Utf8, 167 | is_global_namespace Bool, 168 | notification_version Int64, 169 | PRIMARY KEY (name) 170 | ); 171 | 172 | CREATE TABLE IF NOT EXISTS cluster_membership 173 | ( 174 | host_id Utf8, 175 | rpc_address Utf8, 176 | rpc_port Int32, 177 | role Int32, 178 | session_start Timestamp, 179 | last_heartbeat Timestamp, 180 | expire_at Timestamp, 181 | PRIMARY KEY (role, host_id) 182 | ) WITH ( 183 | TTL = Interval ("PT0S") ON expire_at, 184 | 185 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 186 | AUTO_PARTITIONING_BY_LOAD = ENABLED 187 | ); 188 | 189 | CREATE TABLE IF NOT EXISTS task_queue_user_data 190 | ( 191 | namespace_id Utf8, 192 | task_queue_name Utf8, 193 | data String, 194 | data_encoding Utf8, 195 | version Int64, 196 | PRIMARY KEY (namespace_id, task_queue_name) 197 | ); 198 | 199 | CREATE TABLE IF NOT EXISTS build_id_to_task_queue 200 | ( 201 | namespace_id Utf8, 202 | build_id Utf8, 203 | task_queue_name Utf8, 204 | PRIMARY KEY (namespace_id, build_id, task_queue_name) 205 | ); 206 | -------------------------------------------------------------------------------- /schema/temporal/0002_nexus.sql: -------------------------------------------------------------------------------- 1 | -- +goose Up 2 | 3 | -- +goose StatementBegin 4 | CREATE TABLE IF NOT EXISTS nexus_endpoints ( 5 | type Int32, -- enum RowType { PartitionStatus, NexusEndpoint } 6 | id String, 7 | data String, 8 | data_encoding Int16, 9 | -- When type=PartitionStatus contains the partition version. 10 | -- Partition version is used to guarantee latest versions when listing all endpoints. 11 | -- When type=NexusEndpoint contains the endpoint version used for optimistic concurrency. 12 | version Int64, 13 | PRIMARY KEY (type, id) 14 | ) WITH ( 15 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 16 | AUTO_PARTITIONING_BY_LOAD = ENABLED 17 | ); 18 | -- +goose StatementEnd 19 | -------------------------------------------------------------------------------- /schema/temporal/0003_queue_v2.sql: -------------------------------------------------------------------------------- 1 | -- +goose Up 2 | 3 | -- +goose StatementBegin 4 | CREATE TABLE IF NOT EXISTS queue_v2 5 | ( 6 | queue_type Int32, 7 | queue_name Utf8, 8 | metadata_payload String, 9 | metadata_encoding Int16, 10 | version Int64, 11 | PRIMARY KEY (queue_type, queue_name) 12 | ) WITH ( 13 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 14 | AUTO_PARTITIONING_BY_LOAD = ENABLED 15 | ); 16 | -- +goose StatementEnd 17 | 18 | -- +goose StatementBegin 19 | CREATE TABLE IF NOT EXISTS queue_v2_message 20 | ( 21 | queue_type Int32, 22 | queue_name Utf8, 23 | queue_partition Int32, 24 | 25 | message_id Int64, 26 | message_payload String, 27 | message_encoding Int16, 28 | PRIMARY KEY (queue_type, queue_name, queue_partition, message_id) 29 | ) WITH ( 30 | AUTO_PARTITIONING_BY_SIZE = ENABLED, 31 | AUTO_PARTITIONING_BY_LOAD = ENABLED 32 | ); 33 | -- +goose StatementEnd 34 | -------------------------------------------------------------------------------- /temporal-over-ydb/cmd/migrator/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | ) 6 | 7 | func main() { 8 | if err := rootCmd.Execute(); err != nil { 9 | log.Fatal(err) 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /temporal-over-ydb/cmd/migrator/ping.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/spf13/cobra" 7 | ) 8 | 9 | func init() { 10 | rootCmd.AddCommand(pingCmd) 11 | } 12 | 13 | var ( 14 | pingCmd = &cobra.Command{ 15 | Use: "ping", 16 | Short: "Check if the database is available", 17 | RunE: func(cmd *cobra.Command, args []string) error { 18 | if err := db.PingContext(cmd.Context()); err != nil { 19 | return err 20 | } 21 | log.Println("Database is reachable.") 22 | return nil 23 | }, 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /temporal-over-ydb/cmd/migrator/root.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/spf13/cobra" 8 | ydbenv "github.com/ydb-platform/ydb-go-sdk-auth-environ" 9 | "github.com/ydb-platform/ydb-go-sdk/v3" 10 | "github.com/ydb-platform/ydb-go-sdk/v3/sugar" 11 | ) 12 | 13 | func init() { 14 | rootCmd.PersistentFlags().StringVar(&host, "host", "", "host") 15 | if err := rootCmd.MarkPersistentFlagRequired("host"); err != nil { 16 | panic(err) 17 | } 18 | 19 | rootCmd.PersistentFlags().StringVar(&port, "port", "", "port") 20 | if err := rootCmd.MarkPersistentFlagRequired("port"); err != nil { 21 | panic(err) 22 | } 23 | 24 | rootCmd.PersistentFlags().StringVar(&dbName, "db", "", "database name") 25 | if err := rootCmd.MarkPersistentFlagRequired("db"); err != nil { 26 | panic(err) 27 | } 28 | 29 | rootCmd.PersistentFlags().StringVar(&prefix, "prefix", "", "table path prefix") 30 | 31 | rootCmd.PersistentFlags().BoolVar(&withSecure, "secure", false, "Use secure connection (default false)") 32 | } 33 | 34 | var ( 35 | host string 36 | port string 37 | dbName string 38 | withSecure bool 39 | prefix string 40 | 41 | db *sql.DB 42 | 43 | rootCmd = &cobra.Command{ 44 | Use: "ydb-migrator", 45 | Short: "A CLI tool for managing YDB migrations", 46 | PersistentPreRunE: func(cmd *cobra.Command, args []string) error { 47 | endpoint := host + ":" + port 48 | driver, err := ydb.Open(cmd.Context(), 49 | sugar.DSN(endpoint, dbName, sugar.WithSecure(withSecure)), 50 | ydbenv.WithEnvironCredentials(context.Background()), 51 | ) 52 | if err != nil { 53 | return err 54 | } 55 | 56 | tablePrefix := dbName + "/" + prefix 57 | connector, err := ydb.Connector(driver, 58 | ydb.WithTablePathPrefix(tablePrefix), 59 | ydb.WithDefaultQueryMode(ydb.ScriptingQueryMode), 60 | ydb.WithFakeTx(ydb.ScriptingQueryMode), 61 | ydb.WithNumericArgs(), 62 | ydb.WithAutoDeclare(), 63 | ) 64 | if err != nil { 65 | return err 66 | } 67 | 68 | db = sql.OpenDB(connector) 69 | 70 | return nil 71 | }, 72 | } 73 | ) 74 | -------------------------------------------------------------------------------- /temporal-over-ydb/cmd/migrator/up.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/pressly/goose/v3" 7 | "github.com/spf13/cobra" 8 | ) 9 | 10 | func init() { 11 | rootCmd.AddCommand(upCmd) 12 | 13 | upCmd.Flags().StringVar(&versionTable, "version-table", "", "Version table name") 14 | if err := upCmd.MarkFlagRequired("version-table"); err != nil { 15 | panic(err) 16 | } 17 | 18 | upCmd.Flags().StringVar(&schemaDir, "schema-dir", "", "Schema dir") 19 | if err := upCmd.MarkFlagRequired("schema-dir"); err != nil { 20 | panic(err) 21 | } 22 | } 23 | 24 | var ( 25 | versionTable string 26 | schemaDir string 27 | 28 | upCmd = &cobra.Command{ 29 | Use: "up", 30 | Short: "Apply all up migrations", 31 | RunE: func(cmd *cobra.Command, args []string) error { 32 | goose.SetTableName(versionTable) 33 | 34 | if err := goose.SetDialect("ydb"); err != nil { 35 | return err 36 | } 37 | 38 | if err := goose.Up(db, schemaDir); err != nil { 39 | return err 40 | } 41 | 42 | log.Println("Migrations applied successfully.") 43 | return nil 44 | }, 45 | } 46 | ) 47 | -------------------------------------------------------------------------------- /temporal-over-ydb/cmd/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "strings" 8 | 9 | "github.com/urfave/cli/v2" 10 | "go.temporal.io/server/common/authorization" 11 | "go.temporal.io/server/common/build" 12 | "go.temporal.io/server/common/config" 13 | "go.temporal.io/server/common/debug" 14 | "go.temporal.io/server/common/dynamicconfig" 15 | "go.temporal.io/server/common/headers" 16 | "go.temporal.io/server/common/log" 17 | "go.temporal.io/server/common/log/tag" 18 | "go.temporal.io/server/temporal" 19 | 20 | "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb" 21 | ydbrows "github.com/yandex/temporal-over-ydb/persistence/pkg/ydb/rows" 22 | ) 23 | 24 | func main() { 25 | app := buildCLI() 26 | _ = app.Run(os.Args) 27 | } 28 | 29 | func buildCLI() *cli.App { 30 | app := cli.NewApp() 31 | app.Name = "temporal" 32 | app.Usage = "Temporal over YDB server" 33 | app.Version = headers.ServerVersion 34 | app.ArgsUsage = " " 35 | app.Flags = []cli.Flag{ 36 | &cli.StringFlag{ 37 | Name: "root", 38 | Aliases: []string{"r"}, 39 | Value: ".", 40 | Usage: "root directory of execution environment", 41 | EnvVars: []string{config.EnvKeyRoot}, 42 | }, 43 | &cli.StringFlag{ 44 | Name: "config", 45 | Aliases: []string{"c"}, 46 | Value: "config", 47 | Usage: "config dir path relative to root", 48 | EnvVars: []string{config.EnvKeyConfigDir}, 49 | }, 50 | &cli.StringFlag{ 51 | Name: "env", 52 | Aliases: []string{"e"}, 53 | Value: "development", 54 | Usage: "runtime environment", 55 | EnvVars: []string{config.EnvKeyEnvironment}, 56 | }, 57 | &cli.StringFlag{ 58 | Name: "zone", 59 | Aliases: []string{"az"}, 60 | Usage: "availability zone", 61 | EnvVars: []string{config.EnvKeyAvailabilityZone, config.EnvKeyAvailabilityZoneTypo}, 62 | }, 63 | } 64 | 65 | app.Commands = []*cli.Command{ 66 | { 67 | Name: "start", 68 | Usage: "Start Temporal server", 69 | ArgsUsage: " ", 70 | Flags: []cli.Flag{ 71 | &cli.StringSliceFlag{ 72 | Name: "service", 73 | Aliases: []string{"svc"}, 74 | Value: cli.NewStringSlice(temporal.Services...), 75 | Usage: "service(s) to start", 76 | }, 77 | }, 78 | Action: func(c *cli.Context) error { 79 | env := c.String("env") 80 | zone := c.String("zone") 81 | configDir := path.Join(c.String("root"), c.String("config")) 82 | services := c.StringSlice("service") 83 | allowNoAuth := c.Bool("allow-no-auth") 84 | 85 | // For backward compatibility to support old flag format (i.e. `--services=frontend,history,matching`). 86 | if c.IsSet("services") { 87 | fmt.Println("WARNING: --services flag is deprecated. Specify multiply --service flags instead.") 88 | services = strings.Split(c.String("services"), ",") 89 | } 90 | 91 | cfg, err := config.LoadConfig(env, configDir, zone) 92 | if err != nil { 93 | return cli.Exit(fmt.Sprintf("Unable to load configuration: %v.", err), 1) 94 | } 95 | 96 | // XXX 97 | ydbrows.NumHistoryShards = int(cfg.Persistence.NumHistoryShards) 98 | 99 | logger := log.NewZapLogger(log.BuildZapLogger(cfg.Log)) 100 | logger.Info("Build info.", 101 | tag.NewTimeTag("git-time", build.InfoData.GitTime), 102 | tag.NewStringTag("git-revision", build.InfoData.GitRevision), 103 | tag.NewBoolTag("git-modified", build.InfoData.GitModified), 104 | tag.NewStringTag("go-arch", build.InfoData.GoArch), 105 | tag.NewStringTag("go-os", build.InfoData.GoOs), 106 | tag.NewStringTag("go-version", build.InfoData.GoVersion), 107 | tag.NewBoolTag("cgo-enabled", build.InfoData.CgoEnabled), 108 | tag.NewStringTag("server-version", headers.ServerVersion), 109 | tag.NewBoolTag("debug-mode", debug.Enabled), 110 | ) 111 | 112 | var dynamicConfigClient dynamicconfig.Client 113 | if cfg.DynamicConfigClient != nil { 114 | dynamicConfigClient, err = dynamicconfig.NewFileBasedClient(cfg.DynamicConfigClient, logger, temporal.InterruptCh()) 115 | if err != nil { 116 | return cli.Exit(fmt.Sprintf("Unable to create dynamic config client. Error: %v", err), 1) 117 | } 118 | } else { 119 | dynamicConfigClient = dynamicconfig.NewNoopClient() 120 | logger.Info("Dynamic config client is not configured. Using noop client.") 121 | } 122 | 123 | authorizer, err := authorization.GetAuthorizerFromConfig( 124 | &cfg.Global.Authorization, 125 | ) 126 | if err != nil { 127 | return cli.Exit(fmt.Sprintf("Unable to instantiate authorizer. Error: %v", err), 1) 128 | } 129 | if authorization.IsNoopAuthorizer(authorizer) && !allowNoAuth { 130 | logger.Warn( 131 | "Not using any authorizer and flag `--allow-no-auth` not detected. " + 132 | "Future versions will require using the flag `--allow-no-auth` " + 133 | "if you do not want to set an authorizer.", 134 | ) 135 | } 136 | 137 | claimMapper, err := authorization.GetClaimMapperFromConfig(&cfg.Global.Authorization, logger) 138 | if err != nil { 139 | return cli.Exit(fmt.Sprintf("Unable to instantiate claim mapper: %v.", err), 1) 140 | } 141 | 142 | s, err := temporal.NewServer( 143 | temporal.ForServices(services), 144 | temporal.WithConfig(cfg), 145 | temporal.WithDynamicConfigClient(dynamicConfigClient), 146 | temporal.WithLogger(logger), 147 | temporal.InterruptOn(temporal.InterruptCh()), 148 | temporal.WithAuthorizer(authorizer), 149 | temporal.WithClaimMapper(func(cfg *config.Config) authorization.ClaimMapper { 150 | return claimMapper 151 | }), 152 | temporal.WithCustomDataStoreFactory(ydb.NewYDBAbstractDataStoreFactory()), 153 | ) 154 | if err != nil { 155 | return cli.Exit(fmt.Sprintf("Unable to create server. Error: %v.", err), 1) 156 | } 157 | 158 | err = s.Start() 159 | if err != nil { 160 | return cli.Exit(fmt.Sprintf("Unable to start server. Error: %v", err), 1) 161 | } 162 | return cli.Exit("All services are stopped.", 0) 163 | }, 164 | }, 165 | } 166 | return app 167 | } 168 | --------------------------------------------------------------------------------