├── .gitignore
├── committer.png
├── io
├── gateway
│ └── grpc
│ │ ├── client
│ │ ├── common.go
│ │ ├── client_api.go
│ │ └── internal_client.go
│ │ ├── server
│ │ ├── dto.go
│ │ ├── interceptors.go
│ │ └── server.go
│ │ └── proto
│ │ ├── schema.proto
│ │ └── schema_grpc.pb.go
└── store
│ ├── store_test.go
│ └── store.go
├── examples
└── client
│ └── client.go
├── core
├── cohort
│ ├── commitalgo
│ │ ├── hooks
│ │ │ ├── hooks.go
│ │ │ ├── registry.go
│ │ │ ├── registry_test.go
│ │ │ └── examples.go
│ │ ├── fsm.go
│ │ ├── commitalgo.go
│ │ └── committer_test.go
│ ├── cohort.go
│ └── cohort_test.go
├── dto
│ └── dto.go
├── walrecord
│ └── walrecord.go
└── coordinator
│ ├── coordinator_test.go
│ └── coordinator.go
├── .github
└── workflows
│ └── tests.yml
├── Makefile
├── go.mod
├── mocks
├── mock_state_store.go
├── mock_commitalgo_state_store.go
├── mock_wal.go
├── mock_cohort.go
└── mock_committer.go
├── config
└── config.go
├── main.go
├── README.md
├── toxiproxy_test.go
├── go.sum
├── main_test.go
├── LICENSE
└── chaos_test.go
/.gitignore:
--------------------------------------------------------------------------------
1 | /.idea/
--------------------------------------------------------------------------------
/committer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vadiminshakov/committer/HEAD/committer.png
--------------------------------------------------------------------------------
/io/gateway/grpc/client/common.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "time"
6 |
7 | "github.com/pkg/errors"
8 | "google.golang.org/grpc"
9 | "google.golang.org/grpc/backoff"
10 | "google.golang.org/grpc/credentials/insecure"
11 | )
12 |
13 | func createConnection(addr string) (*grpc.ClientConn, error) {
14 | connParams := grpc.ConnectParams{
15 | Backoff: backoff.Config{
16 | BaseDelay: 100 * time.Millisecond,
17 | MaxDelay: 10 * time.Second,
18 | },
19 | MinConnectTimeout: 200 * time.Millisecond,
20 | }
21 |
22 | ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
23 | conn, err := grpc.DialContext(ctx, addr, grpc.WithConnectParams(connParams), grpc.WithTransportCredentials(insecure.NewCredentials()))
24 | if err != nil {
25 | return nil, errors.Wrap(err, "failed to connect")
26 | }
27 |
28 | return conn, nil
29 | }
30 |
--------------------------------------------------------------------------------
/io/gateway/grpc/server/dto.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/vadiminshakov/committer/core/dto"
5 | "github.com/vadiminshakov/committer/io/gateway/grpc/proto"
6 | )
7 |
8 | func proposeRequestPbToEntity(request *proto.ProposeRequest) *dto.ProposeRequest {
9 | if request == nil {
10 | return nil
11 | }
12 |
13 | return &dto.ProposeRequest{
14 | Key: request.Key,
15 | Value: request.Value,
16 | Height: request.Index,
17 | }
18 | }
19 |
20 | func commitRequestPbToEntity(request *proto.CommitRequest) *dto.CommitRequest {
21 | if request == nil {
22 | return nil
23 | }
24 |
25 | return &dto.CommitRequest{
26 | Height: request.Index,
27 | IsRollback: request.IsRollback,
28 | }
29 | }
30 |
31 | func cohortResponseToProto(e *dto.CohortResponse) *proto.Response {
32 | if e == nil {
33 | return nil
34 | }
35 | return &proto.Response{
36 | Type: proto.Type(e.ResponseType),
37 | Index: e.Height,
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/examples/client/client.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/vadiminshakov/committer/io/gateway/grpc/client"
7 | pb "github.com/vadiminshakov/committer/io/gateway/grpc/proto"
8 | "strconv"
9 | )
10 |
11 | const coordinatorAddr = "0.0.0.0:3000"
12 |
13 | func main() {
14 | key, value := "somekey", "somevalue"
15 |
16 | // create a client for interaction with coordinator
17 | cli, err := client.NewClientAPI(coordinatorAddr)
18 | if err != nil {
19 | panic(err)
20 | }
21 |
22 | for i := 0; i < 5; i++ {
23 | // put a key-value pair
24 | resp, err := cli.Put(context.Background(), key+strconv.Itoa(i), []byte(value+strconv.Itoa(i)))
25 | if err != nil {
26 | panic(err)
27 | }
28 | if resp.Type != pb.Type_ACK {
29 | panic("msg is not acknowledged")
30 | }
31 |
32 | // read committed key
33 | v, err := cli.Get(context.Background(), key+strconv.Itoa(i))
34 | if err != nil {
35 | panic(err)
36 | }
37 | fmt.Printf("got value for key '%s': %s\n", key+strconv.Itoa(i), string(v.Value)+strconv.Itoa(i))
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/core/cohort/commitalgo/hooks/hooks.go:
--------------------------------------------------------------------------------
1 | // Package hooks provides an extensible hook system for commit algorithms.
2 | //
3 | // Hooks allow custom validation, metrics collection, and business logic
4 | // to be executed during propose and commit phases without modifying core logic.
5 | package hooks
6 |
7 | import (
8 | log "github.com/sirupsen/logrus"
9 | "github.com/vadiminshakov/committer/core/dto"
10 | )
11 |
12 | // DefaultHook provides the default logging behavior
13 | type DefaultHook struct{}
14 |
15 | // NewDefaultHook creates a new default hook instance
16 | func NewDefaultHook() *DefaultHook {
17 | return &DefaultHook{}
18 | }
19 |
20 | // OnPropose implements the Hook interface for propose operations
21 | func (h *DefaultHook) OnPropose(req *dto.ProposeRequest) bool {
22 | log.Infof("propose hook on height %d is OK", req.Height)
23 | return true
24 | }
25 |
26 | // OnCommit implements the Hook interface for commit operations
27 | func (h *DefaultHook) OnCommit(req *dto.CommitRequest) bool {
28 | log.Infof("commit hook on height %d is OK", req.Height)
29 | return true
30 | }
31 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | # This is a basic workflow to help you get started with Actions
2 |
3 | name: tests
4 |
5 | # Controls when the action will run. Triggers the workflow on push or pull request
6 | # events but only for the master branch
7 | on:
8 | push:
9 | branches: [ master ]
10 | pull_request:
11 | branches: [ master ]
12 |
13 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
14 | jobs:
15 | # This workflow contains a single job called "build"
16 | run-tests:
17 | # The type of runner that the job will run on
18 | runs-on: ubuntu-latest
19 |
20 | # Steps represent a sequence of tasks that will be executed as part of the job
21 | steps:
22 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
23 | - uses: actions/checkout@v2
24 |
25 | # Set up Go
26 | - name: Set up Go
27 | uses: actions/setup-go@v4
28 | with:
29 | go-version: '1.25'
30 |
31 | # Runs a single command using the runners shell
32 | - name: Run tests
33 | run: make prepare && make tests
34 |
--------------------------------------------------------------------------------
/io/gateway/grpc/server/interceptors.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "net"
6 |
7 | "google.golang.org/grpc"
8 | codes "google.golang.org/grpc/codes"
9 | "google.golang.org/grpc/peer"
10 | status "google.golang.org/grpc/status"
11 | )
12 |
13 | // WhiteListChecker intercepts RPC and checks that the caller is whitelisted.
14 | func WhiteListChecker(ctx context.Context,
15 | req interface{},
16 | info *grpc.UnaryServerInfo,
17 | handler grpc.UnaryHandler) (interface{}, error) {
18 | peerinfo, ok := peer.FromContext(ctx)
19 | if !ok {
20 | return nil, status.Errorf(codes.Internal, "failed to retrieve peer info")
21 | }
22 |
23 | host, _, err := net.SplitHostPort(peerinfo.Addr.String())
24 | if err != nil {
25 | return nil, status.Errorf(codes.Internal, err.Error())
26 | }
27 |
28 | serv := info.Server.(*Server)
29 | if !includes(serv.Config.Whitelist, host) {
30 | return nil, status.Errorf(codes.PermissionDenied, "host %s is not in whitelist", host)
31 | }
32 |
33 | // calls the handler
34 | h, err := handler(ctx, req)
35 |
36 | return h, err
37 | }
38 |
39 | // includes checks that the 'arr' includes 'value'
40 | func includes(arr []string, value string) bool {
41 | for i := range arr {
42 | if arr[i] == value {
43 | return true
44 | }
45 | }
46 | return false
47 | }
--------------------------------------------------------------------------------
/core/cohort/commitalgo/hooks/registry.go:
--------------------------------------------------------------------------------
1 | package hooks
2 |
3 | import (
4 | "github.com/vadiminshakov/committer/core/dto"
5 | )
6 |
7 | // Hook defines the interface for commit algorithm hooks.
8 | type Hook interface {
9 | OnPropose(req *dto.ProposeRequest) bool
10 | OnCommit(req *dto.CommitRequest) bool
11 | }
12 |
13 | // Registry manages a collection of hooks.
14 | type Registry struct {
15 | hooks []Hook
16 | }
17 |
18 | // NewRegistry creates a new hook registry.
19 | func NewRegistry() *Registry {
20 | return &Registry{
21 | hooks: make([]Hook, 0),
22 | }
23 | }
24 |
25 | // Register adds a new hook to the registry.
26 | func (r *Registry) Register(hook Hook) {
27 | r.hooks = append(r.hooks, hook)
28 | }
29 |
30 | // ExecutePropose runs all registered propose hooks.
31 | // Returns false if any hook returns false.
32 | func (r *Registry) ExecutePropose(req *dto.ProposeRequest) bool {
33 | for _, hook := range r.hooks {
34 | if !hook.OnPropose(req) {
35 | return false
36 | }
37 | }
38 | return true
39 | }
40 |
41 | // ExecuteCommit runs all registered commit hooks.
42 | // Returns false if any hook returns false.
43 | func (r *Registry) ExecuteCommit(req *dto.CommitRequest) bool {
44 | for _, hook := range r.hooks {
45 | if !hook.OnCommit(req) {
46 | return false
47 | }
48 | }
49 | return true
50 | }
51 |
52 | // Count returns the number of registered hooks
53 | func (r *Registry) Count() int {
54 | return len(r.hooks)
55 | }
56 |
--------------------------------------------------------------------------------
/io/gateway/grpc/proto/schema.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 | package schema;
3 | import "google/protobuf/empty.proto";
4 | option go_package = "github.com/vadiminshakov/committer/proto";
5 |
6 | service InternalCommitAPI {
7 | rpc Propose(ProposeRequest) returns (Response);
8 | rpc Precommit(PrecommitRequest) returns (Response);
9 | rpc Commit(CommitRequest) returns (Response);
10 | rpc Abort(AbortRequest) returns (Response);
11 | }
12 |
13 | service ClientAPI {
14 | rpc Put(Entry) returns (Response);
15 | rpc Get(Msg) returns (Value);
16 | rpc NodeInfo(google.protobuf.Empty) returns (Info);
17 | }
18 |
19 | message ProposeRequest {
20 | string Key = 1;
21 | bytes Value = 2;
22 | CommitType CommitType = 3;
23 | uint64 index = 4;
24 | }
25 |
26 | enum CommitType {
27 | TWO_PHASE_COMMIT = 0;
28 | THREE_PHASE_COMMIT = 1;
29 | }
30 |
31 | enum Type {
32 | ACK = 0;
33 | NACK = 1;
34 | }
35 |
36 | message Response {
37 | Type type = 1;
38 | uint64 index = 2;
39 | }
40 |
41 | message PrecommitRequest {
42 | uint64 index = 1;
43 | }
44 |
45 | message CommitRequest {
46 | uint64 index = 1;
47 | bool isRollback = 2;
48 | }
49 |
50 | message Entry {
51 | string key = 1;
52 | bytes value = 2;
53 | }
54 |
55 | message Msg {
56 | string key = 1;
57 | }
58 |
59 | message Value {
60 | bytes value = 1;
61 | }
62 |
63 | message Info {
64 | uint64 height = 1;
65 | }
66 |
67 | message AbortRequest {
68 | uint64 height = 1;
69 | string reason = 2;
70 | }
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | prepare:
2 | @rm -rf ./badger
3 | @mkdir ./badger
4 | @mkdir ./badger/coordinator
5 | @mkdir ./badger/cohort
6 |
7 | run-example-coordinator:
8 | @rm -rf ./badger/coordinator
9 | @go run . -role=coordinator -nodeaddr=localhost:3000 -cohorts=localhost:3001 -committype=three-phase -timeout=1000 -dbpath=./badger/coordinator -whitelist=127.0.0.1
10 |
11 | run-example-cohort:
12 | @rm -rf ./badger/cohort
13 | @go run . -role=cohort -coordinator=localhost:3000 -nodeaddr=localhost:3001 -committype=three-phase -timeout=1000 -dbpath=./badger/cohort -whitelist=127.0.0.1
14 |
15 | run-example-client:
16 | @go run ./examples/client
17 |
18 | tests:
19 | @go test ./...
20 |
21 | start-toxiproxy:
22 | @echo "Starting Toxiproxy server..."
23 | @pkill toxiproxy-server || true
24 | @toxiproxy-server > /dev/null 2>&1 &
25 | @echo "Toxiproxy server started in background"
26 |
27 | stop-toxiproxy:
28 | @echo "Stopping Toxiproxy server..."
29 | @pkill toxiproxy-server || true
30 |
31 | test-chaos: start-toxiproxy
32 | @echo "Waiting for Toxiproxy to start..."
33 | @sleep 2
34 | @echo "Running chaos tests..."
35 | @go test -v -tags=chaos -run "TestChaos" ./...
36 | @$(MAKE) stop-toxiproxy
37 |
38 | proto-gen:
39 | @echo "Generating proto files..."
40 | @protoc --go_out=io/gateway/grpc/proto --go_opt=paths=source_relative \
41 | --go-grpc_out=io/gateway/grpc/proto --go-grpc_opt=paths=source_relative \
42 | --proto_path=io/gateway/grpc/proto io/gateway/grpc/proto/schema.proto
43 | @echo "Proto files generated successfully"
44 |
45 | generate: proto-gen
46 | @echo "Generating mocks..."
47 | @go generate ./...
48 | @echo "All files generated successfully"
--------------------------------------------------------------------------------
/io/gateway/grpc/client/client_api.go:
--------------------------------------------------------------------------------
1 | // Package client provides gRPC client implementations for communicating with committer nodes.
2 | //
3 | // This package contains both internal client for node-to-node communication
4 | // and external client API for application integration.
5 | package client
6 |
7 | import (
8 | "context"
9 |
10 | "github.com/vadiminshakov/committer/io/gateway/grpc/proto"
11 | "google.golang.org/protobuf/types/known/emptypb"
12 | )
13 |
14 | // ClientAPIClient provides access to the client API.
15 | type ClientAPIClient struct {
16 | Connection proto.ClientAPIClient
17 | }
18 |
19 | // NewClientAPI creates an instance of the client API client.
20 | // The addr parameter should be the network address of the coordinator (host + port).
21 | func NewClientAPI(addr string) (*ClientAPIClient, error) {
22 | conn, err := createConnection(addr)
23 | if err != nil {
24 | return nil, err
25 | }
26 | return &ClientAPIClient{Connection: proto.NewClientAPIClient(conn)}, nil
27 | }
28 |
29 | // Put sends a put request to the client API.
30 | func (client *ClientAPIClient) Put(ctx context.Context, key string, value []byte) (*proto.Response, error) {
31 | return client.Connection.Put(ctx, &proto.Entry{Key: key, Value: value})
32 | }
33 |
34 | // NodeInfo gets the current height of the node.
35 | func (client *ClientAPIClient) NodeInfo(ctx context.Context) (*proto.Info, error) {
36 | return client.Connection.NodeInfo(ctx, &emptypb.Empty{})
37 | }
38 |
39 | // Get gets the value by the specified key.
40 | func (client *ClientAPIClient) Get(ctx context.Context, key string) (*proto.Value, error) {
41 | return client.Connection.Get(ctx, &proto.Msg{Key: key})
42 | }
43 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/vadiminshakov/committer
2 |
3 | go 1.23.0
4 |
5 | require (
6 | github.com/dgraph-io/badger/v4 v4.8.0
7 | github.com/pkg/errors v0.9.1
8 | github.com/sirupsen/logrus v1.2.0
9 | github.com/stretchr/testify v1.10.0
10 | github.com/vadiminshakov/gowal v0.0.4
11 | go.uber.org/mock v0.6.0
12 | google.golang.org/grpc v1.50.0
13 | google.golang.org/protobuf v1.36.6
14 | )
15 |
16 | require (
17 | github.com/cespare/xxhash/v2 v2.3.0 // indirect
18 | github.com/davecgh/go-spew v1.1.1 // indirect
19 | github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect
20 | github.com/dustin/go-humanize v1.0.1 // indirect
21 | github.com/go-logr/logr v1.4.3 // indirect
22 | github.com/go-logr/stdr v1.2.2 // indirect
23 | github.com/golang/protobuf v1.5.2 // indirect
24 | github.com/google/flatbuffers v25.2.10+incompatible // indirect
25 | github.com/klauspost/compress v1.18.0 // indirect
26 | github.com/konsorten/go-windows-terminal-sequences v1.0.1 // indirect
27 | github.com/pmezard/go-difflib v1.0.0 // indirect
28 | github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
29 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
30 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect
31 | go.opentelemetry.io/otel v1.37.0 // indirect
32 | go.opentelemetry.io/otel/metric v1.37.0 // indirect
33 | go.opentelemetry.io/otel/trace v1.37.0 // indirect
34 | golang.org/x/crypto v0.39.0 // indirect
35 | golang.org/x/net v0.41.0 // indirect
36 | golang.org/x/sys v0.34.0 // indirect
37 | golang.org/x/term v0.32.0 // indirect
38 | golang.org/x/text v0.26.0 // indirect
39 | google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e // indirect
40 | gopkg.in/yaml.v3 v3.0.1 // indirect
41 | )
42 |
--------------------------------------------------------------------------------
/core/dto/dto.go:
--------------------------------------------------------------------------------
1 | // Package dto defines the data transfer objects used for communication in
2 | // distributed atomic commit protocols between coordinators and cohorts.
3 | // These structures represent the messages exchanged during the two-phase
4 | // and three-phase commit protocols.
5 | package dto
6 |
7 | // ProposeRequest represents a proposal for a new transaction.
8 | type ProposeRequest struct {
9 | Key string // Key to be stored
10 | Value []byte // Value to be stored
11 | Height uint64 // Transaction height/sequence number
12 | }
13 |
14 | // CommitRequest represents a commit or rollback request.
15 | type CommitRequest struct {
16 | Height uint64 // Transaction height/sequence number
17 | IsRollback bool // Whether this is a rollback request
18 | }
19 |
20 | // ResponseType represents the type of response from a cohort.
21 | type ResponseType int32
22 |
23 | const (
24 | // ResponseTypeAck indicates successful acknowledgment.
25 | ResponseTypeAck ResponseType = iota
26 | // ResponseTypeNack indicates negative acknowledgment (rejection).
27 | ResponseTypeNack
28 | )
29 |
30 | // CohortResponse represents a response from a cohort node.
31 | type CohortResponse struct {
32 | ResponseType
33 | Height uint64 // Current height of the cohort
34 | }
35 |
36 | // BroadcastRequest represents a request to be broadcast to all cohorts.
37 | type BroadcastRequest struct {
38 | Key string // Key to be stored
39 | Value []byte // Value to be stored
40 | }
41 |
42 | // BroadcastResponse represents a response to a broadcast request.
43 | type BroadcastResponse struct {
44 | Type ResponseType // Response type (ACK/NACK)
45 | Height uint64 // Height of the committed transaction
46 | }
47 |
48 | // AbortRequest represents a request to abort a transaction.
49 | type AbortRequest struct {
50 | Height uint64 // Transaction height to abort
51 | Reason string // Reason for the abort
52 | }
53 |
--------------------------------------------------------------------------------
/core/cohort/commitalgo/fsm.go:
--------------------------------------------------------------------------------
1 | package commitalgo
2 |
3 | import (
4 | "errors"
5 | "sync"
6 | )
7 |
8 | type mode string
9 |
10 | const (
11 | twophase = "two-phase"
12 | threephase = "three-phase"
13 | )
14 | const (
15 | proposeStage = "propose"
16 | precommitStage = "precommit"
17 | commitStage = "commit"
18 | )
19 |
20 | type stateMachine struct {
21 | mu sync.RWMutex
22 | currentState string
23 | mode mode
24 | transitions map[string]map[string]struct{}
25 | }
26 |
27 | var twoPhaseTransitions = map[string]map[string]struct{}{
28 | proposeStage: {
29 | proposeStage: struct{}{},
30 | commitStage: struct{}{},
31 | },
32 | commitStage: {
33 | proposeStage: struct{}{},
34 | },
35 | }
36 |
37 | var threePhaseTransitions = map[string]map[string]struct{}{
38 | proposeStage: {
39 | proposeStage: struct{}{},
40 | precommitStage: struct{}{},
41 | },
42 | precommitStage: {
43 | commitStage: struct{}{},
44 | },
45 | commitStage: {
46 | proposeStage: struct{}{},
47 | },
48 | }
49 |
50 | func newStateMachine(mode mode) *stateMachine {
51 | tr := twoPhaseTransitions
52 | if mode == threephase {
53 | tr = threePhaseTransitions
54 | }
55 |
56 | return &stateMachine{
57 | currentState: proposeStage,
58 | mode: mode,
59 | transitions: tr,
60 | }
61 | }
62 |
63 | func (sm *stateMachine) Transition(nextState string) error {
64 | sm.mu.Lock()
65 | defer sm.mu.Unlock()
66 |
67 | if allowedStates, ok := sm.transitions[sm.currentState]; ok {
68 | if _, ok = allowedStates[nextState]; ok {
69 | sm.currentState = nextState
70 | return nil
71 | }
72 | }
73 |
74 | return errors.New("invalid state transition")
75 | }
76 |
77 | func (sm *stateMachine) getCurrentState() string {
78 | sm.mu.RLock()
79 | defer sm.mu.RUnlock()
80 | return sm.currentState
81 | }
82 |
83 | func (sm *stateMachine) GetMode() mode {
84 | sm.mu.RLock()
85 | defer sm.mu.RUnlock()
86 | return sm.mode
87 | }
88 |
--------------------------------------------------------------------------------
/io/gateway/grpc/client/internal_client.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 |
6 | log "github.com/sirupsen/logrus"
7 | "github.com/vadiminshakov/committer/core/dto"
8 | "github.com/vadiminshakov/committer/io/gateway/grpc/proto"
9 | )
10 |
11 | // InternalCommitClient provides access to the internal API of the atomic commit protocol
12 | type InternalCommitClient struct {
13 | Connection proto.InternalCommitAPIClient
14 | }
15 |
16 | // NewInternalClient creates an instance of the internal API client.
17 | // 'addr' - the network address of the node (host + port).
18 | func NewInternalClient(addr string) (*InternalCommitClient, error) {
19 | conn, err := createConnection(addr)
20 | if err != nil {
21 | return nil, err
22 | }
23 | return &InternalCommitClient{Connection: proto.NewInternalCommitAPIClient(conn)}, nil
24 | }
25 |
26 | // Propose sends a propose request to the internal commit API
27 | func (client *InternalCommitClient) Propose(ctx context.Context, req *proto.ProposeRequest) (*proto.Response, error) {
28 | return client.Connection.Propose(ctx, req)
29 | }
30 |
31 | // Precommit sends a precommit request to the internal commit API
32 | func (client *InternalCommitClient) Precommit(ctx context.Context, req *proto.PrecommitRequest) (*proto.Response, error) {
33 | return client.Connection.Precommit(ctx, req)
34 | }
35 |
36 | // Commit sends a commit request to the internal commit API
37 | func (client *InternalCommitClient) Commit(ctx context.Context, req *proto.CommitRequest) (*proto.Response, error) {
38 | return client.Connection.Commit(ctx, req)
39 | }
40 |
41 | // Abort sends an abort request to the internal commit API
42 | func (client *InternalCommitClient) Abort(ctx context.Context, req *dto.AbortRequest) (*proto.Response, error) {
43 | protoReq := &proto.AbortRequest{
44 | Height: req.Height,
45 | Reason: req.Reason,
46 | }
47 |
48 | log.Infof("Sending abort request for height %d with reason: %s", req.Height, req.Reason)
49 | return client.Connection.Abort(ctx, protoReq)
50 | }
51 |
--------------------------------------------------------------------------------
/mocks/mock_state_store.go:
--------------------------------------------------------------------------------
1 | // Code generated by MockGen. DO NOT EDIT.
2 | // Source: github.com/vadiminshakov/committer/core/coordinator (interfaces: StateStore)
3 | //
4 | // Generated by this command:
5 | //
6 | // mockgen -destination=../../mocks/mock_state_store.go -package=mocks . StateStore
7 | //
8 |
9 | // Package mocks is a generated GoMock package.
10 | package mocks
11 |
12 | import (
13 | reflect "reflect"
14 |
15 | gomock "go.uber.org/mock/gomock"
16 | )
17 |
18 | // MockStateStore is a mock of StateStore interface.
19 | type MockStateStore struct {
20 | ctrl *gomock.Controller
21 | recorder *MockStateStoreMockRecorder
22 | isgomock struct{}
23 | }
24 |
25 | // MockStateStoreMockRecorder is the mock recorder for MockStateStore.
26 | type MockStateStoreMockRecorder struct {
27 | mock *MockStateStore
28 | }
29 |
30 | // NewMockStateStore creates a new mock instance.
31 | func NewMockStateStore(ctrl *gomock.Controller) *MockStateStore {
32 | mock := &MockStateStore{ctrl: ctrl}
33 | mock.recorder = &MockStateStoreMockRecorder{mock}
34 | return mock
35 | }
36 |
37 | // EXPECT returns an object that allows the caller to indicate expected use.
38 | func (m *MockStateStore) EXPECT() *MockStateStoreMockRecorder {
39 | return m.recorder
40 | }
41 |
42 | // Close mocks base method.
43 | func (m *MockStateStore) Close() error {
44 | m.ctrl.T.Helper()
45 | ret := m.ctrl.Call(m, "Close")
46 | ret0, _ := ret[0].(error)
47 | return ret0
48 | }
49 |
50 | // Close indicates an expected call of Close.
51 | func (mr *MockStateStoreMockRecorder) Close() *gomock.Call {
52 | mr.mock.ctrl.T.Helper()
53 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStateStore)(nil).Close))
54 | }
55 |
56 | // Put mocks base method.
57 | func (m *MockStateStore) Put(key string, value []byte) error {
58 | m.ctrl.T.Helper()
59 | ret := m.ctrl.Call(m, "Put", key, value)
60 | ret0, _ := ret[0].(error)
61 | return ret0
62 | }
63 |
64 | // Put indicates an expected call of Put.
65 | func (mr *MockStateStoreMockRecorder) Put(key, value any) *gomock.Call {
66 | mr.mock.ctrl.T.Helper()
67 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockStateStore)(nil).Put), key, value)
68 | }
69 |
--------------------------------------------------------------------------------
/mocks/mock_commitalgo_state_store.go:
--------------------------------------------------------------------------------
1 | // Code generated by MockGen. DO NOT EDIT.
2 | // Source: github.com/vadiminshakov/committer/core/cohort/commitalgo (interfaces: StateStore)
3 | //
4 | // Generated by this command:
5 | //
6 | // mockgen -destination=../../../mocks/mock_commitalgo_state_store.go -package=mocks -mock_names=StateStore=MockCommitalgoStateStore . StateStore
7 | //
8 |
9 | // Package mocks is a generated GoMock package.
10 | package mocks
11 |
12 | import (
13 | reflect "reflect"
14 |
15 | gomock "go.uber.org/mock/gomock"
16 | )
17 |
18 | // MockCommitalgoStateStore is a mock of StateStore interface.
19 | type MockCommitalgoStateStore struct {
20 | ctrl *gomock.Controller
21 | recorder *MockCommitalgoStateStoreMockRecorder
22 | isgomock struct{}
23 | }
24 |
25 | // MockCommitalgoStateStoreMockRecorder is the mock recorder for MockCommitalgoStateStore.
26 | type MockCommitalgoStateStoreMockRecorder struct {
27 | mock *MockCommitalgoStateStore
28 | }
29 |
30 | // NewMockCommitalgoStateStore creates a new mock instance.
31 | func NewMockCommitalgoStateStore(ctrl *gomock.Controller) *MockCommitalgoStateStore {
32 | mock := &MockCommitalgoStateStore{ctrl: ctrl}
33 | mock.recorder = &MockCommitalgoStateStoreMockRecorder{mock}
34 | return mock
35 | }
36 |
37 | // EXPECT returns an object that allows the caller to indicate expected use.
38 | func (m *MockCommitalgoStateStore) EXPECT() *MockCommitalgoStateStoreMockRecorder {
39 | return m.recorder
40 | }
41 |
42 | // Close mocks base method.
43 | func (m *MockCommitalgoStateStore) Close() error {
44 | m.ctrl.T.Helper()
45 | ret := m.ctrl.Call(m, "Close")
46 | ret0, _ := ret[0].(error)
47 | return ret0
48 | }
49 |
50 | // Close indicates an expected call of Close.
51 | func (mr *MockCommitalgoStateStoreMockRecorder) Close() *gomock.Call {
52 | mr.mock.ctrl.T.Helper()
53 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCommitalgoStateStore)(nil).Close))
54 | }
55 |
56 | // Put mocks base method.
57 | func (m *MockCommitalgoStateStore) Put(key string, value []byte) error {
58 | m.ctrl.T.Helper()
59 | ret := m.ctrl.Call(m, "Put", key, value)
60 | ret0, _ := ret[0].(error)
61 | return ret0
62 | }
63 |
64 | // Put indicates an expected call of Put.
65 | func (mr *MockCommitalgoStateStoreMockRecorder) Put(key, value any) *gomock.Call {
66 | mr.mock.ctrl.T.Helper()
67 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockCommitalgoStateStore)(nil).Put), key, value)
68 | }
69 |
--------------------------------------------------------------------------------
/core/walrecord/walrecord.go:
--------------------------------------------------------------------------------
1 | package walrecord
2 |
3 | import (
4 | "encoding/binary"
5 | "fmt"
6 | )
7 |
8 | const (
9 | // system keys for protocol phases
10 | KeyPrepared = "__tx:prepared"
11 | KeyPrecommit = "__tx:precommit"
12 | KeyCommit = "__tx:commit"
13 | KeyAbort = "__tx:abort"
14 |
15 | // stride is the multiplier for height to determine the starting WAL index
16 | Stride = 4
17 | )
18 |
19 | // WalTx represents the payload stored in the WAL.
20 | type WalTx struct {
21 | Key string
22 | Value []byte
23 | }
24 |
25 | // PreparedSlot returns the WAL index for the PREPARED phase at a given height.
26 | func PreparedSlot(height uint64) uint64 {
27 | return height*Stride + 0
28 | }
29 |
30 | // PrecommitSlot returns the WAL index for the PRECOMMIT phase at a given height.
31 | func PrecommitSlot(height uint64) uint64 {
32 | return height*Stride + 1
33 | }
34 |
35 | // CommitSlot returns the WAL index for the COMMIT phase at a given height.
36 | func CommitSlot(height uint64) uint64 {
37 | return height*Stride + 2
38 | }
39 |
40 | // AbortSlot returns the WAL index for the ABORT phase at a given height.
41 | func AbortSlot(height uint64) uint64 {
42 | return height*Stride + 3
43 | }
44 |
45 | // Encode serializes a WalTx into bytes.
46 | // Format: [KeyLen(4 bytes)] [KeyBytes] [ValueLen(4 bytes)] [ValueBytes]
47 | func Encode(tx WalTx) ([]byte, error) {
48 | keyLen := uint32(len(tx.Key))
49 | valLen := uint32(len(tx.Value))
50 |
51 | // 4 bytes for key len + key bytes + 4 bytes for val len + val bytes
52 | buf := make([]byte, 4+keyLen+4+valLen)
53 |
54 | binary.BigEndian.PutUint32(buf[0:4], keyLen)
55 | copy(buf[4:4+keyLen], tx.Key)
56 |
57 | binary.BigEndian.PutUint32(buf[4+keyLen:4+keyLen+4], valLen)
58 | copy(buf[4+keyLen+4:], tx.Value)
59 |
60 | return buf, nil
61 | }
62 |
63 | // Decode deserializes bytes into a WalTx.
64 | func Decode(data []byte) (WalTx, error) {
65 | if len(data) < 4 {
66 | return WalTx{}, fmt.Errorf("data too short for key length")
67 | }
68 |
69 | keyLen := binary.BigEndian.Uint32(data[0:4])
70 | if uint32(len(data)) < 4+keyLen+4 {
71 | return WalTx{}, fmt.Errorf("data too short for key and value length")
72 | }
73 |
74 | key := string(data[4 : 4+keyLen])
75 |
76 | valLen := binary.BigEndian.Uint32(data[4+keyLen : 4+keyLen+4])
77 | if uint32(len(data)) < 4+keyLen+4+valLen {
78 | return WalTx{}, fmt.Errorf("data too short for value body")
79 | }
80 |
81 | value := make([]byte, valLen)
82 | copy(value, data[4+keyLen+4:4+keyLen+4+valLen])
83 |
84 | return WalTx{
85 | Key: key,
86 | Value: value,
87 | }, nil
88 | }
89 |
--------------------------------------------------------------------------------
/core/cohort/cohort.go:
--------------------------------------------------------------------------------
1 | // Package cohort implements the cohort role in distributed atomic commit protocols.
2 | // A cohort is a participant in the atomic commit protocol that receives
3 | // proposals from the coordinator and responds with either commit or abort.
4 | //
5 | // Cohorts participate in 2PC and 3PC transactions by responding to coordinator
6 | // requests and maintaining local transaction state.
7 | package cohort
8 |
9 | import (
10 | "context"
11 | "errors"
12 |
13 | "github.com/vadiminshakov/committer/core/cohort/commitalgo/hooks"
14 | "github.com/vadiminshakov/committer/core/dto"
15 | )
16 |
17 | // Mode represents the commit protocol mode.
18 | type Mode string
19 |
20 | // THREE_PHASE represents the three-phase commit protocol mode.
21 | const THREE_PHASE Mode = "three-phase"
22 |
23 | // Committer defines the interface for commit algorithms.
24 | //
25 | //go:generate mockgen -destination=../../../mocks/mock_committer.go -package=mocks . Committer
26 | type Committer interface {
27 | Height() uint64
28 | Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error)
29 | Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error)
30 | Commit(ctx context.Context, req *dto.CommitRequest) (*dto.CohortResponse, error)
31 | Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error)
32 | RegisterHook(hook hooks.Hook)
33 | }
34 |
35 | // CohortImpl implements the cohort node functionality.
36 | type CohortImpl struct {
37 | committer Committer
38 | commitType Mode
39 | }
40 |
41 | // NewCohort creates a new cohort instance.
42 | func NewCohort(
43 | committer Committer,
44 | commitType Mode) *CohortImpl {
45 | return &CohortImpl{
46 | committer: committer,
47 | commitType: commitType,
48 | }
49 | }
50 |
51 | func (c *CohortImpl) Height() uint64 {
52 | return c.committer.Height()
53 | }
54 |
55 | func (c *CohortImpl) Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) {
56 | return c.committer.Propose(ctx, req)
57 | }
58 |
59 | func (s *CohortImpl) Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) {
60 | if s.commitType != THREE_PHASE {
61 | return nil, errors.New("precommit is allowed for 3PC mode only")
62 | }
63 |
64 | return s.committer.Precommit(ctx, index)
65 | }
66 |
67 | func (c *CohortImpl) Commit(ctx context.Context, in *dto.CommitRequest) (resp *dto.CohortResponse, err error) {
68 | return c.committer.Commit(ctx, in)
69 | }
70 |
71 | func (c *CohortImpl) Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) {
72 | return c.committer.Abort(ctx, req)
73 | }
74 |
--------------------------------------------------------------------------------
/mocks/mock_wal.go:
--------------------------------------------------------------------------------
1 | // Code generated by MockGen. DO NOT EDIT.
2 | // Source: github.com/vadiminshakov/committer/core/coordinator (interfaces: wal)
3 | //
4 | // Generated by this command:
5 | //
6 | // mockgen -destination=../../mocks/mock_wal.go -package=mocks . wal
7 | //
8 |
9 | // Package mocks is a generated GoMock package.
10 | package mocks
11 |
12 | import (
13 | reflect "reflect"
14 |
15 | gomock "go.uber.org/mock/gomock"
16 | )
17 |
18 | // Mockwal is a mock of wal interface.
19 | type Mockwal struct {
20 | ctrl *gomock.Controller
21 | recorder *MockwalMockRecorder
22 | isgomock struct{}
23 | }
24 |
25 | // MockwalMockRecorder is the mock recorder for Mockwal.
26 | type MockwalMockRecorder struct {
27 | mock *Mockwal
28 | }
29 |
30 | // NewMockwal creates a new mock instance.
31 | func NewMockwal(ctrl *gomock.Controller) *Mockwal {
32 | mock := &Mockwal{ctrl: ctrl}
33 | mock.recorder = &MockwalMockRecorder{mock}
34 | return mock
35 | }
36 |
37 | // EXPECT returns an object that allows the caller to indicate expected use.
38 | func (m *Mockwal) EXPECT() *MockwalMockRecorder {
39 | return m.recorder
40 | }
41 |
42 | // Close mocks base method.
43 | func (m *Mockwal) Close() error {
44 | m.ctrl.T.Helper()
45 | ret := m.ctrl.Call(m, "Close")
46 | ret0, _ := ret[0].(error)
47 | return ret0
48 | }
49 |
50 | // Close indicates an expected call of Close.
51 | func (mr *MockwalMockRecorder) Close() *gomock.Call {
52 | mr.mock.ctrl.T.Helper()
53 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*Mockwal)(nil).Close))
54 | }
55 |
56 | // Get mocks base method.
57 | func (m *Mockwal) Get(index uint64) (string, []byte, error) {
58 | m.ctrl.T.Helper()
59 | ret := m.ctrl.Call(m, "Get", index)
60 | ret0, _ := ret[0].(string)
61 | ret1, _ := ret[1].([]byte)
62 | ret2, _ := ret[2].(error)
63 | return ret0, ret1, ret2
64 | }
65 |
66 | // Get indicates an expected call of Get.
67 | func (mr *MockwalMockRecorder) Get(index any) *gomock.Call {
68 | mr.mock.ctrl.T.Helper()
69 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*Mockwal)(nil).Get), index)
70 | }
71 |
72 | // Write mocks base method.
73 | func (m *Mockwal) Write(index uint64, key string, value []byte) error {
74 | m.ctrl.T.Helper()
75 | ret := m.ctrl.Call(m, "Write", index, key, value)
76 | ret0, _ := ret[0].(error)
77 | return ret0
78 | }
79 |
80 | // Write indicates an expected call of Write.
81 | func (mr *MockwalMockRecorder) Write(index, key, value any) *gomock.Call {
82 | mr.mock.ctrl.T.Helper()
83 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*Mockwal)(nil).Write), index, key, value)
84 | }
85 |
86 | // WriteTombstone mocks base method.
87 | func (m *Mockwal) WriteTombstone(index uint64) error {
88 | m.ctrl.T.Helper()
89 | ret := m.ctrl.Call(m, "WriteTombstone", index)
90 | ret0, _ := ret[0].(error)
91 | return ret0
92 | }
93 |
94 | // WriteTombstone indicates an expected call of WriteTombstone.
95 | func (mr *MockwalMockRecorder) WriteTombstone(index any) *gomock.Call {
96 | mr.mock.ctrl.T.Helper()
97 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTombstone", reflect.TypeOf((*Mockwal)(nil).WriteTombstone), index)
98 | }
99 |
--------------------------------------------------------------------------------
/core/cohort/commitalgo/hooks/registry_test.go:
--------------------------------------------------------------------------------
1 | package hooks
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/require"
7 | "github.com/vadiminshakov/committer/core/dto"
8 | )
9 |
10 | // TestHook is a simple test hook implementation
11 | type TestHook struct {
12 | proposeResult bool
13 | commitResult bool
14 | proposeCalled bool
15 | commitCalled bool
16 | }
17 |
18 | func (t *TestHook) OnPropose(req *dto.ProposeRequest) bool {
19 | t.proposeCalled = true
20 | return t.proposeResult
21 | }
22 |
23 | func (t *TestHook) OnCommit(req *dto.CommitRequest) bool {
24 | t.commitCalled = true
25 | return t.commitResult
26 | }
27 |
28 | func TestRegistry_Register(t *testing.T) {
29 | registry := NewRegistry()
30 | hook := &TestHook{}
31 |
32 | require.Equal(t, 0, registry.Count(), "Expected 0 hooks initially")
33 |
34 | registry.Register(hook)
35 |
36 | require.Equal(t, 1, registry.Count(), "Expected 1 hook after registration")
37 | }
38 |
39 | func TestRegistry_ExecutePropose(t *testing.T) {
40 | registry := NewRegistry()
41 |
42 | hook1 := &TestHook{proposeResult: true}
43 | registry.Register(hook1)
44 |
45 | req := &dto.ProposeRequest{Height: 1, Key: "test", Value: []byte("value")}
46 | result := registry.ExecutePropose(req)
47 |
48 | require.True(t, result, "Expected propose to succeed")
49 | require.True(t, hook1.proposeCalled, "Expected hook to be called")
50 |
51 | hook2 := &TestHook{proposeResult: false}
52 | registry.Register(hook2)
53 |
54 | result = registry.ExecutePropose(req)
55 |
56 | require.False(t, result, "Expected propose to fail when hook returns false")
57 | }
58 |
59 | func TestRegistry_ExecuteCommit(t *testing.T) {
60 | registry := NewRegistry()
61 |
62 | hook1 := &TestHook{commitResult: true}
63 | registry.Register(hook1)
64 |
65 | req := &dto.CommitRequest{Height: 1}
66 | result := registry.ExecuteCommit(req)
67 |
68 | require.True(t, result, "Expected commit to succeed")
69 | require.True(t, hook1.commitCalled, "Expected hook to be called")
70 |
71 | hook2 := &TestHook{commitResult: false}
72 | registry.Register(hook2)
73 |
74 | result = registry.ExecuteCommit(req)
75 |
76 | require.False(t, result, "Expected commit to fail when hook returns false")
77 | }
78 |
79 | func TestRegistry_MultipleHooks(t *testing.T) {
80 | registry := NewRegistry()
81 |
82 | hook1 := &TestHook{proposeResult: true, commitResult: true}
83 | hook2 := &TestHook{proposeResult: true, commitResult: true}
84 | hook3 := &TestHook{proposeResult: false, commitResult: true} // this one fails
85 |
86 | registry.Register(hook1)
87 | registry.Register(hook2)
88 | registry.Register(hook3)
89 |
90 | proposeReq := &dto.ProposeRequest{Height: 1, Key: "test", Value: []byte("value")}
91 | commitReq := &dto.CommitRequest{Height: 1}
92 |
93 | // propose should fail because hook3 returns false
94 | require.False(t, registry.ExecutePropose(proposeReq), "Expected propose to fail")
95 |
96 | // commit should succeed because all hooks return true for commit
97 | require.True(t, registry.ExecuteCommit(commitReq), "Expected commit to succeed")
98 |
99 | // check that all hooks were called
100 | require.True(t, hook1.proposeCalled, "Expected hook1 to be called for propose")
101 | require.True(t, hook2.proposeCalled, "Expected hook2 to be called for propose")
102 | require.True(t, hook3.proposeCalled, "Expected hook3 to be called for propose")
103 |
104 | require.True(t, hook1.commitCalled, "Expected hook1 to be called for commit")
105 | require.True(t, hook2.commitCalled, "Expected hook2 to be called for commit")
106 | require.True(t, hook3.commitCalled, "Expected hook3 to be called for commit")
107 | }
108 |
--------------------------------------------------------------------------------
/config/config.go:
--------------------------------------------------------------------------------
1 | // Package config provides configuration management for the committer application.
2 | //
3 | // This package handles command-line flag parsing and configuration validation
4 | // for both coordinator and cohort nodes in the distributed consensus system.
5 | package config
6 |
7 | import (
8 | "flag"
9 | "log"
10 | "strings"
11 | )
12 |
13 | const (
14 | // DefaultWalDir is the default directory for write-ahead logs.
15 | DefaultWalDir string = "wal"
16 | // DefaultWalSegmentPrefix is the default prefix for WAL segment files.
17 | DefaultWalSegmentPrefix string = "msgs_"
18 | // DefaultWalSegmentThreshold is the default number of entries per WAL segment.
19 | DefaultWalSegmentThreshold int = 10000
20 | // DefaultWalMaxSegments is the default maximum number of WAL segments to retain.
21 | DefaultWalMaxSegments int = 100
22 | // DefaultWalIsInSyncDiskMode enables synchronous disk writes for WAL by default.
23 | DefaultWalIsInSyncDiskMode bool = true
24 | )
25 |
26 | // Config holds the configuration settings for the committer application.
27 | type Config struct {
28 | Role string // Node role: "coordinator" or "cohort"
29 | Nodeaddr string // Address of this node
30 | Coordinator string // Address of the coordinator (for cohorts)
31 | CommitType string // Commit protocol: "two-phase" or "three-phase"
32 | DBPath string // Path to the database directory
33 | Cohorts []string // List of cohort addresses (for coordinators)
34 | Whitelist []string // Whitelist of allowed node addresses
35 | Timeout uint64 // Timeout in milliseconds for 3PC operations
36 | }
37 |
38 | // Get creates configuration from yaml configuration file (if '-config=' flag specified) or command-line arguments.
39 | func Get() *Config {
40 | // command-line flags
41 | role := flag.String("role", "cohort", "role (coordinator or cohort)")
42 | nodeaddr := flag.String("nodeaddr", "localhost:3050", "node address")
43 | coordinator := flag.String("coordinator", "", "coordinator address")
44 | committype := flag.String("committype", "two-phase", "two-phase or three-phase commit mode")
45 | timeout := flag.Uint64("timeout", 1000, "ms, timeout after which the message is considered unacknowledged (only for three-phase mode, because two-phase is blocking by design)")
46 | dbpath := flag.String("dbpath", "./badger", "database path on filesystem")
47 | cohorts := flag.String("cohorts", "", "cohort addresses")
48 | whitelist := flag.String("whitelist", "127.0.0.1", "allowed hosts")
49 | flag.Parse()
50 |
51 | cohortsArray := filterEmpty(strings.Split(*cohorts, ","))
52 | if *role != "coordinator" {
53 | if !includes(cohortsArray, *nodeaddr) {
54 | cohortsArray = append(cohortsArray, *nodeaddr)
55 | }
56 | }
57 | whitelistArray := filterEmpty(strings.Split(*whitelist, ","))
58 |
59 | if *role == "coordinator" && len(cohortsArray) == 0 {
60 | log.Fatalf("coordinator role requires at least one cohort address")
61 | }
62 |
63 | return &Config{
64 | Role: *role,
65 | Nodeaddr: *nodeaddr,
66 | Coordinator: *coordinator,
67 | CommitType: *committype,
68 | DBPath: *dbpath,
69 | Cohorts: cohortsArray,
70 | Whitelist: whitelistArray,
71 | Timeout: *timeout,
72 | }
73 |
74 | }
75 |
76 | // includes checks that the 'arr' includes 'value'
77 | func includes(arr []string, value string) bool {
78 | for i := range arr {
79 | if arr[i] == value {
80 | return true
81 | }
82 | }
83 | return false
84 | }
85 |
86 | // filterEmpty trims and removes empty entries from a slice of strings
87 | func filterEmpty(values []string) []string {
88 | result := make([]string, 0, len(values))
89 | for _, v := range values {
90 | if trimmed := strings.TrimSpace(v); trimmed != "" {
91 | result = append(result, trimmed)
92 | }
93 | }
94 | return result
95 | }
96 |
--------------------------------------------------------------------------------
/core/cohort/commitalgo/hooks/examples.go:
--------------------------------------------------------------------------------
1 | package hooks
2 |
3 | import (
4 | "fmt"
5 | "time"
6 |
7 | log "github.com/sirupsen/logrus"
8 | "github.com/vadiminshakov/committer/core/dto"
9 | )
10 |
11 | // MetricsHook collects metrics about propose and commit operations.
12 | type MetricsHook struct {
13 | proposeCount uint64
14 | commitCount uint64
15 | startTime time.Time
16 | }
17 |
18 | // NewMetricsHook creates a new metrics hook.
19 | func NewMetricsHook() *MetricsHook {
20 | return &MetricsHook{
21 | startTime: time.Now(),
22 | }
23 | }
24 |
25 | // OnPropose increments propose counter and logs metrics.
26 | func (m *MetricsHook) OnPropose(req *dto.ProposeRequest) bool {
27 | m.proposeCount++
28 | log.WithFields(log.Fields{
29 | "height": req.Height,
30 | "propose_count": m.proposeCount,
31 | "uptime": time.Since(m.startTime),
32 | }).Info("Metrics: propose operation")
33 | return true
34 | }
35 |
36 | // OnCommit increments commit counter and logs metrics
37 | func (m *MetricsHook) OnCommit(req *dto.CommitRequest) bool {
38 | m.commitCount++
39 | log.WithFields(log.Fields{
40 | "height": req.Height,
41 | "commit_count": m.commitCount,
42 | "uptime": time.Since(m.startTime),
43 | }).Info("Metrics: commit operation")
44 | return true
45 | }
46 |
47 | // GetStats returns current statistics
48 | func (m *MetricsHook) GetStats() (uint64, uint64, time.Duration) {
49 | return m.proposeCount, m.commitCount, time.Since(m.startTime)
50 | }
51 |
52 | // ValidationHook validates requests before processing
53 | type ValidationHook struct {
54 | maxKeyLength int
55 | maxValueLength int
56 | }
57 |
58 | // NewValidationHook creates a new validation hook
59 | func NewValidationHook(maxKeyLength, maxValueLength int) *ValidationHook {
60 | return &ValidationHook{
61 | maxKeyLength: maxKeyLength,
62 | maxValueLength: maxValueLength,
63 | }
64 | }
65 |
66 | // OnPropose validates the propose request
67 | func (v *ValidationHook) OnPropose(req *dto.ProposeRequest) bool {
68 | if len(req.Key) > v.maxKeyLength {
69 | log.Errorf("Key too long: %d > %d", len(req.Key), v.maxKeyLength)
70 | return false
71 | }
72 |
73 | if len(req.Value) > v.maxValueLength {
74 | log.Errorf("Value too long: %d > %d", len(req.Value), v.maxValueLength)
75 | return false
76 | }
77 |
78 | log.Debugf("Validation passed for key: %s", req.Key)
79 | return true
80 | }
81 |
82 | // OnCommit validates the commit request
83 | func (v *ValidationHook) OnCommit(req *dto.CommitRequest) bool {
84 | if req.Height == 0 {
85 | log.Error("Invalid height: cannot be zero")
86 | return false
87 | }
88 |
89 | log.Debugf("Commit validation passed for height: %d", req.Height)
90 | return true
91 | }
92 |
93 | // AuditHook logs all operations for audit purposes
94 | type AuditHook struct {
95 | logFile string
96 | }
97 |
98 | // NewAuditHook creates a new audit hook
99 | func NewAuditHook(logFile string) *AuditHook {
100 | return &AuditHook{
101 | logFile: logFile,
102 | }
103 | }
104 |
105 | // OnPropose logs propose operations
106 | func (a *AuditHook) OnPropose(req *dto.ProposeRequest) bool {
107 | auditMsg := fmt.Sprintf("[AUDIT] PROPOSE - Height: %d, Key: %s, Value: %s, Time: %s",
108 | req.Height, req.Key, string(req.Value), time.Now().Format(time.RFC3339))
109 |
110 | log.WithField("audit", true).Info(auditMsg)
111 | // Here you could also write to a file if needed
112 | return true
113 | }
114 |
115 | // OnCommit logs commit operations
116 | func (a *AuditHook) OnCommit(req *dto.CommitRequest) bool {
117 | auditMsg := fmt.Sprintf("[AUDIT] COMMIT - Height: %d, Time: %s",
118 | req.Height, time.Now().Format(time.RFC3339))
119 |
120 | log.WithField("audit", true).Info(auditMsg)
121 | // Here you could also write to a file if needed
122 | return true
123 | }
124 |
--------------------------------------------------------------------------------
/io/store/store_test.go:
--------------------------------------------------------------------------------
1 | package store
2 |
3 | import (
4 | "os"
5 | "path/filepath"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | "github.com/stretchr/testify/require"
10 | "github.com/vadiminshakov/committer/core/walrecord"
11 | "github.com/vadiminshakov/gowal"
12 | )
13 |
14 | func TestStore_Recovery_Commit(t *testing.T) {
15 | walDir := filepath.Join(os.TempDir(), "wal_commit")
16 | dbDir := filepath.Join(os.TempDir(), "db_commit")
17 | defer os.RemoveAll(walDir)
18 | defer os.RemoveAll(dbDir)
19 |
20 | // 1. setup WAL
21 | w, err := gowal.NewWAL(gowal.Config{
22 | Dir: walDir,
23 | Prefix: "wal_",
24 | SegmentThreshold: 1024 * 1024,
25 | MaxSegments: 10,
26 | })
27 | require.NoError(t, err)
28 |
29 | // 2. write prepared (should require commit to apply)
30 | height := uint64(10)
31 | prepIdx := walrecord.PreparedSlot(height)
32 | tx := walrecord.WalTx{Key: "key1", Value: []byte("value1")}
33 | encoded, _ := walrecord.Encode(tx)
34 |
35 | err = w.Write(prepIdx, walrecord.KeyPrepared, encoded)
36 | require.NoError(t, err)
37 |
38 | // 3. write commit
39 | commitIdx := walrecord.CommitSlot(height)
40 | err = w.Write(commitIdx, walrecord.KeyCommit, encoded)
41 | require.NoError(t, err)
42 |
43 | w.Close()
44 |
45 | w2, err := gowal.NewWAL(gowal.Config{Dir: walDir, Prefix: "wal_", SegmentThreshold: 1024 * 1024, MaxSegments: 10})
46 | require.NoError(t, err)
47 | defer w2.Close()
48 |
49 | // 4. recover
50 | s, state, err := New(w2, dbDir)
51 | require.NoError(t, err)
52 | defer s.Close()
53 |
54 | // 5. verify
55 | assert.Equal(t, height+1, state.Height)
56 |
57 | val, err := s.Get("key1")
58 | assert.NoError(t, err)
59 | assert.Equal(t, []byte("value1"), val)
60 | }
61 |
62 | func TestStore_Recovery_PreparedOnly(t *testing.T) {
63 | walDir := filepath.Join(os.TempDir(), "wal_prepared")
64 | dbDir := filepath.Join(os.TempDir(), "db_prepared")
65 | defer os.RemoveAll(walDir)
66 | defer os.RemoveAll(dbDir)
67 |
68 | w, err := gowal.NewWAL(gowal.Config{
69 | Dir: walDir,
70 | Prefix: "wal_",
71 | SegmentThreshold: 1024 * 1024,
72 | MaxSegments: 10,
73 | })
74 | require.NoError(t, err)
75 |
76 | height := uint64(15)
77 | prepIdx := walrecord.PreparedSlot(height)
78 | tx := walrecord.WalTx{Key: "key2", Value: []byte("value2")}
79 | encoded, _ := walrecord.Encode(tx)
80 |
81 | err = w.Write(prepIdx, walrecord.KeyPrepared, encoded)
82 | require.NoError(t, err)
83 | w.Close()
84 |
85 | w2, err := gowal.NewWAL(gowal.Config{Dir: walDir, Prefix: "wal_", SegmentThreshold: 1024 * 1024, MaxSegments: 10})
86 | require.NoError(t, err)
87 | defer w2.Close()
88 |
89 | s, state, err := New(w2, dbDir)
90 | require.NoError(t, err)
91 | defer s.Close()
92 |
93 | // should NOT be applied to DB
94 | _, err = s.Get("key2")
95 | assert.Equal(t, ErrNotFound, err)
96 |
97 | // height should resume at 15 (incomplete)
98 | assert.Equal(t, height, state.Height)
99 | }
100 |
101 | func TestStore_Recovery_Abort(t *testing.T) {
102 | walDir := filepath.Join(os.TempDir(), "wal_abort")
103 | dbDir := filepath.Join(os.TempDir(), "db_abort")
104 | defer os.RemoveAll(walDir)
105 | defer os.RemoveAll(dbDir)
106 |
107 | w, err := gowal.NewWAL(gowal.Config{
108 | Dir: walDir,
109 | Prefix: "wal_",
110 | SegmentThreshold: 1024 * 1024,
111 | MaxSegments: 10,
112 | })
113 | require.NoError(t, err)
114 |
115 | height := uint64(20)
116 |
117 | // write abort
118 | abortIdx := walrecord.AbortSlot(height)
119 | err = w.Write(abortIdx, walrecord.KeyAbort, nil)
120 | require.NoError(t, err)
121 | w.Close()
122 |
123 | w2, err := gowal.NewWAL(gowal.Config{Dir: walDir, Prefix: "wal_", SegmentThreshold: 1024 * 1024, MaxSegments: 10})
124 | require.NoError(t, err)
125 | defer w2.Close()
126 |
127 | s, state, err := New(w2, dbDir)
128 | require.NoError(t, err)
129 | defer s.Close()
130 |
131 | // height should be 20+1 because it was resolved (Aborted)
132 | assert.Equal(t, height+1, state.Height)
133 | }
134 |
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | // Package main provides a distributed consensus system implementing Two-Phase Commit (2PC)
2 | // and Three-Phase Commit (3PC) protocols for distributed transactions.
3 | //
4 | // Committer is a Go implementation of distributed atomic commit protocols that allows
5 | // you to achieve data consistency in distributed systems using Two-Phase Commit (2PC)
6 | // and Three-Phase Commit (3PC) protocols for distributed transactions.
7 | // The system consists of coordinators that manage transactions and cohorts that
8 | // participate in the consensus process.
9 | //
10 | // Usage:
11 | //
12 | // # Start coordinator
13 | // ./committer -role=coordinator -nodeaddr=localhost:3000 -cohorts=localhost:3001,3002
14 | //
15 | // # Start cohort
16 | // ./committer -role=cohort -coordinator=localhost:3000 -nodeaddr=localhost:3001
17 | package main
18 |
19 | import (
20 | "fmt"
21 | "log"
22 | "os"
23 | "os/signal"
24 | "syscall"
25 |
26 | "github.com/vadiminshakov/committer/config"
27 | "github.com/vadiminshakov/committer/core/cohort"
28 | "github.com/vadiminshakov/committer/core/cohort/commitalgo"
29 | "github.com/vadiminshakov/committer/core/coordinator"
30 | "github.com/vadiminshakov/committer/io/gateway/grpc/server"
31 | "github.com/vadiminshakov/committer/io/store"
32 | "github.com/vadiminshakov/gowal"
33 | )
34 |
35 | func main() {
36 | if err := run(); err != nil {
37 | log.Fatalf("committer failed: %v", err)
38 | }
39 | }
40 |
41 | func run() error {
42 | ctx := make(chan os.Signal, 1)
43 | signal.Notify(ctx, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
44 | conf := config.Get()
45 |
46 | wal, err := newWAL()
47 | if err != nil {
48 | return err
49 | }
50 | defer wal.Close()
51 |
52 | stateStore, recovery, err := newStore(conf, wal)
53 | if err != nil {
54 | return err
55 | }
56 |
57 | roles, err := buildRoles(conf, stateStore, wal, recovery.Height)
58 | if err != nil {
59 | return err
60 | }
61 |
62 | srv, err := server.New(conf, roles.cohort, roles.coordinator, stateStore)
63 | if err != nil {
64 | return fmt.Errorf("failed to create server: %w", err)
65 | }
66 |
67 | srv.Run(server.WhiteListChecker)
68 | <-ctx
69 | srv.Stop()
70 |
71 | return nil
72 | }
73 |
74 | func newWAL() (*gowal.Wal, error) {
75 | walConfig := gowal.Config{
76 | Dir: config.DefaultWalDir,
77 | Prefix: config.DefaultWalSegmentPrefix,
78 | SegmentThreshold: config.DefaultWalSegmentThreshold,
79 | MaxSegments: config.DefaultWalMaxSegments,
80 | IsInSyncDiskMode: config.DefaultWalIsInSyncDiskMode,
81 | }
82 |
83 | w, err := gowal.NewWAL(walConfig)
84 | if err != nil {
85 | return nil, fmt.Errorf("failed to create WAL: %w", err)
86 | }
87 |
88 | return w, nil
89 | }
90 |
91 | func newStore(conf *config.Config, wal *gowal.Wal) (*store.Store, *store.RecoveryState, error) {
92 | stateStore, recovery, err := store.New(wal, conf.DBPath)
93 | if err != nil {
94 | return nil, nil, fmt.Errorf("failed to initialize state store: %w", err)
95 | }
96 |
97 | log.Printf("Recovered state from WAL: next height %d, keys %d\n", recovery.Height, stateStore.Size())
98 | return stateStore, recovery, nil
99 | }
100 |
101 | type roleComponents struct {
102 | cohort server.Cohort
103 | coordinator server.Coordinator
104 | }
105 |
106 | func buildRoles(conf *config.Config, stateStore *store.Store, wal *gowal.Wal, initialHeight uint64) (*roleComponents, error) {
107 | rc := &roleComponents{}
108 | switch conf.Role {
109 | case "cohort":
110 | committer := commitalgo.NewCommitter(stateStore, conf.CommitType, wal, conf.Timeout)
111 | committer.SetHeight(initialHeight)
112 | rc.cohort = cohort.NewCohort(committer, cohort.Mode(conf.CommitType))
113 | case "coordinator":
114 | coord, err := coordinator.New(conf, wal, stateStore)
115 | if err != nil {
116 | return nil, fmt.Errorf("failed to create coordinator: %w", err)
117 | }
118 | coord.SetHeight(initialHeight)
119 | rc.coordinator = coord
120 | default:
121 | return nil, fmt.Errorf("unsupported role %q, expected coordinator or cohort", conf.Role)
122 | }
123 |
124 | return rc, nil
125 | }
126 |
--------------------------------------------------------------------------------
/mocks/mock_cohort.go:
--------------------------------------------------------------------------------
1 | // Code generated by MockGen. DO NOT EDIT.
2 | // Source: github.com/vadiminshakov/committer/io/gateway/grpc/server (interfaces: Cohort)
3 | //
4 | // Generated by this command:
5 | //
6 | // mockgen -destination=../../../../mocks/mock_cohort.go -package=mocks . Cohort
7 | //
8 |
9 | // Package mocks is a generated GoMock package.
10 | package mocks
11 |
12 | import (
13 | context "context"
14 | reflect "reflect"
15 |
16 | dto "github.com/vadiminshakov/committer/core/dto"
17 | gomock "go.uber.org/mock/gomock"
18 | )
19 |
20 | // MockCohort is a mock of Cohort interface.
21 | type MockCohort struct {
22 | ctrl *gomock.Controller
23 | recorder *MockCohortMockRecorder
24 | isgomock struct{}
25 | }
26 |
27 | // MockCohortMockRecorder is the mock recorder for MockCohort.
28 | type MockCohortMockRecorder struct {
29 | mock *MockCohort
30 | }
31 |
32 | // NewMockCohort creates a new mock instance.
33 | func NewMockCohort(ctrl *gomock.Controller) *MockCohort {
34 | mock := &MockCohort{ctrl: ctrl}
35 | mock.recorder = &MockCohortMockRecorder{mock}
36 | return mock
37 | }
38 |
39 | // EXPECT returns an object that allows the caller to indicate expected use.
40 | func (m *MockCohort) EXPECT() *MockCohortMockRecorder {
41 | return m.recorder
42 | }
43 |
44 | // Abort mocks base method.
45 | func (m *MockCohort) Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) {
46 | m.ctrl.T.Helper()
47 | ret := m.ctrl.Call(m, "Abort", ctx, req)
48 | ret0, _ := ret[0].(*dto.CohortResponse)
49 | ret1, _ := ret[1].(error)
50 | return ret0, ret1
51 | }
52 |
53 | // Abort indicates an expected call of Abort.
54 | func (mr *MockCohortMockRecorder) Abort(ctx, req any) *gomock.Call {
55 | mr.mock.ctrl.T.Helper()
56 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abort", reflect.TypeOf((*MockCohort)(nil).Abort), ctx, req)
57 | }
58 |
59 | // Commit mocks base method.
60 | func (m *MockCohort) Commit(ctx context.Context, in *dto.CommitRequest) (*dto.CohortResponse, error) {
61 | m.ctrl.T.Helper()
62 | ret := m.ctrl.Call(m, "Commit", ctx, in)
63 | ret0, _ := ret[0].(*dto.CohortResponse)
64 | ret1, _ := ret[1].(error)
65 | return ret0, ret1
66 | }
67 |
68 | // Commit indicates an expected call of Commit.
69 | func (mr *MockCohortMockRecorder) Commit(ctx, in any) *gomock.Call {
70 | mr.mock.ctrl.T.Helper()
71 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockCohort)(nil).Commit), ctx, in)
72 | }
73 |
74 | // Height mocks base method.
75 | func (m *MockCohort) Height() uint64 {
76 | m.ctrl.T.Helper()
77 | ret := m.ctrl.Call(m, "Height")
78 | ret0, _ := ret[0].(uint64)
79 | return ret0
80 | }
81 |
82 | // Height indicates an expected call of Height.
83 | func (mr *MockCohortMockRecorder) Height() *gomock.Call {
84 | mr.mock.ctrl.T.Helper()
85 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Height", reflect.TypeOf((*MockCohort)(nil).Height))
86 | }
87 |
88 | // Precommit mocks base method.
89 | func (m *MockCohort) Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) {
90 | m.ctrl.T.Helper()
91 | ret := m.ctrl.Call(m, "Precommit", ctx, index)
92 | ret0, _ := ret[0].(*dto.CohortResponse)
93 | ret1, _ := ret[1].(error)
94 | return ret0, ret1
95 | }
96 |
97 | // Precommit indicates an expected call of Precommit.
98 | func (mr *MockCohortMockRecorder) Precommit(ctx, index any) *gomock.Call {
99 | mr.mock.ctrl.T.Helper()
100 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Precommit", reflect.TypeOf((*MockCohort)(nil).Precommit), ctx, index)
101 | }
102 |
103 | // Propose mocks base method.
104 | func (m *MockCohort) Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) {
105 | m.ctrl.T.Helper()
106 | ret := m.ctrl.Call(m, "Propose", ctx, req)
107 | ret0, _ := ret[0].(*dto.CohortResponse)
108 | ret1, _ := ret[1].(error)
109 | return ret0, ret1
110 | }
111 |
112 | // Propose indicates an expected call of Propose.
113 | func (mr *MockCohortMockRecorder) Propose(ctx, req any) *gomock.Call {
114 | mr.mock.ctrl.T.Helper()
115 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Propose", reflect.TypeOf((*MockCohort)(nil).Propose), ctx, req)
116 | }
117 |
--------------------------------------------------------------------------------
/mocks/mock_committer.go:
--------------------------------------------------------------------------------
1 | // Code generated by MockGen. DO NOT EDIT.
2 | // Source: github.com/vadiminshakov/committer/core/cohort (interfaces: Committer)
3 | //
4 | // Generated by this command:
5 | //
6 | // mockgen -destination=../../mocks/mock_committer.go -package=mocks . Committer
7 | //
8 |
9 | // Package mocks is a generated GoMock package.
10 | package mocks
11 |
12 | import (
13 | context "context"
14 | reflect "reflect"
15 |
16 | hooks "github.com/vadiminshakov/committer/core/cohort/commitalgo/hooks"
17 | dto "github.com/vadiminshakov/committer/core/dto"
18 | gomock "go.uber.org/mock/gomock"
19 | )
20 |
21 | // MockCommitter is a mock of Committer interface.
22 | type MockCommitter struct {
23 | ctrl *gomock.Controller
24 | recorder *MockCommitterMockRecorder
25 | isgomock struct{}
26 | }
27 |
28 | // MockCommitterMockRecorder is the mock recorder for MockCommitter.
29 | type MockCommitterMockRecorder struct {
30 | mock *MockCommitter
31 | }
32 |
33 | // NewMockCommitter creates a new mock instance.
34 | func NewMockCommitter(ctrl *gomock.Controller) *MockCommitter {
35 | mock := &MockCommitter{ctrl: ctrl}
36 | mock.recorder = &MockCommitterMockRecorder{mock}
37 | return mock
38 | }
39 |
40 | // EXPECT returns an object that allows the caller to indicate expected use.
41 | func (m *MockCommitter) EXPECT() *MockCommitterMockRecorder {
42 | return m.recorder
43 | }
44 |
45 | // Abort mocks base method.
46 | func (m *MockCommitter) Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) {
47 | m.ctrl.T.Helper()
48 | ret := m.ctrl.Call(m, "Abort", ctx, req)
49 | ret0, _ := ret[0].(*dto.CohortResponse)
50 | ret1, _ := ret[1].(error)
51 | return ret0, ret1
52 | }
53 |
54 | // Abort indicates an expected call of Abort.
55 | func (mr *MockCommitterMockRecorder) Abort(ctx, req any) *gomock.Call {
56 | mr.mock.ctrl.T.Helper()
57 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abort", reflect.TypeOf((*MockCommitter)(nil).Abort), ctx, req)
58 | }
59 |
60 | // Commit mocks base method.
61 | func (m *MockCommitter) Commit(ctx context.Context, req *dto.CommitRequest) (*dto.CohortResponse, error) {
62 | m.ctrl.T.Helper()
63 | ret := m.ctrl.Call(m, "Commit", ctx, req)
64 | ret0, _ := ret[0].(*dto.CohortResponse)
65 | ret1, _ := ret[1].(error)
66 | return ret0, ret1
67 | }
68 |
69 | // Commit indicates an expected call of Commit.
70 | func (mr *MockCommitterMockRecorder) Commit(ctx, req any) *gomock.Call {
71 | mr.mock.ctrl.T.Helper()
72 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockCommitter)(nil).Commit), ctx, req)
73 | }
74 |
75 | // Height mocks base method.
76 | func (m *MockCommitter) Height() uint64 {
77 | m.ctrl.T.Helper()
78 | ret := m.ctrl.Call(m, "Height")
79 | ret0, _ := ret[0].(uint64)
80 | return ret0
81 | }
82 |
83 | // Height indicates an expected call of Height.
84 | func (mr *MockCommitterMockRecorder) Height() *gomock.Call {
85 | mr.mock.ctrl.T.Helper()
86 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Height", reflect.TypeOf((*MockCommitter)(nil).Height))
87 | }
88 |
89 | // Precommit mocks base method.
90 | func (m *MockCommitter) Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) {
91 | m.ctrl.T.Helper()
92 | ret := m.ctrl.Call(m, "Precommit", ctx, index)
93 | ret0, _ := ret[0].(*dto.CohortResponse)
94 | ret1, _ := ret[1].(error)
95 | return ret0, ret1
96 | }
97 |
98 | // Precommit indicates an expected call of Precommit.
99 | func (mr *MockCommitterMockRecorder) Precommit(ctx, index any) *gomock.Call {
100 | mr.mock.ctrl.T.Helper()
101 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Precommit", reflect.TypeOf((*MockCommitter)(nil).Precommit), ctx, index)
102 | }
103 |
104 | // Propose mocks base method.
105 | func (m *MockCommitter) Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) {
106 | m.ctrl.T.Helper()
107 | ret := m.ctrl.Call(m, "Propose", ctx, req)
108 | ret0, _ := ret[0].(*dto.CohortResponse)
109 | ret1, _ := ret[1].(error)
110 | return ret0, ret1
111 | }
112 |
113 | // Propose indicates an expected call of Propose.
114 | func (mr *MockCommitterMockRecorder) Propose(ctx, req any) *gomock.Call {
115 | mr.mock.ctrl.T.Helper()
116 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Propose", reflect.TypeOf((*MockCommitter)(nil).Propose), ctx, req)
117 | }
118 |
119 | // RegisterHook mocks base method.
120 | func (m *MockCommitter) RegisterHook(hook hooks.Hook) {
121 | m.ctrl.T.Helper()
122 | m.ctrl.Call(m, "RegisterHook", hook)
123 | }
124 |
125 | // RegisterHook indicates an expected call of RegisterHook.
126 | func (mr *MockCommitterMockRecorder) RegisterHook(hook any) *gomock.Call {
127 | mr.mock.ctrl.T.Helper()
128 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterHook", reflect.TypeOf((*MockCommitter)(nil).RegisterHook), hook)
129 | }
130 |
--------------------------------------------------------------------------------
/io/store/store.go:
--------------------------------------------------------------------------------
1 | package store
2 |
3 | import (
4 | stdErrors "errors"
5 | "os"
6 | "strings"
7 | "sync"
8 |
9 | "github.com/dgraph-io/badger/v4"
10 | "github.com/pkg/errors"
11 | "github.com/vadiminshakov/committer/core/walrecord"
12 | "github.com/vadiminshakov/gowal"
13 | )
14 |
15 | // ErrNotFound returned when key does not exist in the store.
16 | var ErrNotFound = errors.New("key not found")
17 |
18 | // Store persists committed key/value pairs in BadgerDB and reconstructs them from WAL on startup.
19 | type Store struct {
20 | wal *gowal.Wal
21 | db *badger.DB
22 | mu sync.RWMutex
23 | }
24 |
25 | // Snapshot returns a shallow copy of the current state.
26 | func (s *Store) Snapshot() map[string][]byte {
27 | s.mu.RLock()
28 | defer s.mu.RUnlock()
29 |
30 | snapshot := make(map[string][]byte)
31 | _ = s.db.View(func(txn *badger.Txn) error {
32 | it := txn.NewIterator(badger.DefaultIteratorOptions)
33 | defer it.Close()
34 |
35 | for it.Rewind(); it.Valid(); it.Next() {
36 | item := it.Item()
37 | key := item.KeyCopy(nil)
38 | if err := item.Value(func(val []byte) error {
39 | snapshot[string(key)] = cloneBytes(val)
40 | return nil
41 | }); err != nil {
42 | return err
43 | }
44 | }
45 | return nil
46 | })
47 |
48 | return snapshot
49 | }
50 |
51 | // Size returns current number of keys in the store.
52 | func (s *Store) Size() int {
53 | s.mu.RLock()
54 | defer s.mu.RUnlock()
55 |
56 | count := 0
57 | _ = s.db.View(func(txn *badger.Txn) error {
58 | it := txn.NewIterator(badger.DefaultIteratorOptions)
59 | defer it.Close()
60 | for it.Rewind(); it.Valid(); it.Next() {
61 | count++
62 | }
63 | return nil
64 | })
65 |
66 | return count
67 | }
68 |
69 | // RecoveryState contains information extracted from WAL during startup.
70 | type RecoveryState struct {
71 | // Height is the current protocol height (next proposal should use this height).
72 | Height uint64
73 | }
74 |
75 | // New creates a new WAL-backed store and reconstructs state from WAL entries using BadgerDB.
76 | func New(wal *gowal.Wal, dbPath string) (*Store, *RecoveryState, error) {
77 | if wal == nil {
78 | return nil, nil, errors.New("wal is nil")
79 | }
80 | if dbPath == "" {
81 | return nil, nil, errors.New("db path is empty")
82 | }
83 |
84 | if err := os.MkdirAll(dbPath, 0o755); err != nil {
85 | return nil, nil, errors.Wrap(err, "create badger directory")
86 | }
87 |
88 | opts := badger.DefaultOptions(dbPath)
89 | db, err := badger.Open(opts)
90 | if err != nil {
91 | return nil, nil, errors.Wrap(err, "open badger db")
92 | }
93 |
94 | state := &Store{
95 | wal: wal,
96 | db: db,
97 | }
98 |
99 | recovery, err := state.recover()
100 | if err != nil {
101 | _ = db.Close()
102 | return nil, nil, err
103 | }
104 |
105 | return state, recovery, nil
106 | }
107 |
108 | // Put stores the provided value for the key.
109 | func (s *Store) Put(key string, value []byte) error {
110 | if key == "" {
111 | return errors.New("key cannot be empty")
112 | }
113 |
114 | s.mu.Lock()
115 | defer s.mu.Unlock()
116 |
117 | return s.db.Update(func(txn *badger.Txn) error {
118 | if value == nil {
119 | if err := txn.Delete([]byte(key)); err != nil && !stdErrors.Is(err, badger.ErrKeyNotFound) {
120 | return err
121 | }
122 | return nil
123 | }
124 | return txn.Set([]byte(key), cloneBytes(value))
125 | })
126 | }
127 |
128 | // Get retrieves value by key. Returns ErrNotFound if key does not exist.
129 | func (s *Store) Get(key string) ([]byte, error) {
130 | s.mu.RLock()
131 | defer s.mu.RUnlock()
132 |
133 | var result []byte
134 | err := s.db.View(func(txn *badger.Txn) error {
135 | item, err := txn.Get([]byte(key))
136 | if err != nil {
137 | if stdErrors.Is(err, badger.ErrKeyNotFound) {
138 | return ErrNotFound
139 | }
140 | return err
141 | }
142 | result, err = item.ValueCopy(nil)
143 | return err
144 | })
145 | if err != nil {
146 | return nil, err
147 | }
148 |
149 | return cloneBytes(result), nil
150 | }
151 |
152 | // Close closes the underlying Badger database.
153 | func (s *Store) Close() error {
154 | return s.db.Close()
155 | }
156 |
157 | func (s *Store) recover() (*RecoveryState, error) {
158 | var (
159 | hasProto bool
160 | maxProtoHeight uint64
161 | maxProtoStatus string // last seen status for maxProtoHeight
162 | )
163 |
164 | for msg := range s.wal.Iterator() {
165 | if msg.Key == "" || msg.Key == "skip" {
166 | continue
167 | }
168 |
169 | // check if it is a system/protocol key
170 | if strings.HasPrefix(msg.Key, "__tx:") {
171 | hasProto = true
172 | height := msg.Idx / walrecord.Stride
173 |
174 | // track max proto height and its status
175 | if height > maxProtoHeight {
176 | maxProtoHeight = height
177 | maxProtoStatus = msg.Key
178 | } else if height == maxProtoHeight {
179 | // if multiple messages for same height, commit/abort override prepared/precommit as "final" status
180 | if msg.Key == walrecord.KeyCommit || msg.Key == walrecord.KeyAbort {
181 | maxProtoStatus = msg.Key
182 | }
183 | }
184 |
185 | // apply strictly on commit
186 | if msg.Key == walrecord.KeyCommit {
187 | walTx, err := walrecord.Decode(msg.Value)
188 | if err != nil {
189 | return nil, errors.Wrapf(err, "decode wal tx at idx %d", msg.Idx)
190 | }
191 | if err := s.Put(walTx.Key, walTx.Value); err != nil {
192 | return nil, errors.Wrapf(err, "apply committed tx at idx %d", msg.Idx)
193 | }
194 | }
195 | }
196 | }
197 |
198 | var height uint64
199 | if hasProto {
200 | if maxProtoStatus == walrecord.KeyCommit || maxProtoStatus == walrecord.KeyAbort {
201 | height = maxProtoHeight + 1
202 | } else {
203 | // incomplete transaction at maxProtoHeight
204 | height = maxProtoHeight
205 | }
206 | }
207 |
208 | return &RecoveryState{Height: height}, nil
209 | }
210 |
211 | func cloneBytes(src []byte) []byte {
212 | if src == nil {
213 | return nil
214 | }
215 |
216 | dst := make([]byte, len(src))
217 | copy(dst, src)
218 | return dst
219 | }
220 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | [](https://pkg.go.dev/github.com/vadiminshakov/committer)
3 | [](https://goreportcard.com/report/github.com/vadiminshakov/committer)
4 | [](https://github.com/avelino/awesome-go)
5 |
6 |
7 |
8 |
9 |
10 | # **Committer**
11 |
12 | **Committer** is a Go implementation of the **Two-Phase Commit (2PC)** and **Three-Phase Commit (3PC)** protocols for distributed systems.
13 |
14 | ## **Architecture**
15 |
16 | The system consists of two types of nodes: **Coordinator** and **Cohorts**.
17 | The **Coordinator** is responsible for initiating and managing the commit protocols (2PC or 3PC), while the **Cohorts** participate in the protocol by responding to the coordinator's requests.
18 | The communication between nodes is handled using gRPC, and the state of each node is managed using a state machine.
19 |
20 | ## **Atomic Commit Protocols**
21 |
22 | ### **Two-Phase Commit (2PC)**
23 |
24 | The Two-Phase Commit protocol ensures atomicity in distributed transactions through two distinct phases:
25 |
26 | #### **Phase 1: Voting Phase (Propose)**
27 | 1. **Coordinator** sends a `PROPOSE` request to all cohorts with transaction data
28 | 2. Each **Cohort** validates the transaction locally and responds:
29 | - `ACK` (Yes) - if ready to commit
30 | - `NACK` (No) - if unable to commit
31 | 3. **Coordinator** waits for all responses
32 |
33 | #### **Phase 2: Commit Phase**
34 | 1. If **all cohorts** voted `ACK`:
35 | - **Coordinator** sends `COMMIT` to all cohorts
36 | - Each **Cohort** commits the transaction and responds with `ACK`
37 | 2. If **any cohort** voted `NACK`:
38 | - **Coordinator** sends `ABORT` to all cohorts
39 | - Each **Cohort** aborts the transaction
40 |
41 | ### **Three-Phase Commit (3PC)**
42 |
43 | The Three-Phase Commit protocol extends 2PC with an additional phase to reduce blocking scenarios:
44 |
45 | #### **Phase 1: Voting Phase (Propose)**
46 | 1. **Coordinator** sends `PROPOSE` request to all cohorts
47 | 2. **Cohorts** respond with `ACK`/`NACK` (same as 2PC)
48 |
49 | #### **Phase 2: Preparation Phase (Precommit)**
50 | 1. If all cohorts voted `ACK`:
51 | - **Coordinator** sends `PRECOMMIT` to all cohorts
52 | - **Cohorts** acknowledge they're prepared to commit
53 | - **Timeout mechanism**: If cohort doesn't receive `COMMIT` within timeout, it auto-commits
54 | 2. If any cohort voted `NACK`:
55 | - **Coordinator** sends `ABORT` to all cohorts
56 |
57 | #### **Phase 3: Commit Phase**
58 | 1. **Coordinator** sends `COMMIT` to all cohorts
59 | 2. **Cohorts** perform the actual commit operation
60 |
61 | ## **Configuration**
62 |
63 | All configuration parameters can be set using command-line flags:
64 |
65 | | **Flag** | **Description** | **Default** | **Example** |
66 | |-----------------|---------------------------------------------------------|---------------------|-------------------------------------|
67 | | `role` | Node role: `coordinator` or `cohort` | `cohort` | `-role=coordinator` |
68 | | `nodeaddr` | Address of the current node | `localhost:3050` | `-nodeaddr=localhost:3051` |
69 | | `coordinator` | Coordinator address (required for cohorts) | `""` | `-coordinator=localhost:3050` |
70 | | `committype` | Commit protocol: `two-phase` or `three-phase` | `three-phase` | `-committype=two-phase` |
71 | | `timeout` | Timeout (ms) for unacknowledged messages (3PC only) | `1000` | `-timeout=500` |
72 | | `dbpath` | Path to the BadgerDB database on the filesystem | `./badger` | `-dbpath=/tmp/badger` |
73 | | `cohorts` | Comma-separated list of cohort addresses | `""` | `-cohorts=localhost:3052,3053` |
74 | | `whitelist` | Comma-separated list of allowed hosts | `127.0.0.1` | `-whitelist=192.168.0.1,192.168.0.2`|
75 |
76 |
77 | ## **Usage**
78 |
79 | ### **Running as a Cohort**
80 | ```bash
81 | ./committer -role=cohort -nodeaddr=localhost:3001 -coordinator=localhost:3000 -committype=three-phase -timeout=1000 -dbpath=/tmp/badger/cohort
82 | ```
83 |
84 | ### **Running as a Coordinator**
85 | ```bash
86 | ./committer -role=coordinator -nodeaddr=localhost:3000 -cohorts=localhost:3001 -committype=three-phase -timeout=1000 -dbpath=/tmp/badger/coordinator
87 | ```
88 |
89 | ## **Hooks System**
90 |
91 | The hooks system allows you to add custom validation and business logic during the **Propose** and **Commit** stages without modifying the core code. Hooks are executed in the order they were registered, and if any hook returns `false`, the operation is rejected.
92 |
93 | ```go
94 | // Default usage (with built-in default hook)
95 | committer := commitalgo.NewCommitter(database, "three-phase", wal, timeout)
96 |
97 | // With custom hooks
98 | metricsHook := hooks.NewMetricsHook()
99 | validationHook := hooks.NewValidationHook(100, 1024)
100 | auditHook := hooks.NewAuditHook("audit.log")
101 |
102 | committer := commitalgo.NewCommitter(database, "three-phase", wal, timeout,
103 | metricsHook,
104 | validationHook,
105 | auditHook,
106 | )
107 |
108 | // Dynamic registration
109 | committer.RegisterHook(myCustomHook)
110 | ```
111 |
112 | ## **Testing**
113 |
114 | ### **Run Functional Tests**
115 | ```bash
116 | make tests
117 | ```
118 |
119 | ### **Testing with Example Client**
120 | 1. Compile executables:
121 |
122 | ```bash
123 | make prepare
124 | ```
125 |
126 | 2. Run the coordinator:
127 |
128 | ```bash
129 | make run-example-coordinator
130 | ```
131 |
132 | 3. Run a cohort in another terminal:
133 |
134 | ```bash
135 | make run-example-cohort
136 | ```
137 |
138 | 4. Start the example client:
139 |
140 | ```bash
141 | make run-example-client
142 | ```
143 |
144 | Or directly:
145 |
146 | ```bash
147 | go run ./examples/client/client.go
148 | ```
149 |
150 | ## **Contributions**
151 |
152 | Contributions are welcome! Feel free to submit a PR or open an issue if you find bugs or have suggestions for improvement.
153 |
154 | ## **License**
155 |
156 | This project is licensed under the [Apache License](LICENSE).
--------------------------------------------------------------------------------
/toxiproxy_test.go:
--------------------------------------------------------------------------------
1 | //go:build chaos
2 |
3 | package main
4 |
5 | import (
6 | "bytes"
7 | "encoding/json"
8 | "fmt"
9 | "io"
10 | "net/http"
11 | "strconv"
12 | "strings"
13 | "time"
14 | )
15 |
16 | type toxiproxyClient struct {
17 | BaseURL string
18 | HTTPClient *http.Client
19 | }
20 |
21 | type proxy struct {
22 | Name string `json:"name"`
23 | Listen string `json:"listen"`
24 | Upstream string `json:"upstream"`
25 | Enabled bool `json:"enabled"`
26 | }
27 |
28 | type toxic struct {
29 | Name string `json:"name"`
30 | Type string `json:"type"`
31 | Stream string `json:"stream"`
32 | Toxicity float32 `json:"toxicity"`
33 | Attributes map[string]interface{} `json:"attributes"`
34 | }
35 |
36 | func newToxiproxyClient(baseURL string) *toxiproxyClient {
37 | return &toxiproxyClient{
38 | BaseURL: baseURL,
39 | HTTPClient: &http.Client{Timeout: 10 * time.Second},
40 | }
41 | }
42 |
43 | func (c *toxiproxyClient) createProxy(name, listen, upstream string) (*proxy, error) {
44 | p := &proxy{
45 | Name: name,
46 | Listen: listen,
47 | Upstream: upstream,
48 | Enabled: true,
49 | }
50 |
51 | data, err := json.Marshal(p)
52 | if err != nil {
53 | return nil, err
54 | }
55 |
56 | resp, err := c.HTTPClient.Post(c.BaseURL+"/proxies", "application/json", bytes.NewBuffer(data))
57 | if err != nil {
58 | return nil, err
59 | }
60 | defer resp.Body.Close()
61 |
62 | body, _ := io.ReadAll(resp.Body)
63 | if resp.StatusCode != http.StatusCreated {
64 | return nil, fmt.Errorf("failed to create proxy (status %d): %s", resp.StatusCode, string(body))
65 | }
66 |
67 | var createdProxy proxy
68 | if err := json.Unmarshal(body, &createdProxy); err != nil {
69 | return nil, fmt.Errorf("failed to parse created proxy response: %v, body: %s", err, string(body))
70 | }
71 |
72 | return &createdProxy, nil
73 | }
74 |
75 | func (c *toxiproxyClient) addToxic(proxyName string, toxic *toxic) error {
76 | data, err := json.Marshal(toxic)
77 | if err != nil {
78 | return err
79 | }
80 |
81 | url := fmt.Sprintf("%s/proxies/%s/toxics", c.BaseURL, proxyName)
82 | resp, err := c.HTTPClient.Post(url, "application/json", bytes.NewBuffer(data))
83 | if err != nil {
84 | return err
85 | }
86 | defer resp.Body.Close()
87 |
88 | body, _ := io.ReadAll(resp.Body)
89 | if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
90 | return fmt.Errorf("failed to add toxic to proxy '%s' (status %d): %s. Request: %s", proxyName, resp.StatusCode, string(body), string(data))
91 | }
92 |
93 | return nil
94 | }
95 |
96 | func (c *toxiproxyClient) deleteProxy(name string) error {
97 | url := fmt.Sprintf("%s/proxies/%s", c.BaseURL, name)
98 | req, err := http.NewRequest("DELETE", url, nil)
99 | if err != nil {
100 | return err
101 | }
102 |
103 | resp, err := c.HTTPClient.Do(req)
104 | if err != nil {
105 | return err
106 | }
107 | defer resp.Body.Close()
108 |
109 | if resp.StatusCode != http.StatusNoContent {
110 | body, _ := io.ReadAll(resp.Body)
111 | return fmt.Errorf("failed to delete proxy: %s", string(body))
112 | }
113 |
114 | return nil
115 | }
116 |
117 | func (c *toxiproxyClient) resetProxy(name string) error {
118 | url := fmt.Sprintf("%s/proxies/%s/toxics", c.BaseURL, name)
119 | req, err := http.NewRequest("DELETE", url, nil)
120 | if err != nil {
121 | return err
122 | }
123 |
124 | resp, err := c.HTTPClient.Do(req)
125 | if err != nil {
126 | return err
127 | }
128 | defer resp.Body.Close()
129 |
130 | return nil
131 | }
132 |
133 | type chaosTestHelper struct {
134 | client *toxiproxyClient
135 | proxies map[string]*proxy
136 | proxyMapping map[string]string // original -> proxy address
137 | }
138 |
139 | func newChaosTestHelper(toxiproxyURL string) *chaosTestHelper {
140 | return &chaosTestHelper{
141 | client: newToxiproxyClient(toxiproxyURL),
142 | proxies: make(map[string]*proxy),
143 | proxyMapping: make(map[string]string),
144 | }
145 | }
146 |
147 | func (h *chaosTestHelper) setupProxies(nodeAddresses []string) error {
148 | for i, addr := range nodeAddresses {
149 | proxyName := fmt.Sprintf("node_%d", i)
150 | proxyAddr := h.generateProxyAddress(addr)
151 |
152 | proxy, err := h.client.createProxy(proxyName, proxyAddr, addr)
153 | if err != nil {
154 | return fmt.Errorf("failed to create proxy for %s: %v", addr, err)
155 | }
156 |
157 | h.proxies[proxyName] = proxy
158 | h.proxyMapping[addr] = proxyAddr
159 | }
160 |
161 | return nil
162 | }
163 |
164 | // addResetPeer simulates TCP RESET after optional timeout
165 | func (h *chaosTestHelper) addResetPeer(nodeAddr string, timeout time.Duration) error {
166 | proxyName := h.getProxyName(nodeAddr)
167 | if proxyName == "" {
168 | return fmt.Errorf("proxy not found for address %s", nodeAddr)
169 | }
170 |
171 | toxic := &toxic{
172 | Name: fmt.Sprintf("reset_peer_%s", proxyName),
173 | Type: "reset_peer",
174 | Stream: "downstream",
175 | Toxicity: 1.0,
176 | Attributes: map[string]interface{}{
177 | "timeout": int(timeout.Milliseconds()),
178 | },
179 | }
180 |
181 | return h.client.addToxic(proxyName, toxic)
182 | }
183 |
184 | // addDataLimit closes connection after transmitting specified bytes
185 | func (h *chaosTestHelper) addDataLimit(nodeAddr string, bytes int) error {
186 | proxyName := h.getProxyName(nodeAddr)
187 | if proxyName == "" {
188 | return fmt.Errorf("proxy not found for address %s", nodeAddr)
189 | }
190 |
191 | toxic := &toxic{
192 | Name: fmt.Sprintf("limit_data_%s", proxyName),
193 | Type: "limit_data",
194 | Stream: "downstream",
195 | Toxicity: 1.0,
196 | Attributes: map[string]interface{}{
197 | "bytes": bytes,
198 | },
199 | }
200 |
201 | return h.client.addToxic(proxyName, toxic)
202 | }
203 |
204 | func (h *chaosTestHelper) getProxyAddress(originalAddr string) string {
205 | return h.proxyMapping[originalAddr]
206 | }
207 |
208 | func (h *chaosTestHelper) cleanup() error {
209 | for proxyName := range h.proxies {
210 | if err := h.client.deleteProxy(proxyName); err != nil {
211 | return err
212 | }
213 | }
214 |
215 | h.proxies = make(map[string]*proxy)
216 | h.proxyMapping = make(map[string]string)
217 | return nil
218 | }
219 |
220 | func (h *chaosTestHelper) generateProxyAddress(originalAddr string) string {
221 | parts := strings.Split(originalAddr, ":")
222 | if len(parts) != 2 {
223 | return originalAddr
224 | }
225 |
226 | port, err := strconv.Atoi(parts[1])
227 | if err != nil {
228 | return originalAddr
229 | }
230 |
231 | proxyPort := port + 10000
232 | return fmt.Sprintf("%s:%d", parts[0], proxyPort)
233 | }
234 |
235 | func (h *chaosTestHelper) getProxyName(nodeAddr string) string {
236 | for name, proxy := range h.proxies {
237 | if proxy.Upstream == nodeAddr {
238 | return name
239 | }
240 | }
241 | return ""
242 | }
243 |
--------------------------------------------------------------------------------
/core/cohort/cohort_test.go:
--------------------------------------------------------------------------------
1 | package cohort
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/require"
9 | "github.com/vadiminshakov/committer/core/dto"
10 | "github.com/vadiminshakov/committer/mocks"
11 | "go.uber.org/mock/gomock"
12 | )
13 |
14 | func TestNewCohort(t *testing.T) {
15 | ctrl := gomock.NewController(t)
16 | defer ctrl.Finish()
17 |
18 | mockCommitter := mocks.NewMockCommitter(ctrl)
19 |
20 | // test creating 2PC cohort
21 | cohort := NewCohort(mockCommitter, "two-phase")
22 | require.NotNil(t, cohort)
23 | require.Equal(t, Mode("two-phase"), cohort.commitType)
24 |
25 | // test creating 3PC cohort
26 | cohort3PC := NewCohort(mockCommitter, THREE_PHASE)
27 | require.NotNil(t, cohort3PC)
28 | require.Equal(t, THREE_PHASE, cohort3PC.commitType)
29 | }
30 |
31 | func TestCohort_Height(t *testing.T) {
32 | ctrl := gomock.NewController(t)
33 | defer ctrl.Finish()
34 |
35 | mockCommitter := mocks.NewMockCommitter(ctrl)
36 | cohort := NewCohort(mockCommitter, "two-phase")
37 |
38 | // expect Height to be called and return 5
39 | mockCommitter.EXPECT().Height().Return(uint64(5))
40 |
41 | height := cohort.Height()
42 | require.Equal(t, uint64(5), height)
43 | }
44 |
45 | func TestCohort_Propose(t *testing.T) {
46 | ctrl := gomock.NewController(t)
47 | defer ctrl.Finish()
48 |
49 | mockCommitter := mocks.NewMockCommitter(ctrl)
50 | cohort := NewCohort(mockCommitter, "two-phase")
51 |
52 | ctx := context.Background()
53 | proposeReq := &dto.ProposeRequest{
54 | Height: 0,
55 | Key: "test-key",
56 | Value: []byte("test-value"),
57 | }
58 |
59 | expectedResp := &dto.CohortResponse{
60 | ResponseType: dto.ResponseTypeAck,
61 | Height: 0,
62 | }
63 |
64 | // expect Propose to be called and return success
65 | mockCommitter.EXPECT().Propose(ctx, proposeReq).Return(expectedResp, nil)
66 |
67 | resp, err := cohort.Propose(ctx, proposeReq)
68 | require.NoError(t, err)
69 | require.NotNil(t, resp)
70 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
71 | require.Equal(t, uint64(0), resp.Height)
72 | }
73 |
74 | func TestCohort_Propose_Error(t *testing.T) {
75 | ctrl := gomock.NewController(t)
76 | defer ctrl.Finish()
77 |
78 | mockCommitter := mocks.NewMockCommitter(ctrl)
79 | cohort := NewCohort(mockCommitter, "two-phase")
80 |
81 | ctx := context.Background()
82 | proposeReq := &dto.ProposeRequest{
83 | Height: 0,
84 | Key: "test-key",
85 | Value: []byte("test-value"),
86 | }
87 |
88 | expectedErr := fmt.Errorf("propose failed")
89 |
90 | // expect Propose to be called and return error
91 | mockCommitter.EXPECT().Propose(ctx, proposeReq).Return(nil, expectedErr)
92 |
93 | resp, err := cohort.Propose(ctx, proposeReq)
94 | require.Error(t, err)
95 | require.Nil(t, resp)
96 | require.Equal(t, expectedErr, err)
97 | }
98 |
99 | func TestCohort_Precommit_TwoPhase(t *testing.T) {
100 | ctrl := gomock.NewController(t)
101 | defer ctrl.Finish()
102 |
103 | mockCommitter := mocks.NewMockCommitter(ctrl)
104 | cohort := NewCohort(mockCommitter, "two-phase")
105 |
106 | ctx := context.Background()
107 |
108 | // test precommit in 2PC mode (should fail)
109 | _, err := cohort.Precommit(ctx, 0)
110 | require.Error(t, err)
111 | require.Contains(t, err.Error(), "precommit is allowed for 3PC mode only")
112 | }
113 |
114 | func TestCohort_Precommit_ThreePhase(t *testing.T) {
115 | ctrl := gomock.NewController(t)
116 | defer ctrl.Finish()
117 |
118 | mockCommitter := mocks.NewMockCommitter(ctrl)
119 | cohort := NewCohort(mockCommitter, THREE_PHASE)
120 |
121 | ctx := context.Background()
122 |
123 | expectedResp := &dto.CohortResponse{
124 | ResponseType: dto.ResponseTypeAck,
125 | }
126 |
127 | // expect Precommit to be called and return success
128 | mockCommitter.EXPECT().Precommit(ctx, uint64(0)).Return(expectedResp, nil)
129 |
130 | // test precommit in 3PC mode (should succeed)
131 | resp, err := cohort.Precommit(ctx, 0)
132 | require.NoError(t, err)
133 | require.NotNil(t, resp)
134 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
135 | }
136 |
137 | func TestCohort_Precommit_ThreePhase_Error(t *testing.T) {
138 | ctrl := gomock.NewController(t)
139 | defer ctrl.Finish()
140 |
141 | mockCommitter := mocks.NewMockCommitter(ctrl)
142 | cohort := NewCohort(mockCommitter, THREE_PHASE)
143 |
144 | ctx := context.Background()
145 | expectedErr := fmt.Errorf("precommit failed")
146 |
147 | // expect Precommit to be called and return error
148 | mockCommitter.EXPECT().Precommit(ctx, uint64(0)).Return(nil, expectedErr)
149 |
150 | resp, err := cohort.Precommit(ctx, 0)
151 | require.Error(t, err)
152 | require.Nil(t, resp)
153 | require.Equal(t, expectedErr, err)
154 | }
155 |
156 | func TestCohort_Commit(t *testing.T) {
157 | ctrl := gomock.NewController(t)
158 | defer ctrl.Finish()
159 |
160 | mockCommitter := mocks.NewMockCommitter(ctrl)
161 | cohort := NewCohort(mockCommitter, "two-phase")
162 |
163 | ctx := context.Background()
164 | commitReq := &dto.CommitRequest{Height: 0}
165 |
166 | expectedResp := &dto.CohortResponse{
167 | ResponseType: dto.ResponseTypeAck,
168 | }
169 |
170 | // expect Commit to be called and return success
171 | mockCommitter.EXPECT().Commit(ctx, commitReq).Return(expectedResp, nil)
172 |
173 | resp, err := cohort.Commit(ctx, commitReq)
174 | require.NoError(t, err)
175 | require.NotNil(t, resp)
176 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
177 | }
178 |
179 | func TestCohort_Commit_Error(t *testing.T) {
180 | ctrl := gomock.NewController(t)
181 | defer ctrl.Finish()
182 |
183 | mockCommitter := mocks.NewMockCommitter(ctrl)
184 | cohort := NewCohort(mockCommitter, "two-phase")
185 |
186 | ctx := context.Background()
187 | commitReq := &dto.CommitRequest{Height: 0}
188 | expectedErr := fmt.Errorf("commit failed")
189 |
190 | // expect Commit to be called and return error
191 | mockCommitter.EXPECT().Commit(ctx, commitReq).Return(nil, expectedErr)
192 |
193 | resp, err := cohort.Commit(ctx, commitReq)
194 | require.Error(t, err)
195 | require.Nil(t, resp)
196 | require.Equal(t, expectedErr, err)
197 | }
198 |
199 | func TestCohort_Commit_Nack(t *testing.T) {
200 | ctrl := gomock.NewController(t)
201 | defer ctrl.Finish()
202 |
203 | mockCommitter := mocks.NewMockCommitter(ctrl)
204 | cohort := NewCohort(mockCommitter, "two-phase")
205 |
206 | ctx := context.Background()
207 | commitReq := &dto.CommitRequest{Height: 0}
208 |
209 | expectedResp := &dto.CohortResponse{
210 | ResponseType: dto.ResponseTypeNack,
211 | }
212 |
213 | // expect Commit to be called and return NACK
214 | mockCommitter.EXPECT().Commit(ctx, commitReq).Return(expectedResp, nil)
215 |
216 | resp, err := cohort.Commit(ctx, commitReq)
217 | require.NoError(t, err)
218 | require.NotNil(t, resp)
219 | require.Equal(t, dto.ResponseTypeNack, resp.ResponseType)
220 | }
221 |
222 | func TestCohort_ModeValidation(t *testing.T) {
223 | ctrl := gomock.NewController(t)
224 | defer ctrl.Finish()
225 |
226 | mockCommitter := mocks.NewMockCommitter(ctrl)
227 |
228 | // test different modes
229 | modes := []Mode{"two-phase", THREE_PHASE, "custom-mode"}
230 |
231 | for _, mode := range modes {
232 | cohort := NewCohort(mockCommitter, mode)
233 | require.NotNil(t, cohort)
234 | require.Equal(t, mode, cohort.commitType)
235 | }
236 | }
237 |
--------------------------------------------------------------------------------
/core/coordinator/coordinator_test.go:
--------------------------------------------------------------------------------
1 | package coordinator
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync/atomic"
7 | "testing"
8 | "time"
9 |
10 | "github.com/stretchr/testify/require"
11 | "github.com/vadiminshakov/committer/config"
12 | "github.com/vadiminshakov/committer/core/dto"
13 | "github.com/vadiminshakov/committer/core/walrecord"
14 | "github.com/vadiminshakov/committer/io/gateway/grpc/server"
15 | "github.com/vadiminshakov/committer/mocks"
16 | "go.uber.org/mock/gomock"
17 | )
18 |
19 | func TestCoordinator_New(t *testing.T) {
20 | ctrl := gomock.NewController(t)
21 | defer ctrl.Finish()
22 |
23 | mockWAL := mocks.NewMockwal(ctrl)
24 | mockStore := mocks.NewMockStateStore(ctrl)
25 |
26 | conf := &config.Config{
27 | Nodeaddr: "localhost:8080",
28 | Role: "coordinator",
29 | Cohorts: []string{"localhost:8081", "localhost:8082"},
30 | CommitType: server.TWO_PHASE,
31 | }
32 |
33 | coord, err := New(conf, mockWAL, mockStore)
34 | require.NoError(t, err)
35 | require.NotNil(t, coord)
36 | require.Equal(t, uint64(0), coord.Height())
37 | require.Len(t, coord.cohorts, 2)
38 | }
39 |
40 | func TestCoordinator_Height(t *testing.T) {
41 | ctrl := gomock.NewController(t)
42 | defer ctrl.Finish()
43 |
44 | mockWAL := mocks.NewMockwal(ctrl)
45 | mockStore := mocks.NewMockStateStore(ctrl)
46 |
47 | conf := &config.Config{
48 | Nodeaddr: "localhost:8080",
49 | Role: "coordinator",
50 | Cohorts: []string{},
51 | CommitType: server.TWO_PHASE,
52 | }
53 |
54 | coord, err := New(conf, mockWAL, mockStore)
55 | require.NoError(t, err)
56 |
57 | // test initial height
58 | require.Equal(t, uint64(0), coord.Height())
59 |
60 | // test height increment
61 | atomic.AddUint64(&coord.height, 1)
62 | require.Equal(t, uint64(1), coord.Height())
63 | }
64 |
65 | func TestCoordinator_SyncHeight(t *testing.T) {
66 | ctrl := gomock.NewController(t)
67 | defer ctrl.Finish()
68 |
69 | mockWAL := mocks.NewMockwal(ctrl)
70 | mockStore := mocks.NewMockStateStore(ctrl)
71 |
72 | conf := &config.Config{
73 | Nodeaddr: "localhost:8080",
74 | Role: "coordinator",
75 | Cohorts: []string{},
76 | CommitType: server.TWO_PHASE,
77 | }
78 |
79 | coord, err := New(conf, mockWAL, mockStore)
80 | require.NoError(t, err)
81 |
82 | // test sync to higher height
83 | coord.syncHeight(5)
84 | require.Equal(t, uint64(5), coord.Height())
85 |
86 | // test sync to lower height (should not change)
87 | coord.syncHeight(3)
88 | require.Equal(t, uint64(5), coord.Height())
89 |
90 | // test sync to same height (should not change)
91 | coord.syncHeight(5)
92 | require.Equal(t, uint64(5), coord.Height())
93 | }
94 |
95 | func TestCoordinator_PersistMessage(t *testing.T) {
96 | ctrl := gomock.NewController(t)
97 | defer ctrl.Finish()
98 |
99 | mockWAL := mocks.NewMockwal(ctrl)
100 | mockStore := mocks.NewMockStateStore(ctrl)
101 |
102 | conf := &config.Config{
103 | Nodeaddr: "localhost:8080",
104 | Role: "coordinator",
105 | Cohorts: []string{},
106 | CommitType: server.TWO_PHASE,
107 | }
108 |
109 | coord, err := New(conf, mockWAL, mockStore)
110 | require.NoError(t, err)
111 |
112 | // test persist message success
113 | testKey := "test-key"
114 | testValue := []byte("test-value")
115 |
116 | // expect WAL.Get to return the test data
117 | // expect WAL.Get to return the test data
118 | ptx := walrecord.WalTx{Key: testKey, Value: testValue}
119 | encoded, _ := walrecord.Encode(ptx)
120 | mockWAL.EXPECT().Get(walrecord.PreparedSlot(0)).Return(walrecord.KeyPrepared, encoded, nil)
121 |
122 | // expect DB.Put to be called with the test data
123 | mockStore.EXPECT().Put(testKey, testValue).Return(nil)
124 |
125 | err = coord.persistMessage()
126 | require.NoError(t, err)
127 | }
128 |
129 | func TestCoordinator_PersistMessage_NoDataInWAL(t *testing.T) {
130 | ctrl := gomock.NewController(t)
131 | defer ctrl.Finish()
132 |
133 | mockWAL := mocks.NewMockwal(ctrl)
134 | mockStore := mocks.NewMockStateStore(ctrl)
135 |
136 | conf := &config.Config{
137 | Nodeaddr: "localhost:8080",
138 | Role: "coordinator",
139 | Cohorts: []string{},
140 | CommitType: server.TWO_PHASE,
141 | }
142 |
143 | coord, err := New(conf, mockWAL, mockStore)
144 | require.NoError(t, err)
145 |
146 | // expect WAL.Get to return no data
147 | mockWAL.EXPECT().Get(walrecord.PreparedSlot(0)).Return("", nil, nil)
148 |
149 | // test persist message when no data in WAL
150 | err = coord.persistMessage()
151 | require.Error(t, err)
152 | require.Contains(t, err.Error(), "can't find msg in wal")
153 | }
154 |
155 | func TestCoordinator_Abort(t *testing.T) {
156 | ctrl := gomock.NewController(t)
157 | defer ctrl.Finish()
158 |
159 | mockWAL := mocks.NewMockwal(ctrl)
160 | mockStore := mocks.NewMockStateStore(ctrl)
161 |
162 | conf := &config.Config{
163 | Nodeaddr: "localhost:8080",
164 | Role: "coordinator",
165 | Cohorts: []string{},
166 | CommitType: server.TWO_PHASE,
167 | }
168 |
169 | coord, err := New(conf, mockWAL, mockStore)
170 | require.NoError(t, err)
171 |
172 | ctx := context.Background()
173 |
174 | // expect WAL abort write
175 | mockWAL.EXPECT().Write(walrecord.AbortSlot(0), walrecord.KeyAbort, gomock.Any()).Return(nil)
176 |
177 | // test abort
178 | done := make(chan bool, 1)
179 | go func() {
180 | coord.abort(ctx, "test abort reason")
181 | done <- true
182 | }()
183 |
184 | select {
185 | case <-done:
186 | // abort completed successfully
187 | case <-time.After(1 * time.Second):
188 | t.Fatal("abort took too long to complete")
189 | }
190 | }
191 |
192 | func TestCoordinator_SyncHeight_Concurrent(t *testing.T) {
193 | ctrl := gomock.NewController(t)
194 | defer ctrl.Finish()
195 |
196 | mockWAL := mocks.NewMockwal(ctrl)
197 | mockStore := mocks.NewMockStateStore(ctrl)
198 |
199 | conf := &config.Config{
200 | Nodeaddr: "localhost:8080",
201 | Role: "coordinator",
202 | Cohorts: []string{},
203 | CommitType: server.TWO_PHASE,
204 | }
205 |
206 | coord, err := New(conf, mockWAL, mockStore)
207 | require.NoError(t, err)
208 |
209 | // test concurrent height sync
210 | done := make(chan bool, 10)
211 |
212 | for i := 0; i < 10; i++ {
213 | go func(height uint64) {
214 | coord.syncHeight(height)
215 | done <- true
216 | }(uint64(i + 1))
217 | }
218 |
219 | // wait for all goroutines to complete
220 | for i := 0; i < 10; i++ {
221 | <-done
222 | }
223 |
224 | // height should be the maximum value
225 | require.Equal(t, uint64(10), coord.Height())
226 | }
227 |
228 | func TestNackResponse(t *testing.T) {
229 | testErr := fmt.Errorf("test error")
230 | testMsg := "test message"
231 |
232 | resp, err := nackResponse(testErr, testMsg)
233 |
234 | require.NotNil(t, resp)
235 | require.Equal(t, dto.ResponseTypeNack, resp.Type)
236 | require.Error(t, err)
237 | require.Contains(t, err.Error(), testMsg)
238 | require.Contains(t, err.Error(), "test error")
239 | }
240 |
241 | func TestCoordinator_Config_Validation(t *testing.T) {
242 | ctrl := gomock.NewController(t)
243 | defer ctrl.Finish()
244 |
245 | mockWAL := mocks.NewMockwal(ctrl)
246 | mockStore := mocks.NewMockStateStore(ctrl)
247 |
248 | // test with empty cohorts list
249 | conf := &config.Config{
250 | Nodeaddr: "localhost:8080",
251 | Role: "coordinator",
252 | Cohorts: []string{},
253 | CommitType: server.TWO_PHASE,
254 | }
255 |
256 | coord, err := New(conf, mockWAL, mockStore)
257 | require.NoError(t, err)
258 | require.NotNil(t, coord)
259 | require.Len(t, coord.cohorts, 0)
260 |
261 | // test with multiple cohorts
262 | conf.Cohorts = []string{"localhost:8081", "localhost:8082", "localhost:8083"}
263 | coord, err = New(conf, mockWAL, mockStore)
264 | require.NoError(t, err)
265 | require.NotNil(t, coord)
266 | require.Len(t, coord.cohorts, 3)
267 | }
268 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
2 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
5 | github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs=
6 | github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w=
7 | github.com/dgraph-io/ristretto/v2 v2.2.0 h1:bkY3XzJcXoMuELV8F+vS8kzNgicwQFAaGINAEJdWGOM=
8 | github.com/dgraph-io/ristretto/v2 v2.2.0/go.mod h1:RZrm63UmcBAaYWC1DotLYBmTvgkrs0+XhBd7Npn7/zI=
9 | github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38=
10 | github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
11 | github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
12 | github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
13 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
14 | github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
15 | github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
16 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
17 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
18 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
19 | github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
20 | github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
21 | github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q=
22 | github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
23 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
24 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
25 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
26 | github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
27 | github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
28 | github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
29 | github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
30 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
31 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
32 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
33 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
34 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
35 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
36 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
37 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
38 | github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
39 | github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
40 | github.com/sirupsen/logrus v1.2.0 h1:juTguoYk5qI21pwyTXY3B3Y5cOTH3ZUyZCg1v/mihuo=
41 | github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
42 | github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
43 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
44 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
45 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
46 | github.com/vadiminshakov/gowal v0.0.4 h1:99mJRkHL7GFNo7f6TMXkUEkRdeVeNIpmywkbVnPaPMM=
47 | github.com/vadiminshakov/gowal v0.0.4/go.mod h1:NMvNH0xjRPYSrcaIg/JhqB5uTneFdgo45+Fs2sk+Esc=
48 | github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
49 | github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
50 | github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
51 | github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
52 | go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
53 | go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
54 | go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
55 | go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
56 | go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
57 | go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
58 | go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
59 | go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
60 | go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
61 | go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
62 | golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
63 | golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
64 | golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
65 | golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
66 | golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
67 | golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
68 | golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
69 | golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
70 | golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
71 | golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
72 | golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
73 | golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
74 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
75 | google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e h1:halCgTFuLWDRD61piiNSxPsARANGD3Xl16hPrLgLiIg=
76 | google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e/go.mod h1:3526vdqwhZAwq4wsRUaVG555sVgsNmIjRtO7t/JH29U=
77 | google.golang.org/grpc v1.50.0 h1:fPVVDxY9w++VjTZsYvXWqEf9Rqar/e+9zYfxKK+W+YU=
78 | google.golang.org/grpc v1.50.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
79 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
80 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
81 | google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
82 | google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
83 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
84 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
85 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
86 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
87 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
88 |
--------------------------------------------------------------------------------
/io/gateway/grpc/server/server.go:
--------------------------------------------------------------------------------
1 | // Package server provides gRPC server implementation for the committer service.
2 | //
3 | // This package implements both internal commit API for node-to-node communication
4 | // and client API for external interactions with the distributed consensus system.
5 | package server
6 |
7 | import (
8 | "context"
9 | "errors"
10 | "net"
11 | "time"
12 |
13 | log "github.com/sirupsen/logrus"
14 | "github.com/vadiminshakov/committer/config"
15 | "github.com/vadiminshakov/committer/core/dto"
16 | "github.com/vadiminshakov/committer/io/gateway/grpc/proto"
17 | "github.com/vadiminshakov/committer/io/store"
18 | "google.golang.org/grpc"
19 | "google.golang.org/grpc/codes"
20 | "google.golang.org/grpc/status"
21 | "google.golang.org/protobuf/types/known/emptypb"
22 | )
23 |
24 | const (
25 | // TWO_PHASE represents the two-phase commit protocol.
26 | TWO_PHASE = "two-phase"
27 | // THREE_PHASE represents the three-phase commit protocol.
28 | THREE_PHASE = "three-phase"
29 | )
30 |
31 | // Coordinator defines the interface for coordinator operations.
32 | type Coordinator interface {
33 | Broadcast(ctx context.Context, req dto.BroadcastRequest) (*dto.BroadcastResponse, error)
34 | Height() uint64
35 | SetHeight(height uint64)
36 | }
37 |
38 | // Cohort defines the interface for cohort operations.
39 | //
40 | //go:generate mockgen -destination=../../../../mocks/mock_cohort.go -package=mocks . Cohort
41 | type Cohort interface {
42 | Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error)
43 | Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error)
44 | Commit(ctx context.Context, in *dto.CommitRequest) (*dto.CohortResponse, error)
45 | Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error)
46 | Height() uint64
47 | }
48 |
49 | // Server holds server instance, node config and connections to followers (if it's a coordinator node).
50 | type Server struct {
51 | proto.UnimplementedInternalCommitAPIServer
52 | proto.UnimplementedClientAPIServer
53 |
54 | cohort Cohort // Cohort implementation for this node
55 | store *store.Store // Persistent storage
56 | coordinator Coordinator // Coordinator implementation (if this node is a coordinator)
57 | GRPCServer *grpc.Server // gRPC server instance
58 | Config *config.Config // Node configuration
59 | ProposeHook func(req *proto.ProposeRequest) bool // Hook for propose phase
60 | CommitHook func(req *proto.CommitRequest) bool // Hook for commit phase
61 | Addr string // Server address
62 | }
63 |
64 | func (s *Server) Propose(ctx context.Context, req *proto.ProposeRequest) (*proto.Response, error) {
65 | if s.cohort == nil {
66 | return nil, status.Error(codes.FailedPrecondition, "cohort role not enabled on this node")
67 | }
68 | resp, err := s.cohort.Propose(ctx, proposeRequestPbToEntity(req))
69 | return cohortResponseToProto(resp), err
70 | }
71 |
72 | func (s *Server) Precommit(ctx context.Context, req *proto.PrecommitRequest) (*proto.Response, error) {
73 | if s.cohort == nil {
74 | return nil, status.Error(codes.FailedPrecondition, "cohort role not enabled on this node")
75 | }
76 | resp, err := s.cohort.Precommit(ctx, req.Index)
77 | return cohortResponseToProto(resp), err
78 | }
79 |
80 | func (s *Server) Commit(ctx context.Context, req *proto.CommitRequest) (*proto.Response, error) {
81 | if s.cohort == nil {
82 | return nil, status.Error(codes.FailedPrecondition, "cohort role not enabled on this node")
83 | }
84 | resp, err := s.cohort.Commit(ctx, commitRequestPbToEntity(req))
85 | return cohortResponseToProto(resp), err
86 | }
87 |
88 | func (s *Server) Abort(ctx context.Context, req *proto.AbortRequest) (*proto.Response, error) {
89 | if s.cohort == nil {
90 | return nil, status.Error(codes.FailedPrecondition, "cohort role not enabled on this node")
91 | }
92 | abortReq := &dto.AbortRequest{
93 | Height: req.Height,
94 | Reason: req.Reason,
95 | }
96 | resp, err := s.cohort.Abort(ctx, abortReq)
97 | return cohortResponseToProto(resp), err
98 | }
99 |
100 | func (s *Server) Get(ctx context.Context, req *proto.Msg) (*proto.Value, error) {
101 | value, err := s.store.Get(req.Key)
102 | if err != nil {
103 | return nil, err
104 | }
105 | return &proto.Value{Value: value}, nil
106 | }
107 |
108 | // Put initiates a distributed transaction to store a key-value pair.
109 | func (s *Server) Put(ctx context.Context, req *proto.Entry) (*proto.Response, error) {
110 | if s.coordinator == nil {
111 | return nil, status.Error(codes.FailedPrecondition, "coordinator role not enabled on this node")
112 | }
113 | resp, err := s.coordinator.Broadcast(ctx, dto.BroadcastRequest{
114 | Key: req.Key,
115 | Value: req.Value,
116 | })
117 | if err != nil {
118 | return nil, err
119 | }
120 |
121 | return &proto.Response{
122 | Type: proto.Type(resp.Type),
123 | Index: resp.Height,
124 | }, nil
125 | }
126 |
127 | // NodeInfo returns information about the current node.
128 | func (s *Server) NodeInfo(ctx context.Context, req *emptypb.Empty) (*proto.Info, error) {
129 | switch {
130 | case s.cohort != nil:
131 | return &proto.Info{Height: s.cohort.Height()}, nil
132 | case s.coordinator != nil:
133 | return &proto.Info{Height: s.coordinator.Height()}, nil
134 | default:
135 | return nil, status.Error(codes.FailedPrecondition, "node has neither cohort nor coordinator role configured")
136 | }
137 | }
138 |
139 | // New creates a new Server instance with the specified configuration.
140 | func New(conf *config.Config, cohort Cohort, coordinator Coordinator, stateStore *store.Store) (*Server, error) {
141 | log.SetFormatter(&log.TextFormatter{
142 | ForceColors: true, // Seems like automatic color detection doesn't work on windows terminals
143 | FullTimestamp: true,
144 | TimestampFormat: time.RFC822,
145 | })
146 |
147 | server := &Server{
148 | Addr: conf.Nodeaddr,
149 | cohort: cohort,
150 | coordinator: coordinator,
151 | store: stateStore,
152 | Config: conf,
153 | }
154 |
155 | if server.Config.CommitType == TWO_PHASE {
156 | log.Info("two-phase-commit mode enabled")
157 | } else {
158 | log.Info("three-phase-commit mode enabled")
159 | }
160 | err := checkServerFields(server)
161 | return server, err
162 | }
163 |
164 | func checkServerFields(server *Server) error {
165 | if server.store == nil {
166 | return errors.New("store is not configured")
167 | }
168 | if server.Config.Role == "cohort" && server.cohort == nil {
169 | return errors.New("cohort role selected but cohort implementation is nil")
170 | }
171 | if server.Config.Role == "coordinator" && server.coordinator == nil {
172 | return errors.New("coordinator role selected but coordinator implementation is nil")
173 | }
174 | return nil
175 | }
176 |
177 | // Run starts the gRPC server in a non-blocking manner.
178 | func (s *Server) Run(opts ...grpc.UnaryServerInterceptor) {
179 | var err error
180 | s.GRPCServer = grpc.NewServer(grpc.ChainUnaryInterceptor(opts...))
181 | if s.cohort != nil {
182 | proto.RegisterInternalCommitAPIServer(s.GRPCServer, s)
183 | }
184 | proto.RegisterClientAPIServer(s.GRPCServer, s)
185 |
186 | l, err := net.Listen("tcp", s.Addr)
187 | if err != nil {
188 | log.Fatalf("failed to listen: %v", err)
189 | }
190 | log.Infof("listening on tcp://%s", s.Addr)
191 |
192 | go s.GRPCServer.Serve(l)
193 | }
194 |
195 | // Stop gracefully stops the gRPC server.
196 | func (s *Server) Stop() {
197 | log.Info("stopping server")
198 | s.GRPCServer.GracefulStop()
199 | if s.store != nil {
200 | if err := s.store.Close(); err != nil {
201 | log.Infof("failed to close store: %s\n", err)
202 | }
203 | }
204 | log.Info("server stopped")
205 | }
206 |
--------------------------------------------------------------------------------
/main_test.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "os"
7 | "path/filepath"
8 | "strconv"
9 | "testing"
10 | "time"
11 |
12 | log "github.com/sirupsen/logrus"
13 | "github.com/stretchr/testify/require"
14 | "github.com/vadiminshakov/committer/config"
15 | "github.com/vadiminshakov/committer/core/cohort"
16 | "github.com/vadiminshakov/committer/core/cohort/commitalgo"
17 | "github.com/vadiminshakov/committer/core/coordinator"
18 | "github.com/vadiminshakov/committer/io/gateway/grpc/client"
19 | pb "github.com/vadiminshakov/committer/io/gateway/grpc/proto"
20 | "github.com/vadiminshakov/committer/io/gateway/grpc/server"
21 | "github.com/vadiminshakov/committer/io/store"
22 | "github.com/vadiminshakov/gowal"
23 | )
24 |
25 | const (
26 | COORDINATOR_TYPE = "coordinator"
27 | COHORT_TYPE = "cohort"
28 | BADGER_DIR = "/tmp/badger"
29 | )
30 |
31 | var (
32 | whitelist = []string{"127.0.0.1"}
33 | nodes = map[string][]*config.Config{
34 | COORDINATOR_TYPE: {
35 | {Nodeaddr: "localhost:2938", Role: "coordinator",
36 | Cohorts: []string{"localhost:2345", "localhost:2384", "localhost:7532", "localhost:5743", "localhost:4991"},
37 | Whitelist: whitelist, CommitType: "two-phase", Timeout: 100},
38 | {Nodeaddr: "localhost:5002", Role: "coordinator",
39 | Cohorts: []string{"localhost:2345", "localhost:2384", "localhost:7532", "localhost:5743", "localhost:4991"},
40 | Whitelist: whitelist, CommitType: "three-phase", Timeout: 100},
41 | },
42 | COHORT_TYPE: {
43 | &config.Config{Nodeaddr: "localhost:2345", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"},
44 | &config.Config{Nodeaddr: "localhost:2384", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"},
45 | &config.Config{Nodeaddr: "localhost:7532", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"},
46 | &config.Config{Nodeaddr: "localhost:5743", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"},
47 | &config.Config{Nodeaddr: "localhost:4991", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"},
48 | },
49 | }
50 | )
51 |
52 | var testtable = map[string][]byte{
53 | "key1": []byte("value1"),
54 | "key2": []byte("value2"),
55 | "key3": []byte("value3"),
56 | }
57 |
58 | func TestHappyPath(t *testing.T) {
59 | log.SetLevel(log.InfoLevel)
60 |
61 | var canceller func() error
62 |
63 | var height uint64 = 0
64 | coordConfig := nodes[COORDINATOR_TYPE][0]
65 | if coordConfig.CommitType == "two-phase" {
66 | canceller = startnodes(pb.CommitType_TWO_PHASE_COMMIT)
67 | log.Println("***\nTEST IN TWO-PHASE MODE\n***")
68 | } else {
69 | canceller = startnodes(pb.CommitType_THREE_PHASE_COMMIT)
70 | log.Println("***\nTEST IN THREE-PHASE MODE\n***")
71 | }
72 |
73 | defer canceller()
74 |
75 | c, err := client.NewClientAPI(coordConfig.Nodeaddr)
76 | if err != nil {
77 | t.Error(err)
78 | }
79 |
80 | for key, val := range testtable {
81 | resp, err := c.Put(context.Background(), key, val)
82 | if err != nil {
83 | t.Error(err)
84 | }
85 |
86 | if resp.Type != pb.Type_ACK {
87 | t.Error("msg is not acknowledged")
88 | }
89 | // ok, value is added, let's increment height counter
90 | height++
91 | }
92 |
93 | // wait for rollback on cohorts
94 | time.Sleep(1 * time.Second)
95 |
96 | // connect to cohorts and check that them added key-value
97 | for _, node := range nodes[COHORT_TYPE] {
98 | cli, err := client.NewClientAPI(node.Nodeaddr)
99 | require.NoError(t, err, "err not nil")
100 |
101 | for key, val := range testtable {
102 | // check values added by nodes
103 | resp, err := cli.Get(context.Background(), key)
104 | require.NoError(t, err, "err not nil")
105 | require.Equal(t, resp.Value, val)
106 |
107 | // check height of node
108 | nodeInfo, err := cli.NodeInfo(context.Background())
109 | require.NoError(t, err, "err not nil")
110 | require.Equal(t, height, nodeInfo.Height, "node %s ahead, %d commits behind (current height is %d)", node.Nodeaddr, nodeInfo.Height-height, nodeInfo.Height)
111 | }
112 | }
113 |
114 | require.NoError(t, canceller())
115 | }
116 |
117 | func startnodes(commitType pb.CommitType) func() error {
118 | COORDINATOR_BADGER := fmt.Sprintf("%s%s%d", BADGER_DIR, "coordinator", time.Now().UnixNano())
119 | COHORT_BADGER := fmt.Sprintf("%s%s%d", BADGER_DIR, "cohort", time.Now().UnixNano())
120 |
121 | // check dir exists
122 | if _, err := os.Stat(COORDINATOR_BADGER); !os.IsNotExist(err) {
123 | // del dir
124 | err := os.RemoveAll(COORDINATOR_BADGER)
125 | failfast(err)
126 | }
127 | if _, err := os.Stat(COHORT_BADGER); !os.IsNotExist(err) {
128 | // del dir
129 | failfast(os.RemoveAll(COHORT_BADGER))
130 | }
131 | if _, err := os.Stat("./tmp"); !os.IsNotExist(err) {
132 | // del dir
133 | failfast(os.RemoveAll("./tmp"))
134 |
135 | }
136 |
137 | {
138 | // nolint:errcheck
139 | failfast(os.Mkdir(COORDINATOR_BADGER, os.FileMode(0777)))
140 | failfast(os.Mkdir(COHORT_BADGER, os.FileMode(0777)))
141 | failfast(os.Mkdir("./tmp", os.FileMode(0777)))
142 | failfast(os.Mkdir("./tmp/cohort", os.FileMode(0777)))
143 | failfast(os.Mkdir("./tmp/coord", os.FileMode(0777)))
144 | }
145 |
146 | // start cohorts
147 | stopfuncs := make([]func(), 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
148 | for i, node := range nodes[COHORT_TYPE] {
149 | if commitType == pb.CommitType_THREE_PHASE_COMMIT {
150 | node.Coordinator = nodes[COORDINATOR_TYPE][1].Nodeaddr
151 | }
152 | node.DBPath = filepath.Join(COHORT_BADGER, strconv.Itoa(i))
153 |
154 | failfast(os.MkdirAll(node.DBPath, os.FileMode(0o777)))
155 | walConfig := gowal.Config{
156 | Dir: "./tmp/cohort/" + strconv.Itoa(i),
157 | Prefix: "msgs_",
158 | SegmentThreshold: 100,
159 | MaxSegments: 100,
160 | IsInSyncDiskMode: false,
161 | }
162 | w, err := gowal.NewWAL(walConfig)
163 | failfast(err)
164 |
165 | stateStore, recovery, err := store.New(w, node.DBPath)
166 | failfast(err)
167 |
168 | ct := server.TWO_PHASE
169 | if commitType == pb.CommitType_THREE_PHASE_COMMIT {
170 | ct = server.THREE_PHASE
171 | }
172 |
173 | committer := commitalgo.NewCommitter(stateStore, ct, w, node.Timeout)
174 | committer.SetHeight(recovery.Height)
175 | cohortImpl := cohort.NewCohort(committer, cohort.Mode(node.CommitType))
176 |
177 | cohortServer, err := server.New(node, cohortImpl, nil, stateStore)
178 | failfast(err)
179 |
180 | go cohortServer.Run(server.WhiteListChecker)
181 |
182 | stopfuncs = append(stopfuncs, cohortServer.Stop)
183 | }
184 |
185 | // start coordinators (in two- and three-phase modes)
186 | for i, coordConfig := range nodes[COORDINATOR_TYPE] {
187 | coordConfig.DBPath = filepath.Join(COORDINATOR_BADGER, strconv.Itoa(i))
188 | failfast(os.MkdirAll(coordConfig.DBPath, os.FileMode(0o777)))
189 | walConfig := gowal.Config{
190 | Dir: "./tmp/coord/msgs" + strconv.Itoa(i),
191 | Prefix: "msgs",
192 | SegmentThreshold: 100,
193 | MaxSegments: 100,
194 | IsInSyncDiskMode: false,
195 | }
196 |
197 | c, err := gowal.NewWAL(walConfig)
198 | failfast(err)
199 |
200 | stateStore, recovery, err := store.New(c, coordConfig.DBPath)
201 | failfast(err)
202 |
203 | coord, err := coordinator.New(coordConfig, c, stateStore)
204 | failfast(err)
205 | coord.SetHeight(recovery.Height)
206 |
207 | coordServer, err := server.New(coordConfig, nil, coord, stateStore)
208 | failfast(err)
209 |
210 | go coordServer.Run(server.WhiteListChecker)
211 | time.Sleep(100 * time.Millisecond)
212 | stopfuncs = append(stopfuncs, coordServer.Stop)
213 | }
214 |
215 | return func() error {
216 | for _, f := range stopfuncs {
217 | f()
218 | }
219 | failfast(os.RemoveAll("./tmp"))
220 | return os.RemoveAll(BADGER_DIR)
221 | }
222 | }
223 |
224 | func failfast(err error) {
225 | if err != nil {
226 | panic(err)
227 | }
228 | }
229 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
--------------------------------------------------------------------------------
/core/coordinator/coordinator.go:
--------------------------------------------------------------------------------
1 | // Package coordinator implements the coordinator role in distributed atomic commit protocols.
2 | // The coordinator is responsible for managing the atomic commit process by
3 | // sending prepare requests to cohorts and collecting their responses.
4 | package coordinator
5 |
6 | import (
7 | "context"
8 | "fmt"
9 | "strings"
10 | "sync"
11 | "sync/atomic"
12 |
13 | "github.com/pkg/errors"
14 | log "github.com/sirupsen/logrus"
15 | "github.com/vadiminshakov/committer/config"
16 | "github.com/vadiminshakov/committer/core/dto"
17 | "github.com/vadiminshakov/committer/core/walrecord"
18 | "github.com/vadiminshakov/committer/io/gateway/grpc/client"
19 | pb "github.com/vadiminshakov/committer/io/gateway/grpc/proto"
20 | "github.com/vadiminshakov/committer/io/gateway/grpc/server"
21 | "google.golang.org/grpc/codes"
22 | "google.golang.org/grpc/status"
23 | )
24 |
25 | //go:generate mockgen -destination=../../mocks/mock_wal.go -package=mocks . wal
26 | type wal interface {
27 | Write(index uint64, key string, value []byte) error
28 | WriteTombstone(index uint64) error
29 | Get(index uint64) (string, []byte, error)
30 | Close() error
31 | }
32 |
33 | // StateStore defines the interface for persistent state storage.
34 | //
35 | //go:generate mockgen -destination=../../mocks/mock_state_store.go -package=mocks . StateStore
36 | type StateStore interface {
37 | Put(key string, value []byte) error
38 | Close() error
39 | }
40 |
41 | type coordinator struct {
42 | wal wal
43 | store StateStore
44 | cohorts map[string]*client.InternalCommitClient
45 | config *config.Config
46 | commitType pb.CommitType
47 | threePhase bool
48 | height uint64
49 | mu sync.Mutex // serialize broadcast to keep height/WAL alignment
50 | }
51 |
52 | // New creates a new coordinator instance with the specified configuration.
53 | func New(conf *config.Config, wal wal, store StateStore) (*coordinator, error) {
54 | cohorts := make(map[string]*client.InternalCommitClient, len(conf.Cohorts))
55 | for _, f := range conf.Cohorts {
56 | cl, err := client.NewInternalClient(f)
57 | if err != nil {
58 | return nil, err
59 | }
60 |
61 | cohorts[f] = cl
62 | }
63 |
64 | threePhase := conf.CommitType == server.THREE_PHASE
65 | commitType := pb.CommitType_TWO_PHASE_COMMIT
66 | if threePhase {
67 | commitType = pb.CommitType_THREE_PHASE_COMMIT
68 | }
69 |
70 | return &coordinator{
71 | wal: wal,
72 | store: store,
73 | cohorts: cohorts,
74 | config: conf,
75 | commitType: commitType,
76 | threePhase: threePhase,
77 | }, nil
78 | }
79 |
80 | // Broadcast executes the complete distributed consensus algorithm (2PC or 3PC) for a transaction.
81 | // It runs through all phases: propose, precommit (if 3PC), commit, and persistence.
82 | func (c *coordinator) Broadcast(ctx context.Context, req dto.BroadcastRequest) (*dto.BroadcastResponse, error) {
83 | c.mu.Lock()
84 | defer c.mu.Unlock()
85 |
86 | log.Infof("Proposing key %s", req.Key)
87 | if err := c.propose(ctx, req); err != nil {
88 | return nackResponse(err, "failed to send propose")
89 | }
90 |
91 | if c.threePhase {
92 | log.Infof("Precommitting key %s", req.Key)
93 | if err := c.preCommit(ctx); err != nil {
94 | return nackResponse(err, "failed to send precommit")
95 | }
96 | }
97 |
98 | log.Infof("Committing key %s", req.Key)
99 | if err := c.commit(ctx); err != nil {
100 | s, ok := status.FromError(err)
101 | if !ok {
102 | return &dto.BroadcastResponse{Type: dto.ResponseTypeNack}, fmt.Errorf("failed to extract grpc status code from err: %s", err)
103 | }
104 | if s.Code() == codes.AlreadyExists {
105 | return &dto.BroadcastResponse{Type: dto.ResponseTypeNack}, nil
106 | }
107 | return nackResponse(err, "failed to send commit")
108 | }
109 |
110 | log.Infof("coordinator committed key %s", req.Key)
111 |
112 | newHeight := atomic.AddUint64(&c.height, 1)
113 | return &dto.BroadcastResponse{Type: dto.ResponseTypeAck, Height: newHeight}, nil
114 | }
115 |
116 | func (c *coordinator) propose(ctx context.Context, req dto.BroadcastRequest) error {
117 | for name, cohort := range c.cohorts {
118 | if err := c.sendProposal(ctx, cohort, name, req, c.commitType); err != nil {
119 | return err
120 | }
121 | }
122 |
123 | currentHeight := atomic.LoadUint64(&c.height)
124 |
125 | // write Prepared to WAL to persist payload
126 | ptx := walrecord.WalTx{Key: req.Key, Value: req.Value}
127 | pbBytes, err := walrecord.Encode(ptx)
128 | if err != nil {
129 | return err
130 | }
131 | return c.wal.Write(walrecord.PreparedSlot(currentHeight), walrecord.KeyPrepared, pbBytes)
132 | }
133 |
134 | func (c *coordinator) sendProposal(ctx context.Context, cohort *client.InternalCommitClient, name string, req dto.BroadcastRequest, commitType pb.CommitType) error {
135 | var (
136 | resp *pb.Response
137 | err error
138 | )
139 |
140 | for {
141 | currentHeight := atomic.LoadUint64(&c.height)
142 | resp, err = cohort.Propose(ctx, &pb.ProposeRequest{
143 | Key: req.Key,
144 | Value: req.Value,
145 | CommitType: commitType,
146 | Index: currentHeight,
147 | })
148 |
149 | if err == nil && resp != nil && resp.Type == pb.Type_ACK {
150 | break // success
151 | }
152 |
153 | // if cohort has bigger height, update coordinator's height and retry
154 | if resp != nil && resp.Index > currentHeight {
155 | c.syncHeight(resp.Index)
156 | continue
157 | }
158 | if err != nil {
159 | // send abort to all cohorts on error
160 | c.abort(ctx, fmt.Sprintf("node %s rejected proposed msg: %v", name, err))
161 | return fmt.Errorf("node %s rejected proposed msg: %w", name, err)
162 | }
163 |
164 | // send abort to all cohorts on NACK
165 | c.abort(ctx, fmt.Sprintf("cohort %s sent NACK for propose", name))
166 | return fmt.Errorf("cohort %s not acknowledged msg %v", name, req)
167 | }
168 | return nil
169 | }
170 |
171 | func (c *coordinator) preCommit(ctx context.Context) error {
172 | currentHeight := atomic.LoadUint64(&c.height)
173 | for name, cohort := range c.cohorts {
174 | resp, err := cohort.Precommit(ctx, &pb.PrecommitRequest{Index: currentHeight})
175 | if err != nil {
176 | c.abort(ctx, fmt.Sprintf("cohort %s precommit error: %v", name, err))
177 | return status.Error(codes.FailedPrecondition, "cohort not acknowledged msg")
178 | }
179 | if !isAck(resp) {
180 | c.abort(ctx, fmt.Sprintf("cohort %s sent NACK for precommit", name))
181 | return status.Error(codes.FailedPrecondition, "cohort not acknowledged msg")
182 | }
183 | }
184 |
185 | return nil
186 | }
187 |
188 | func (c *coordinator) commit(ctx context.Context) error {
189 | currentHeight := atomic.LoadUint64(&c.height)
190 | var errs []string
191 | for name, cohort := range c.cohorts {
192 | resp, err := cohort.Commit(ctx, &pb.CommitRequest{Index: currentHeight})
193 | if err != nil {
194 | errs = append(errs, fmt.Sprintf("%s: %v", name, err))
195 | continue
196 | }
197 | if !isAck(resp) {
198 | errs = append(errs, fmt.Sprintf("%s: NACK", name))
199 | }
200 | }
201 |
202 | if len(errs) > 0 {
203 | return status.Error(codes.FailedPrecondition, "cohort not acknowledged msg: "+strings.Join(errs, "; "))
204 | }
205 |
206 | // Persist Decision (Commit) and Apply
207 | // 1. read Payload from Prepared
208 | k, v, err := c.wal.Get(walrecord.PreparedSlot(currentHeight))
209 | if err != nil {
210 | return status.Errorf(codes.Internal, "failed to read prepared tx: %v", err)
211 | }
212 | if k != walrecord.KeyPrepared {
213 | return status.Errorf(codes.Internal, "expected prepared tx, got %s", k)
214 | }
215 |
216 | // 2. write commit
217 | if err := c.wal.Write(walrecord.CommitSlot(currentHeight), walrecord.KeyCommit, v); err != nil {
218 | return status.Errorf(codes.Internal, "failed to write commit val: %v", err)
219 | }
220 |
221 | // 3. apply
222 | walTx, err := walrecord.Decode(v)
223 | if err != nil {
224 | return status.Errorf(codes.Internal, "failed to decode tx: %v", err)
225 | }
226 | if err := c.store.Put(walTx.Key, walTx.Value); err != nil {
227 | log.Errorf("failed to apply to store: %v", err)
228 | return err // committed but not applied? critical error
229 | }
230 |
231 | return nil
232 | }
233 |
234 | func isAck(resp *pb.Response) bool {
235 | return resp != nil && resp.Type == pb.Type_ACK
236 | }
237 |
238 | // syncHeight atomically updates coordinator height to match cohort height if needed
239 | func (c *coordinator) syncHeight(cohortHeight uint64) {
240 | for {
241 | currentHeight := atomic.LoadUint64(&c.height)
242 | if cohortHeight <= currentHeight {
243 | return // height is already up to date
244 | }
245 |
246 | if atomic.CompareAndSwapUint64(&c.height, currentHeight, cohortHeight) {
247 | log.Warnf("Updating coordinator height: %d -> %d", currentHeight, cohortHeight)
248 | return
249 | }
250 | }
251 | }
252 |
253 | // Height returns the current transaction height.
254 | func (c *coordinator) Height() uint64 {
255 | return atomic.LoadUint64(&c.height)
256 | }
257 |
258 | // SetHeight initializes coordinator height during recovery.
259 | func (c *coordinator) SetHeight(height uint64) {
260 | atomic.StoreUint64(&c.height, height)
261 | }
262 |
263 | // abort sends abort requests to all cohorts in a fire-and-forget manner
264 | func (c *coordinator) abort(ctx context.Context, reason string) {
265 | currentHeight := atomic.LoadUint64(&c.height)
266 | log.Warnf("Aborting transaction at height %d: %s", currentHeight, reason)
267 |
268 | if err := c.wal.Write(walrecord.AbortSlot(currentHeight), walrecord.KeyAbort, nil); err != nil {
269 | log.Errorf("Failed to write abort to WAL: %v", err)
270 | }
271 |
272 | for name, cohort := range c.cohorts {
273 | go func(name string, cohort *client.InternalCommitClient) {
274 | if _, err := cohort.Abort(ctx, &dto.AbortRequest{Height: currentHeight, Reason: reason}); err != nil {
275 | log.Errorf("Failed to send abort to cohort %s: %v", name, err)
276 | }
277 | }(name, cohort)
278 | }
279 | }
280 |
281 | func nackResponse(err error, msg string) (*dto.BroadcastResponse, error) {
282 | return &dto.BroadcastResponse{Type: dto.ResponseTypeNack}, errors.Wrap(err, msg)
283 | }
284 |
285 | func (c *coordinator) persistMessage() error {
286 | currentHeight := atomic.LoadUint64(&c.height)
287 |
288 | // read prepared
289 | k, v, err := c.wal.Get(walrecord.PreparedSlot(currentHeight))
290 | if err != nil {
291 | return status.Error(codes.Internal, fmt.Sprintf("failed to read msg at height %d from wal: %v", currentHeight, err))
292 | }
293 | if v == nil {
294 | return status.Error(codes.Internal, "can't find msg in wal")
295 | }
296 | if k != walrecord.KeyPrepared {
297 | return status.Error(codes.Internal, fmt.Sprintf("expected prepared key, got %s", k))
298 | }
299 |
300 | walTx, err := walrecord.Decode(v)
301 | if err != nil {
302 | return status.Error(codes.Internal, "failed to decode tx")
303 | }
304 |
305 | // save
306 | return c.store.Put(walTx.Key, walTx.Value)
307 | }
308 |
--------------------------------------------------------------------------------
/io/gateway/grpc/proto/schema_grpc.pb.go:
--------------------------------------------------------------------------------
1 | // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
2 |
3 | package proto
4 |
5 | import (
6 | context "context"
7 | grpc "google.golang.org/grpc"
8 | codes "google.golang.org/grpc/codes"
9 | status "google.golang.org/grpc/status"
10 | emptypb "google.golang.org/protobuf/types/known/emptypb"
11 | )
12 |
13 | // This is a compile-time assertion to ensure that this generated file
14 | // is compatible with the grpc package it is being compiled against.
15 | // Requires gRPC-Go v1.32.0 or later.
16 | const _ = grpc.SupportPackageIsVersion7
17 |
18 | // InternalCommitAPIClient is the client API for InternalCommitAPI service.
19 | //
20 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
21 | type InternalCommitAPIClient interface {
22 | Propose(ctx context.Context, in *ProposeRequest, opts ...grpc.CallOption) (*Response, error)
23 | Precommit(ctx context.Context, in *PrecommitRequest, opts ...grpc.CallOption) (*Response, error)
24 | Commit(ctx context.Context, in *CommitRequest, opts ...grpc.CallOption) (*Response, error)
25 | Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*Response, error)
26 | }
27 |
28 | type internalCommitAPIClient struct {
29 | cc grpc.ClientConnInterface
30 | }
31 |
32 | func NewInternalCommitAPIClient(cc grpc.ClientConnInterface) InternalCommitAPIClient {
33 | return &internalCommitAPIClient{cc}
34 | }
35 |
36 | func (c *internalCommitAPIClient) Propose(ctx context.Context, in *ProposeRequest, opts ...grpc.CallOption) (*Response, error) {
37 | out := new(Response)
38 | err := c.cc.Invoke(ctx, "/schema.InternalCommitAPI/Propose", in, out, opts...)
39 | if err != nil {
40 | return nil, err
41 | }
42 | return out, nil
43 | }
44 |
45 | func (c *internalCommitAPIClient) Precommit(ctx context.Context, in *PrecommitRequest, opts ...grpc.CallOption) (*Response, error) {
46 | out := new(Response)
47 | err := c.cc.Invoke(ctx, "/schema.InternalCommitAPI/Precommit", in, out, opts...)
48 | if err != nil {
49 | return nil, err
50 | }
51 | return out, nil
52 | }
53 |
54 | func (c *internalCommitAPIClient) Commit(ctx context.Context, in *CommitRequest, opts ...grpc.CallOption) (*Response, error) {
55 | out := new(Response)
56 | err := c.cc.Invoke(ctx, "/schema.InternalCommitAPI/Commit", in, out, opts...)
57 | if err != nil {
58 | return nil, err
59 | }
60 | return out, nil
61 | }
62 |
63 | func (c *internalCommitAPIClient) Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*Response, error) {
64 | out := new(Response)
65 | err := c.cc.Invoke(ctx, "/schema.InternalCommitAPI/Abort", in, out, opts...)
66 | if err != nil {
67 | return nil, err
68 | }
69 | return out, nil
70 | }
71 |
72 | // InternalCommitAPIServer is the server API for InternalCommitAPI service.
73 | // All implementations must embed UnimplementedInternalCommitAPIServer
74 | // for forward compatibility
75 | type InternalCommitAPIServer interface {
76 | Propose(context.Context, *ProposeRequest) (*Response, error)
77 | Precommit(context.Context, *PrecommitRequest) (*Response, error)
78 | Commit(context.Context, *CommitRequest) (*Response, error)
79 | Abort(context.Context, *AbortRequest) (*Response, error)
80 | mustEmbedUnimplementedInternalCommitAPIServer()
81 | }
82 |
83 | // UnimplementedInternalCommitAPIServer must be embedded to have forward compatible implementations.
84 | type UnimplementedInternalCommitAPIServer struct {
85 | }
86 |
87 | func (UnimplementedInternalCommitAPIServer) Propose(context.Context, *ProposeRequest) (*Response, error) {
88 | return nil, status.Errorf(codes.Unimplemented, "method Propose not implemented")
89 | }
90 | func (UnimplementedInternalCommitAPIServer) Precommit(context.Context, *PrecommitRequest) (*Response, error) {
91 | return nil, status.Errorf(codes.Unimplemented, "method Precommit not implemented")
92 | }
93 | func (UnimplementedInternalCommitAPIServer) Commit(context.Context, *CommitRequest) (*Response, error) {
94 | return nil, status.Errorf(codes.Unimplemented, "method Commit not implemented")
95 | }
96 | func (UnimplementedInternalCommitAPIServer) Abort(context.Context, *AbortRequest) (*Response, error) {
97 | return nil, status.Errorf(codes.Unimplemented, "method Abort not implemented")
98 | }
99 | func (UnimplementedInternalCommitAPIServer) mustEmbedUnimplementedInternalCommitAPIServer() {}
100 |
101 | // UnsafeInternalCommitAPIServer may be embedded to opt out of forward compatibility for this service.
102 | // Use of this interface is not recommended, as added methods to InternalCommitAPIServer will
103 | // result in compilation errors.
104 | type UnsafeInternalCommitAPIServer interface {
105 | mustEmbedUnimplementedInternalCommitAPIServer()
106 | }
107 |
108 | func RegisterInternalCommitAPIServer(s grpc.ServiceRegistrar, srv InternalCommitAPIServer) {
109 | s.RegisterService(&InternalCommitAPI_ServiceDesc, srv)
110 | }
111 |
112 | func _InternalCommitAPI_Propose_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
113 | in := new(ProposeRequest)
114 | if err := dec(in); err != nil {
115 | return nil, err
116 | }
117 | if interceptor == nil {
118 | return srv.(InternalCommitAPIServer).Propose(ctx, in)
119 | }
120 | info := &grpc.UnaryServerInfo{
121 | Server: srv,
122 | FullMethod: "/schema.InternalCommitAPI/Propose",
123 | }
124 | handler := func(ctx context.Context, req interface{}) (interface{}, error) {
125 | return srv.(InternalCommitAPIServer).Propose(ctx, req.(*ProposeRequest))
126 | }
127 | return interceptor(ctx, in, info, handler)
128 | }
129 |
130 | func _InternalCommitAPI_Precommit_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
131 | in := new(PrecommitRequest)
132 | if err := dec(in); err != nil {
133 | return nil, err
134 | }
135 | if interceptor == nil {
136 | return srv.(InternalCommitAPIServer).Precommit(ctx, in)
137 | }
138 | info := &grpc.UnaryServerInfo{
139 | Server: srv,
140 | FullMethod: "/schema.InternalCommitAPI/Precommit",
141 | }
142 | handler := func(ctx context.Context, req interface{}) (interface{}, error) {
143 | return srv.(InternalCommitAPIServer).Precommit(ctx, req.(*PrecommitRequest))
144 | }
145 | return interceptor(ctx, in, info, handler)
146 | }
147 |
148 | func _InternalCommitAPI_Commit_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
149 | in := new(CommitRequest)
150 | if err := dec(in); err != nil {
151 | return nil, err
152 | }
153 | if interceptor == nil {
154 | return srv.(InternalCommitAPIServer).Commit(ctx, in)
155 | }
156 | info := &grpc.UnaryServerInfo{
157 | Server: srv,
158 | FullMethod: "/schema.InternalCommitAPI/Commit",
159 | }
160 | handler := func(ctx context.Context, req interface{}) (interface{}, error) {
161 | return srv.(InternalCommitAPIServer).Commit(ctx, req.(*CommitRequest))
162 | }
163 | return interceptor(ctx, in, info, handler)
164 | }
165 |
166 | func _InternalCommitAPI_Abort_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
167 | in := new(AbortRequest)
168 | if err := dec(in); err != nil {
169 | return nil, err
170 | }
171 | if interceptor == nil {
172 | return srv.(InternalCommitAPIServer).Abort(ctx, in)
173 | }
174 | info := &grpc.UnaryServerInfo{
175 | Server: srv,
176 | FullMethod: "/schema.InternalCommitAPI/Abort",
177 | }
178 | handler := func(ctx context.Context, req interface{}) (interface{}, error) {
179 | return srv.(InternalCommitAPIServer).Abort(ctx, req.(*AbortRequest))
180 | }
181 | return interceptor(ctx, in, info, handler)
182 | }
183 |
184 | // InternalCommitAPI_ServiceDesc is the grpc.ServiceDesc for InternalCommitAPI service.
185 | // It's only intended for direct use with grpc.RegisterService,
186 | // and not to be introspected or modified (even as a copy)
187 | var InternalCommitAPI_ServiceDesc = grpc.ServiceDesc{
188 | ServiceName: "schema.InternalCommitAPI",
189 | HandlerType: (*InternalCommitAPIServer)(nil),
190 | Methods: []grpc.MethodDesc{
191 | {
192 | MethodName: "Propose",
193 | Handler: _InternalCommitAPI_Propose_Handler,
194 | },
195 | {
196 | MethodName: "Precommit",
197 | Handler: _InternalCommitAPI_Precommit_Handler,
198 | },
199 | {
200 | MethodName: "Commit",
201 | Handler: _InternalCommitAPI_Commit_Handler,
202 | },
203 | {
204 | MethodName: "Abort",
205 | Handler: _InternalCommitAPI_Abort_Handler,
206 | },
207 | },
208 | Streams: []grpc.StreamDesc{},
209 | Metadata: "schema.proto",
210 | }
211 |
212 | // ClientAPIClient is the client API for ClientAPI service.
213 | //
214 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
215 | type ClientAPIClient interface {
216 | Put(ctx context.Context, in *Entry, opts ...grpc.CallOption) (*Response, error)
217 | Get(ctx context.Context, in *Msg, opts ...grpc.CallOption) (*Value, error)
218 | NodeInfo(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*Info, error)
219 | }
220 |
221 | type clientAPIClient struct {
222 | cc grpc.ClientConnInterface
223 | }
224 |
225 | func NewClientAPIClient(cc grpc.ClientConnInterface) ClientAPIClient {
226 | return &clientAPIClient{cc}
227 | }
228 |
229 | func (c *clientAPIClient) Put(ctx context.Context, in *Entry, opts ...grpc.CallOption) (*Response, error) {
230 | out := new(Response)
231 | err := c.cc.Invoke(ctx, "/schema.ClientAPI/Put", in, out, opts...)
232 | if err != nil {
233 | return nil, err
234 | }
235 | return out, nil
236 | }
237 |
238 | func (c *clientAPIClient) Get(ctx context.Context, in *Msg, opts ...grpc.CallOption) (*Value, error) {
239 | out := new(Value)
240 | err := c.cc.Invoke(ctx, "/schema.ClientAPI/Get", in, out, opts...)
241 | if err != nil {
242 | return nil, err
243 | }
244 | return out, nil
245 | }
246 |
247 | func (c *clientAPIClient) NodeInfo(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*Info, error) {
248 | out := new(Info)
249 | err := c.cc.Invoke(ctx, "/schema.ClientAPI/NodeInfo", in, out, opts...)
250 | if err != nil {
251 | return nil, err
252 | }
253 | return out, nil
254 | }
255 |
256 | // ClientAPIServer is the server API for ClientAPI service.
257 | // All implementations must embed UnimplementedClientAPIServer
258 | // for forward compatibility
259 | type ClientAPIServer interface {
260 | Put(context.Context, *Entry) (*Response, error)
261 | Get(context.Context, *Msg) (*Value, error)
262 | NodeInfo(context.Context, *emptypb.Empty) (*Info, error)
263 | mustEmbedUnimplementedClientAPIServer()
264 | }
265 |
266 | // UnimplementedClientAPIServer must be embedded to have forward compatible implementations.
267 | type UnimplementedClientAPIServer struct {
268 | }
269 |
270 | func (UnimplementedClientAPIServer) Put(context.Context, *Entry) (*Response, error) {
271 | return nil, status.Errorf(codes.Unimplemented, "method Put not implemented")
272 | }
273 | func (UnimplementedClientAPIServer) Get(context.Context, *Msg) (*Value, error) {
274 | return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
275 | }
276 | func (UnimplementedClientAPIServer) NodeInfo(context.Context, *emptypb.Empty) (*Info, error) {
277 | return nil, status.Errorf(codes.Unimplemented, "method NodeInfo not implemented")
278 | }
279 | func (UnimplementedClientAPIServer) mustEmbedUnimplementedClientAPIServer() {}
280 |
281 | // UnsafeClientAPIServer may be embedded to opt out of forward compatibility for this service.
282 | // Use of this interface is not recommended, as added methods to ClientAPIServer will
283 | // result in compilation errors.
284 | type UnsafeClientAPIServer interface {
285 | mustEmbedUnimplementedClientAPIServer()
286 | }
287 |
288 | func RegisterClientAPIServer(s grpc.ServiceRegistrar, srv ClientAPIServer) {
289 | s.RegisterService(&ClientAPI_ServiceDesc, srv)
290 | }
291 |
292 | func _ClientAPI_Put_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
293 | in := new(Entry)
294 | if err := dec(in); err != nil {
295 | return nil, err
296 | }
297 | if interceptor == nil {
298 | return srv.(ClientAPIServer).Put(ctx, in)
299 | }
300 | info := &grpc.UnaryServerInfo{
301 | Server: srv,
302 | FullMethod: "/schema.ClientAPI/Put",
303 | }
304 | handler := func(ctx context.Context, req interface{}) (interface{}, error) {
305 | return srv.(ClientAPIServer).Put(ctx, req.(*Entry))
306 | }
307 | return interceptor(ctx, in, info, handler)
308 | }
309 |
310 | func _ClientAPI_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
311 | in := new(Msg)
312 | if err := dec(in); err != nil {
313 | return nil, err
314 | }
315 | if interceptor == nil {
316 | return srv.(ClientAPIServer).Get(ctx, in)
317 | }
318 | info := &grpc.UnaryServerInfo{
319 | Server: srv,
320 | FullMethod: "/schema.ClientAPI/Get",
321 | }
322 | handler := func(ctx context.Context, req interface{}) (interface{}, error) {
323 | return srv.(ClientAPIServer).Get(ctx, req.(*Msg))
324 | }
325 | return interceptor(ctx, in, info, handler)
326 | }
327 |
328 | func _ClientAPI_NodeInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
329 | in := new(emptypb.Empty)
330 | if err := dec(in); err != nil {
331 | return nil, err
332 | }
333 | if interceptor == nil {
334 | return srv.(ClientAPIServer).NodeInfo(ctx, in)
335 | }
336 | info := &grpc.UnaryServerInfo{
337 | Server: srv,
338 | FullMethod: "/schema.ClientAPI/NodeInfo",
339 | }
340 | handler := func(ctx context.Context, req interface{}) (interface{}, error) {
341 | return srv.(ClientAPIServer).NodeInfo(ctx, req.(*emptypb.Empty))
342 | }
343 | return interceptor(ctx, in, info, handler)
344 | }
345 |
346 | // ClientAPI_ServiceDesc is the grpc.ServiceDesc for ClientAPI service.
347 | // It's only intended for direct use with grpc.RegisterService,
348 | // and not to be introspected or modified (even as a copy)
349 | var ClientAPI_ServiceDesc = grpc.ServiceDesc{
350 | ServiceName: "schema.ClientAPI",
351 | HandlerType: (*ClientAPIServer)(nil),
352 | Methods: []grpc.MethodDesc{
353 | {
354 | MethodName: "Put",
355 | Handler: _ClientAPI_Put_Handler,
356 | },
357 | {
358 | MethodName: "Get",
359 | Handler: _ClientAPI_Get_Handler,
360 | },
361 | {
362 | MethodName: "NodeInfo",
363 | Handler: _ClientAPI_NodeInfo_Handler,
364 | },
365 | },
366 | Streams: []grpc.StreamDesc{},
367 | Metadata: "schema.proto",
368 | }
369 |
--------------------------------------------------------------------------------
/core/cohort/commitalgo/commitalgo.go:
--------------------------------------------------------------------------------
1 | // Package commitalgo implements the core commit algorithms for 2PC and 3PC protocols.
2 | //
3 | // This package provides the finite state machine logic and transaction handling
4 | // for cohort nodes participating in distributed consensus.
5 | package commitalgo
6 |
7 | import (
8 | "context"
9 | "sync"
10 | "sync/atomic"
11 | "time"
12 |
13 | log "github.com/sirupsen/logrus"
14 | "github.com/vadiminshakov/committer/core/cohort/commitalgo/hooks"
15 | "github.com/vadiminshakov/committer/core/dto"
16 | "github.com/vadiminshakov/committer/core/walrecord"
17 | "google.golang.org/grpc/codes"
18 | "google.golang.org/grpc/status"
19 | )
20 |
21 | // wal defines the interface for write-ahead log operations.
22 | type wal interface {
23 | Write(index uint64, key string, value []byte) error
24 | WriteTombstone(index uint64) error
25 | Get(index uint64) (string, []byte, error)
26 | Close() error
27 | }
28 |
29 | // StateStore defines the interface for state storage.
30 | //
31 | //go:generate mockgen -destination=../../../mocks/mock_commitalgo_state_store.go -package=mocks -mock_names=StateStore=MockCommitalgoStateStore . StateStore
32 | type StateStore interface {
33 | Put(key string, value []byte) error
34 | Close() error
35 | }
36 |
37 | // CommitterImpl implements the commit algorithm with state machine and hooks.
38 | type CommitterImpl struct {
39 | store StateStore
40 | wal wal
41 | hookRegistry *hooks.Registry
42 | state *stateMachine
43 | height uint64
44 | timeout uint64
45 | mu sync.Mutex
46 | }
47 |
48 | // NewCommitter creates a new committer instance with the specified configuration.
49 | func NewCommitter(store StateStore, commitType string, wal wal, timeout uint64, customHooks ...hooks.Hook) *CommitterImpl {
50 | registry := hooks.NewRegistry()
51 |
52 | for _, hook := range customHooks {
53 | registry.Register(hook)
54 | }
55 |
56 | if len(customHooks) == 0 {
57 | registry.Register(hooks.NewDefaultHook())
58 | }
59 |
60 | return &CommitterImpl{
61 | hookRegistry: registry,
62 | store: store,
63 | wal: wal,
64 | timeout: timeout,
65 | state: newStateMachine(mode(commitType)),
66 | }
67 | }
68 |
69 | // RegisterHook adds a new hook to the committer.
70 | func (c *CommitterImpl) RegisterHook(hook hooks.Hook) {
71 | c.hookRegistry.Register(hook)
72 | }
73 |
74 | // Height returns the current transaction height.
75 | func (c *CommitterImpl) Height() uint64 {
76 | return atomic.LoadUint64(&c.height)
77 | }
78 |
79 | // SetHeight initializes committer height from recovered WAL state.
80 | func (c *CommitterImpl) SetHeight(height uint64) {
81 | atomic.StoreUint64(&c.height, height)
82 | }
83 |
84 | func (c *CommitterImpl) getCurrentState() string {
85 | return c.state.getCurrentState()
86 | }
87 |
88 | // Propose handles the propose phase of the commit protocol.
89 | func (c *CommitterImpl) Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) {
90 | c.mu.Lock()
91 | defer c.mu.Unlock()
92 |
93 | currentHeight := atomic.LoadUint64(&c.height)
94 | if currentHeight > req.Height {
95 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack, Height: currentHeight}, nil
96 | }
97 |
98 | if !c.hookRegistry.ExecutePropose(req) {
99 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack, Height: req.Height}, nil
100 | }
101 |
102 | if err := c.state.Transition(proposeStage); err != nil {
103 | return nil, err
104 | }
105 |
106 | // prepare payload
107 | ptx := walrecord.WalTx{Key: req.Key, Value: req.Value}
108 | pb, err := walrecord.Encode(ptx)
109 | if err != nil {
110 | return nil, status.Errorf(codes.Internal, "encode error: %v", err)
111 | }
112 |
113 | // save
114 | if err := c.wal.Write(walrecord.PreparedSlot(req.Height), walrecord.KeyPrepared, pb); err != nil {
115 | if terr := c.state.Transition(proposeStage); terr != nil {
116 | log.Errorf("failed to reset state to propose after WAL error: %v", terr)
117 | }
118 | return nil, status.Errorf(codes.Internal, "failed to write wal on index %d: %v", req.Height, err)
119 | }
120 |
121 | if c.state.mode == twophase {
122 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck, Height: req.Height}, nil
123 | }
124 |
125 | go c.handleProposeTimeout(req.Height)
126 |
127 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck, Height: req.Height}, nil
128 | }
129 |
130 | func (c *CommitterImpl) handleProposeTimeout(height uint64) {
131 | timer := time.NewTimer(time.Duration(c.timeout) * time.Millisecond)
132 | defer timer.Stop()
133 | <-timer.C
134 |
135 | c.mu.Lock()
136 | defer c.mu.Unlock()
137 |
138 | currentState := c.state.getCurrentState()
139 | currentHeight := atomic.LoadUint64(&c.height)
140 |
141 | log.Debugf("propose timeout handler: state=%s, height=%d, index=%d", currentState, currentHeight, height)
142 |
143 | if currentState != proposeStage || currentHeight != height {
144 | log.Debugf("skipping propose timeout handling for height %d: state=%s, currentHeight=%d", height, currentState, currentHeight)
145 | return
146 | }
147 |
148 | if err := c.wal.Write(walrecord.AbortSlot(height), walrecord.KeyAbort, nil); err != nil {
149 | log.Errorf("failed to write skip record for height %d: %v", height, err)
150 | } else {
151 | log.Warnf("skip proposed message after timeout for height %d", height)
152 | }
153 | }
154 |
155 | // Precommit handles the precommit phase of the three-phase commit protocol.
156 | func (c *CommitterImpl) Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) {
157 | c.mu.Lock()
158 | defer c.mu.Unlock()
159 |
160 | currentHeight := atomic.LoadUint64(&c.height)
161 | if index != currentHeight {
162 | return nil, status.Errorf(codes.FailedPrecondition, "invalid precommit height: expected %d, got %d", currentHeight, index)
163 | }
164 |
165 | if c.state.getCurrentState() != proposeStage {
166 | return nil, status.Errorf(codes.FailedPrecondition, "precommit allowed only from propose state, current: %s", c.state.getCurrentState())
167 | }
168 |
169 | // check if already aborted
170 | if k, _, _ := c.wal.Get(walrecord.AbortSlot(index)); k == walrecord.KeyAbort {
171 | return nil, status.Errorf(codes.Aborted, "transaction %d was aborted", index)
172 | }
173 |
174 | // read prepared to ensure we have data
175 | pSlot := walrecord.PreparedSlot(index)
176 | k, val, err := c.wal.Get(pSlot)
177 | if err != nil {
178 | return nil, status.Errorf(codes.Internal, "failed to read prepared slot %d: %v", pSlot, err)
179 | }
180 | if k != walrecord.KeyPrepared {
181 | return nil, status.Errorf(codes.FailedPrecondition, "not prepared (key=%s)", k)
182 | }
183 |
184 | if err := c.wal.Write(walrecord.PrecommitSlot(index), walrecord.KeyPrecommit, val); err != nil {
185 | return nil, status.Errorf(codes.Internal, "failed to write precommit: %v", err)
186 | }
187 |
188 | c.state.Transition(precommitStage)
189 |
190 | go c.handlePrecommitTimeout(index)
191 |
192 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck, Height: index}, nil
193 | }
194 |
195 | func (c *CommitterImpl) handlePrecommitTimeout(height uint64) {
196 | timer := time.NewTimer(time.Duration(c.timeout) * time.Millisecond)
197 | defer timer.Stop()
198 | <-timer.C
199 |
200 | c.mu.Lock()
201 | defer c.mu.Unlock()
202 |
203 | currentState := c.state.getCurrentState()
204 | currentHeight := atomic.LoadUint64(&c.height)
205 |
206 | log.Debugf("precommit timeout handler: state=%s, height=%d, index=%d", currentState, currentHeight, height)
207 |
208 | if currentState != precommitStage || currentHeight != height {
209 | log.Debugf("skipping autocommit for height %d: state=%s, currentHeight=%d", height, currentState, currentHeight)
210 | return
211 | }
212 |
213 | // check abort
214 | if k, _, _ := c.wal.Get(walrecord.AbortSlot(height)); k == walrecord.KeyAbort {
215 | log.Infof("found abort record for height %d during precommit timeout", height)
216 | c.resetToPropose(height, "abort record found")
217 | return
218 | }
219 |
220 | // check data (prepared/precommit)
221 | key, value, err := c.wal.Get(walrecord.PreparedSlot(height))
222 | if err != nil {
223 | log.Errorf("failed to read WAL for height %d during precommit timeout: %v", height, err)
224 | c.resetToPropose(height, "WAL read error")
225 | return
226 | }
227 | if key != walrecord.KeyPrepared {
228 | log.Warnf("unexpected key at prepared slot: %s", key)
229 | c.resetToPropose(height, "invalid prepared key")
230 | return
231 | }
232 | if value == nil {
233 | log.Errorf("no data found in WAL for height %d during precommit timeout", height)
234 | c.resetToPropose(height, "no WAL data")
235 | return
236 | }
237 |
238 | log.Warnf("performing autocommit after precommit timeout for height %d", height)
239 |
240 | response, err := c.commit(height)
241 | if err != nil {
242 | log.Errorf("autocommit failed for height %d: %v", height, err)
243 | c.resetToPropose(height, "autocommit failed")
244 | return
245 | }
246 |
247 | if response != nil && response.ResponseType == dto.ResponseTypeNack {
248 | log.Warnf("autocommit returned NACK for height %d", height)
249 | c.resetToPropose(height, "autocommit NACK")
250 | return
251 | }
252 |
253 | log.Infof("successfully autocommitted height %d after precommit timeout", height)
254 | }
255 |
256 | func (c *CommitterImpl) resetToPropose(height uint64, reason string) {
257 | log.Debugf("resetting state to propose for height %d: %s", height, reason)
258 |
259 | currentState := c.state.getCurrentState()
260 |
261 | if c.state.GetMode() == threephase && currentState == precommitStage {
262 | if err := c.state.Transition(commitStage); err != nil {
263 | log.Errorf("failed to transition to commit state during reset for height %d: %v", height, err)
264 | return
265 | }
266 | }
267 |
268 | if err := c.state.Transition(proposeStage); err != nil {
269 | log.Errorf("failed to reset to propose state for height %d: %v (current: %s)", height, err, c.state.getCurrentState())
270 | }
271 | }
272 |
273 | // Commit handles the commit phase of the atomic commit protocol.
274 | func (c *CommitterImpl) Commit(ctx context.Context, req *dto.CommitRequest) (*dto.CohortResponse, error) {
275 | c.mu.Lock()
276 | defer c.mu.Unlock()
277 | return c.commit(req.Height)
278 | }
279 |
280 | func (c *CommitterImpl) commit(height uint64) (*dto.CohortResponse, error) {
281 | currentHeight := atomic.LoadUint64(&c.height)
282 |
283 | // idempotent commit: if already applied, return ACK without error
284 | if height < currentHeight {
285 | log.Debugf("commit for height %d already applied (current height: %d)", height, currentHeight)
286 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck, Height: height}, nil
287 | }
288 |
289 | // future height: reject
290 | if height > currentHeight {
291 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack, Height: currentHeight}, nil
292 | }
293 |
294 | // height == currentHeight: process the commit
295 | currentState := c.state.getCurrentState()
296 | expectedState := c.getExpectedCommitState()
297 |
298 | if currentState != expectedState {
299 | return nil, status.Errorf(codes.FailedPrecondition,
300 | "invalid state for commit: expected %s for %s mode, but current state is %s",
301 | expectedState, c.state.GetMode(), currentState)
302 | }
303 |
304 | // check if already Aborted
305 | if k, _, _ := c.wal.Get(walrecord.AbortSlot(height)); k == walrecord.KeyAbort {
306 | log.Warnf("rejecting commit for aborted height %d", height)
307 | c.resetToPropose(height, "wal aborted")
308 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack}, nil
309 | }
310 |
311 | if err := c.state.Transition(commitStage); err != nil {
312 | return nil, status.Errorf(codes.FailedPrecondition, "invalid state transition to commit: %v", err)
313 | }
314 |
315 | if !c.hookRegistry.ExecuteCommit(&dto.CommitRequest{Height: height}) {
316 | if terr := c.state.Transition(proposeStage); terr != nil {
317 | log.Errorf("failed to reset state after hook failure: %v", terr)
318 | }
319 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack}, nil
320 | }
321 |
322 | // retrieve payload (try Precommit, then Prepared)
323 | var payload []byte
324 | if k, v, err := c.wal.Get(walrecord.PrecommitSlot(height)); err == nil && k == walrecord.KeyPrecommit {
325 | payload = v
326 | } else if k, v, err := c.wal.Get(walrecord.PreparedSlot(height)); err == nil && k == walrecord.KeyPrepared {
327 | payload = v
328 | } else {
329 | c.resetToPropose(height, "no prepared/precommit data found")
330 | return nil, status.Errorf(codes.FailedPrecondition, "no prepared/precommit data found")
331 | }
332 |
333 | // decode
334 | walTx, err := walrecord.Decode(payload)
335 | if err != nil {
336 | return nil, status.Errorf(codes.Internal, "decode error: %v", err)
337 | }
338 |
339 | // write commit to WAL
340 | if err := c.wal.Write(walrecord.CommitSlot(height), walrecord.KeyCommit, payload); err != nil {
341 | c.resetToPropose(height, "wal write failed")
342 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack}, err
343 | }
344 |
345 | // 2. apply to storage
346 | if err := c.store.Put(walTx.Key, walTx.Value); err != nil {
347 | log.Errorf("CRITICAL: failed to apply committed tx to store: %v", err)
348 | return nil, err
349 | }
350 |
351 | atomic.StoreUint64(&c.height, currentHeight+1)
352 |
353 | if terr := c.state.Transition(proposeStage); terr != nil {
354 | log.Errorf("failed to transition back to propose state after successful commit: %v", terr)
355 | }
356 |
357 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck}, nil
358 | }
359 |
360 | func (c *CommitterImpl) getExpectedCommitState() string {
361 | if c.state.GetMode() == twophase {
362 | return proposeStage
363 | }
364 | return precommitStage
365 | }
366 |
367 | // Abort handles abort requests from coordinator.
368 | func (c *CommitterImpl) Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) {
369 | log.Warnf("received abort request for height %d: %s", req.Height, req.Reason)
370 |
371 | c.mu.Lock()
372 | defer c.mu.Unlock()
373 |
374 | currentHeight := atomic.LoadUint64(&c.height)
375 |
376 | if req.Height > currentHeight {
377 | log.Debugf("ignoring abort for future height %d (current: %d)", req.Height, currentHeight)
378 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck}, nil
379 | }
380 |
381 | if req.Height < currentHeight {
382 | log.Debugf("ignoring abort for past height %d (current: %d)", req.Height, currentHeight)
383 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck}, nil
384 | }
385 |
386 | log.Infof("processing abort for current height %d", req.Height)
387 |
388 | if err := c.wal.Write(walrecord.AbortSlot(req.Height), walrecord.KeyAbort, nil); err != nil {
389 | log.Errorf("failed to write abort record for aborted transaction at height %d: %v", req.Height, err)
390 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack}, err
391 | }
392 |
393 | log.Infof("successfully wrote tombstone record for height %d", req.Height)
394 |
395 | c.resetToPropose(req.Height, "abort request")
396 |
397 | log.Infof("successfully processed abort for height %d", req.Height)
398 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck}, nil
399 | }
400 |
--------------------------------------------------------------------------------
/core/cohort/commitalgo/committer_test.go:
--------------------------------------------------------------------------------
1 | package commitalgo
2 |
3 | import (
4 | "context"
5 | "path/filepath"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/require"
9 | "github.com/vadiminshakov/committer/core/cohort/commitalgo/hooks"
10 | "github.com/vadiminshakov/committer/core/dto"
11 | "github.com/vadiminshakov/committer/core/walrecord"
12 | "github.com/vadiminshakov/committer/io/store"
13 | "github.com/vadiminshakov/gowal"
14 | )
15 |
16 | // testHook for testing purposes
17 | type testHook struct {
18 | proposeResult bool
19 | commitResult bool
20 | proposeCalled bool
21 | commitCalled bool
22 | }
23 |
24 | func (t *testHook) OnPropose(req *dto.ProposeRequest) bool {
25 | t.proposeCalled = true
26 | return t.proposeResult
27 | }
28 |
29 | func (t *testHook) OnCommit(req *dto.CommitRequest) bool {
30 | t.commitCalled = true
31 | return t.commitResult
32 | }
33 |
34 | func openTestWAL(t *testing.T, walPath string) *gowal.Wal {
35 | walConfig := gowal.Config{
36 | Dir: walPath,
37 | Prefix: "test",
38 | SegmentThreshold: 1024,
39 | MaxSegments: 10,
40 | IsInSyncDiskMode: false,
41 | }
42 |
43 | wal, err := gowal.NewWAL(walConfig)
44 | require.NoError(t, err)
45 | t.Cleanup(func() { wal.Close() })
46 |
47 | return wal
48 | }
49 |
50 | func newStateStore(t *testing.T, wal *gowal.Wal) (*store.Store, *store.RecoveryState) {
51 | dbPath := filepath.Join(t.TempDir(), "badger")
52 | stateStore, recovery, err := store.New(wal, dbPath)
53 | require.NoError(t, err)
54 | t.Cleanup(func() { stateStore.Close() })
55 | return stateStore, recovery
56 | }
57 |
58 | func prepareCommitter(t *testing.T, walPath, commitType string, timeout uint64, hooks ...hooks.Hook) (*CommitterImpl, *store.Store, *gowal.Wal, *store.RecoveryState) {
59 | wal := openTestWAL(t, walPath)
60 | stateStore, recovery := newStateStore(t, wal)
61 |
62 | committer := NewCommitter(stateStore, commitType, wal, timeout, hooks...)
63 | committer.SetHeight(recovery.Height)
64 |
65 | return committer, stateStore, wal, recovery
66 | }
67 |
68 | func TestNewCommitter_DefaultHook(t *testing.T) {
69 | tempDir := t.TempDir()
70 | committer, _, _, _ := prepareCommitter(t, filepath.Join(tempDir, "wal"), "3pc", 5000)
71 |
72 | require.Equal(t, 1, committer.hookRegistry.Count(), "Expected 1 default hook")
73 | }
74 |
75 | func TestNewCommitter_CustomHooks(t *testing.T) {
76 | tempDir := t.TempDir()
77 | // test: create committer with custom hooks
78 | testHook1 := &testHook{proposeResult: true, commitResult: true}
79 | testHook2 := &testHook{proposeResult: true, commitResult: true}
80 |
81 | committer, _, _, _ := prepareCommitter(t, filepath.Join(tempDir, "wal"), "3pc", 5000, testHook1, testHook2)
82 |
83 | require.Equal(t, 2, committer.hookRegistry.Count(), "Expected 2 custom hooks")
84 | }
85 |
86 | func TestNewCommitter_DynamicHookRegistration(t *testing.T) {
87 | tempDir := t.TempDir()
88 | // test: create committer and add hooks dynamically
89 | committer, _, _, _ := prepareCommitter(t, filepath.Join(tempDir, "wal"), "3pc", 5000)
90 |
91 | require.Equal(t, 1, committer.hookRegistry.Count(), "Expected 1 default hook initially")
92 |
93 | // add hooks dynamically
94 | testHook := &testHook{proposeResult: true, commitResult: true}
95 | committer.RegisterHook(testHook)
96 |
97 | require.Equal(t, 2, committer.hookRegistry.Count(), "Expected 2 hooks after registration")
98 | }
99 |
100 | func TestNewCommitter_BuiltinHooks(t *testing.T) {
101 | tempDir := t.TempDir()
102 |
103 | metricsHook := hooks.NewMetricsHook()
104 | validationHook := hooks.NewValidationHook(100, 1024)
105 | auditHook := hooks.NewAuditHook("test_audit.log")
106 |
107 | committer, _, _, _ := prepareCommitter(t, filepath.Join(tempDir, "wal"), "3pc", 5000,
108 | metricsHook,
109 | validationHook,
110 | auditHook,
111 | )
112 |
113 | require.Equal(t, 3, committer.hookRegistry.Count(), "Expected 3 built-in hooks")
114 |
115 | proposeCount, commitCount, _ := metricsHook.GetStats()
116 | require.Equal(t, uint64(0), proposeCount, "Expected initial propose count to be 0")
117 | require.Equal(t, uint64(0), commitCount, "Expected initial commit count to be 0")
118 | }
119 | func TestCommit_StateValidation_2PC(t *testing.T) {
120 | tempDir := t.TempDir()
121 | walPath := filepath.Join(tempDir, "wal")
122 | wal := openTestWAL(t, walPath)
123 |
124 | stateStore, recovery := newStateStore(t, wal)
125 |
126 | // create 2PC committer
127 | committer := NewCommitter(stateStore, "two-phase", wal, 5000)
128 | committer.SetHeight(recovery.Height)
129 |
130 | require.Equal(t, "propose", committer.getCurrentState())
131 |
132 | // first, propose a transaction
133 | proposeReq := &dto.ProposeRequest{
134 | Height: 0,
135 | Key: "test-key",
136 | Value: []byte("test-value"),
137 | }
138 |
139 | _, err := committer.Propose(context.Background(), proposeReq)
140 | require.NoError(t, err)
141 |
142 | // should still be in propose state for 2PC
143 | require.Equal(t, "propose", committer.getCurrentState())
144 |
145 | // commit should work from propose state
146 | commitReq := &dto.CommitRequest{Height: 0}
147 | resp, err := committer.Commit(context.Background(), commitReq)
148 | require.NoError(t, err)
149 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
150 |
151 | // should return to propose state after commit
152 | require.Equal(t, "propose", committer.getCurrentState())
153 | }
154 |
155 | func TestCommit_StateValidation_3PC(t *testing.T) {
156 | tempDir := t.TempDir()
157 | walPath := filepath.Join(tempDir, "wal")
158 |
159 | wal := openTestWAL(t, walPath)
160 | stateStore, recovery := newStateStore(t, wal)
161 |
162 | // create 3PC committer
163 | committer := NewCommitter(stateStore, "three-phase", wal, 5000)
164 | committer.SetHeight(recovery.Height)
165 |
166 | require.Equal(t, "propose", committer.getCurrentState())
167 |
168 | // first, propose a transaction
169 | proposeReq := &dto.ProposeRequest{
170 | Height: 0,
171 | Key: "test-key",
172 | Value: []byte("test-value"),
173 | }
174 |
175 | _, err := committer.Propose(context.Background(), proposeReq)
176 | require.NoError(t, err)
177 |
178 | // should still be in propose state
179 | require.Equal(t, "propose", committer.getCurrentState())
180 |
181 | // commit should fail from propose state in 3PC mode
182 | commitReq := &dto.CommitRequest{Height: 0}
183 | _, err = committer.Commit(context.Background(), commitReq)
184 | require.Error(t, err)
185 | require.Contains(t, err.Error(), "invalid state for commit: expected precommit for three-phase mode, but current state is propose")
186 |
187 | // state should remain in propose after failed commit
188 | require.Equal(t, "propose", committer.getCurrentState())
189 |
190 | // now go through proper 3PC flow: propose -> precommit -> commit
191 | _, err = committer.Precommit(context.Background(), 0)
192 | require.NoError(t, err)
193 | require.Equal(t, "precommit", committer.getCurrentState())
194 |
195 | // now commit should work from precommit state
196 | resp, err := committer.Commit(context.Background(), commitReq)
197 | require.NoError(t, err)
198 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
199 |
200 | // should return to propose state after successful commit
201 | require.Equal(t, "propose", committer.getCurrentState())
202 | }
203 |
204 | func TestCommit_StateRestoration_OnErrors(t *testing.T) {
205 | t.Run("Hook failure", func(t *testing.T) {
206 | tempDir := t.TempDir()
207 | walPath := filepath.Join(tempDir, "wal")
208 | wal := openTestWAL(t, walPath)
209 |
210 | stateStore, recovery := newStateStore(t, wal)
211 |
212 | // create 3PC committer with a hook that will fail commit
213 | failingHook := &testHook{proposeResult: true, commitResult: false}
214 | committer := NewCommitter(stateStore, "three-phase", wal, 5000, failingHook)
215 | committer.SetHeight(recovery.Height)
216 |
217 | // go through proper 3PC flow: propose -> precommit
218 | proposeReq := &dto.ProposeRequest{
219 | Height: 0,
220 | Key: "test-key",
221 | Value: []byte("test-value"),
222 | }
223 |
224 | _, err := committer.Propose(context.Background(), proposeReq)
225 | require.NoError(t, err)
226 |
227 | _, err = committer.Precommit(context.Background(), 0)
228 | require.NoError(t, err)
229 | require.Equal(t, "precommit", committer.getCurrentState())
230 |
231 | // commit should fail due to hook, but state should be restored
232 | commitReq := &dto.CommitRequest{Height: 0}
233 | resp, err := committer.Commit(context.Background(), commitReq)
234 | require.NoError(t, err)
235 | require.Equal(t, dto.ResponseTypeNack, resp.ResponseType)
236 |
237 | // state should be restored to propose after failed commit
238 | require.Equal(t, "propose", committer.getCurrentState())
239 | // height should not be incremented on hook failure
240 | require.Equal(t, uint64(0), committer.Height())
241 | })
242 |
243 | t.Run("Skip record (tombstone)", func(t *testing.T) {
244 | tempDir := t.TempDir()
245 | walPath := filepath.Join(tempDir, "wal")
246 | wal := openTestWAL(t, walPath)
247 |
248 | stateStore, recovery := newStateStore(t, wal)
249 |
250 | // create 3PC committer
251 | committer := NewCommitter(stateStore, "three-phase", wal, 5000)
252 | committer.SetHeight(recovery.Height)
253 |
254 | // first propose a normal transaction
255 | proposeReq := &dto.ProposeRequest{
256 | Height: 0,
257 | Key: "test-key",
258 | Value: []byte("test-value"),
259 | }
260 |
261 | _, err := committer.Propose(context.Background(), proposeReq)
262 | require.NoError(t, err)
263 |
264 | _, err = committer.Precommit(context.Background(), 0)
265 | require.NoError(t, err)
266 | require.Equal(t, "precommit", committer.getCurrentState())
267 |
268 | // simulate abort by writing abort record to WAL (this would normally be done by Abort method)
269 | err = wal.Write(walrecord.AbortSlot(0), walrecord.KeyAbort, nil)
270 | require.NoError(t, err)
271 |
272 | // commit should detect skip record and restore state
273 | commitReq := &dto.CommitRequest{Height: 0}
274 | resp, err := committer.Commit(context.Background(), commitReq)
275 | require.NoError(t, err)
276 | require.Equal(t, dto.ResponseTypeNack, resp.ResponseType)
277 |
278 | // state should be restored to propose after detecting skip record
279 | require.Equal(t, "propose", committer.getCurrentState())
280 | // height should not be incremented for skip records
281 | require.Equal(t, uint64(0), committer.Height())
282 |
283 | // data should not be in database (commit was skipped)
284 | value, err := stateStore.Get("test-key")
285 | require.Error(t, err) // should not exist
286 | require.Nil(t, value)
287 | })
288 | }
289 |
290 | func TestGetExpectedCommitState(t *testing.T) {
291 | tempDir := t.TempDir()
292 | walPath := filepath.Join(tempDir, "wal")
293 | wal := openTestWAL(t, walPath)
294 |
295 | stateStore, _ := newStateStore(t, wal)
296 |
297 | // test 2PC mode
298 | committer2PC := NewCommitter(stateStore, "two-phase", wal, 5000)
299 | require.Equal(t, "propose", committer2PC.getExpectedCommitState())
300 |
301 | // test 3PC mode
302 | committer3PC := NewCommitter(stateStore, "three-phase", wal, 5000)
303 | require.Equal(t, "precommit", committer3PC.getExpectedCommitState())
304 | }
305 | func TestPrecommitTimeout_StateValidation(t *testing.T) {
306 | tempDir := t.TempDir()
307 | walPath := filepath.Join(tempDir, "wal")
308 | wal := openTestWAL(t, walPath)
309 |
310 | stateStore, recovery := newStateStore(t, wal)
311 |
312 | // сreate 3PC committer with short timeout for testing
313 | committer := NewCommitter(stateStore, "three-phase", wal, 50) // 50ms timeout
314 | committer.SetHeight(recovery.Height)
315 |
316 | // test case 1: should skip autocommit when in commit state
317 | // first go to precommit, then to commit
318 | committer.state.Transition(precommitStage)
319 | committer.state.Transition(commitStage)
320 | committer.handlePrecommitTimeout(0)
321 | require.Equal(t, "commit", committer.getCurrentState()) // state unchanged
322 |
323 | // test case 2: should skip autocommit when in propose state
324 | // transition back to propose (commit -> propose is valid)
325 | committer.state.Transition(proposeStage)
326 | committer.handlePrecommitTimeout(0)
327 | require.Equal(t, "propose", committer.getCurrentState()) // state unchanged
328 |
329 | // test case 3: should skip autocommit when height doesn't match
330 | committer.state.Transition(precommitStage)
331 | committer.handlePrecommitTimeout(999) // wrong height
332 | require.Equal(t, "precommit", committer.getCurrentState()) // state unchanged
333 | }
334 |
335 | func TestPrecommitTimeout_AutocommitSuccess(t *testing.T) {
336 | tempDir := t.TempDir()
337 | walPath := filepath.Join(tempDir, "wal")
338 | wal := openTestWAL(t, walPath)
339 |
340 | stateStore, recovery := newStateStore(t, wal)
341 |
342 | // create 3PC committer
343 | committer := NewCommitter(stateStore, "three-phase", wal, 50)
344 | committer.SetHeight(recovery.Height)
345 |
346 | ctx := context.Background()
347 |
348 | // first propose to get data in WAL
349 | proposeReq := &dto.ProposeRequest{
350 | Height: 0,
351 | Key: "test-key",
352 | Value: []byte("test-value"),
353 | }
354 | _, err := committer.Propose(ctx, proposeReq)
355 | require.NoError(t, err)
356 |
357 | // move to precommit state
358 | _, err = committer.Precommit(ctx, 0)
359 | require.NoError(t, err)
360 | require.Equal(t, "precommit", committer.getCurrentState())
361 |
362 | // test successful autocommit
363 | committer.handlePrecommitTimeout(0)
364 |
365 | // should be back in propose state after successful autocommit
366 | require.Equal(t, "propose", committer.getCurrentState())
367 | require.Equal(t, uint64(1), committer.Height()) // height should be incremented
368 | }
369 |
370 | func TestPrecommitTimeout_AutocommitWithSkipRecord(t *testing.T) {
371 | tempDir := t.TempDir()
372 | walPath := filepath.Join(tempDir, "wal")
373 | wal := openTestWAL(t, walPath)
374 |
375 | stateStore, recovery := newStateStore(t, wal)
376 |
377 | // create 3PC committer
378 | committer := NewCommitter(stateStore, "three-phase", wal, 50)
379 | committer.SetHeight(recovery.Height)
380 |
381 | // write abort record directly to WAL
382 | err := wal.Write(walrecord.AbortSlot(0), walrecord.KeyAbort, nil)
383 | require.NoError(t, err)
384 |
385 | // move to precommit state
386 | committer.state.Transition(precommitStage)
387 | require.Equal(t, "precommit", committer.getCurrentState())
388 |
389 | // test autocommit with skip record - should recover to propose
390 | committer.handlePrecommitTimeout(0)
391 |
392 | // should be back in propose state after recovery
393 | require.Equal(t, "propose", committer.getCurrentState())
394 | require.Equal(t, uint64(0), committer.Height()) // Hhight should not be incremented for skip
395 | }
396 |
397 | func TestPrecommitTimeout_AutocommitFailure(t *testing.T) {
398 | tempDir := t.TempDir()
399 | walPath := filepath.Join(tempDir, "wal")
400 | wal := openTestWAL(t, walPath)
401 |
402 | stateStore, recovery := newStateStore(t, wal)
403 |
404 | // create 3PC committer with failing hook
405 | failingHook := &testHook{proposeResult: true, commitResult: false}
406 | committer := NewCommitter(stateStore, "three-phase", wal, 50, failingHook)
407 | committer.SetHeight(recovery.Height)
408 |
409 | ctx := context.Background()
410 |
411 | // first propose to get data in WAL
412 | proposeReq := &dto.ProposeRequest{
413 | Height: 0,
414 | Key: "test-key",
415 | Value: []byte("test-value"),
416 | }
417 | _, err := committer.Propose(ctx, proposeReq)
418 | require.NoError(t, err)
419 |
420 | // move to precommit state
421 | _, err = committer.Precommit(ctx, 0)
422 | require.NoError(t, err)
423 | require.Equal(t, "precommit", committer.getCurrentState())
424 |
425 | // test autocommit failure - should recover to propose
426 | committer.handlePrecommitTimeout(0)
427 |
428 | // should be back in propose state after recovery from failed autocommit
429 | require.Equal(t, "propose", committer.getCurrentState())
430 | require.Equal(t, uint64(0), committer.Height()) // height should not be incremented on failure
431 | }
432 |
433 | func TestPrecommitTimeout_NoDataInWAL(t *testing.T) {
434 | tempDir := t.TempDir()
435 | walPath := filepath.Join(tempDir, "wal")
436 | wal := openTestWAL(t, walPath)
437 |
438 | stateStore, recovery := newStateStore(t, wal)
439 |
440 | // create 3PC committer
441 | committer := NewCommitter(stateStore, "three-phase", wal, 50)
442 | committer.SetHeight(recovery.Height)
443 |
444 | // set up state without data in WAL
445 |
446 | // move to precommit state without proposing first
447 | committer.state.Transition(precommitStage)
448 | require.Equal(t, "precommit", committer.getCurrentState())
449 |
450 | // test autocommit with no data in WAL - should recover to propose
451 | committer.handlePrecommitTimeout(0)
452 |
453 | // should be back in propose state after recovery
454 | require.Equal(t, "propose", committer.getCurrentState())
455 | require.Equal(t, uint64(0), committer.Height()) // height should not be incremented
456 | }
457 |
458 | func TestRecoverToPropose(t *testing.T) {
459 | tempDir := t.TempDir()
460 | walPath := filepath.Join(tempDir, "wal")
461 | wal := openTestWAL(t, walPath)
462 |
463 | stateStore, recovery := newStateStore(t, wal)
464 |
465 | // create 3PC committer
466 | committer := NewCommitter(stateStore, "three-phase", wal, 50)
467 | committer.SetHeight(recovery.Height)
468 |
469 | // test recovery from precommit state
470 | committer.state.Transition(precommitStage)
471 | require.Equal(t, "precommit", committer.getCurrentState())
472 |
473 | committer.resetToPropose(0, "test")
474 | require.Equal(t, "propose", committer.getCurrentState())
475 |
476 | // test recovery from commit state (should transition to propose)
477 | // first to precommit, then to commit
478 | committer.state.Transition(precommitStage)
479 | committer.state.Transition(commitStage)
480 | require.Equal(t, "commit", committer.getCurrentState())
481 |
482 | committer.resetToPropose(0, "test")
483 | // should transition to propose since commit -> propose is valid in FSM
484 | require.Equal(t, "propose", committer.getCurrentState())
485 | }
486 |
487 | func TestAbort_CurrentHeight(t *testing.T) {
488 | tempDir := t.TempDir()
489 | walPath := filepath.Join(tempDir, "wal")
490 | wal := openTestWAL(t, walPath)
491 |
492 | stateStore, recovery := newStateStore(t, wal)
493 |
494 | // create 3PC committer
495 | committer := NewCommitter(stateStore, "three-phase", wal, 5000)
496 | committer.SetHeight(recovery.Height)
497 |
498 | // set up a transaction at current height
499 | ctx := context.Background()
500 | proposeReq := &dto.ProposeRequest{
501 | Height: 0,
502 | Key: "test-key",
503 | Value: []byte("test-value"),
504 | }
505 |
506 | _, err := committer.Propose(ctx, proposeReq)
507 | require.NoError(t, err)
508 |
509 | // move to precommit state
510 | _, err = committer.Precommit(ctx, 0)
511 | require.NoError(t, err)
512 | require.Equal(t, "precommit", committer.getCurrentState())
513 |
514 | // test abort for current height
515 | abortReq := &dto.AbortRequest{
516 | Height: 0,
517 | Reason: "Test abort",
518 | }
519 |
520 | resp, err := committer.Abort(ctx, abortReq)
521 | require.NoError(t, err)
522 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
523 |
524 | // should be back in propose state after abort
525 | require.Equal(t, "propose", committer.getCurrentState())
526 |
527 | // should have the original data in WAL (Prepared)
528 | k, val, err := wal.Get(walrecord.PreparedSlot(0))
529 | require.NoError(t, err)
530 | require.Equal(t, walrecord.KeyPrepared, k)
531 |
532 | // verify payload
533 | walTx, err := walrecord.Decode(val)
534 | require.NoError(t, err)
535 | require.Equal(t, "test-key", walTx.Key)
536 |
537 | // verify Abort
538 | k, _, err = wal.Get(walrecord.AbortSlot(0))
539 | require.NoError(t, err)
540 | require.Equal(t, walrecord.KeyAbort, k)
541 |
542 | value, err := stateStore.Get("test-key")
543 | require.Error(t, err, "Original value should not exist in database after abort")
544 | require.Nil(t, value)
545 | }
546 |
547 | func TestAbort_FutureHeight(t *testing.T) {
548 | tempDir := t.TempDir()
549 | walPath := filepath.Join(tempDir, "wal")
550 | wal := openTestWAL(t, walPath)
551 |
552 | stateStore, recovery := newStateStore(t, wal)
553 |
554 | // create committer
555 | committer := NewCommitter(stateStore, "three-phase", wal, 5000)
556 | committer.SetHeight(recovery.Height)
557 |
558 | // test abort for future height (should be ignored)
559 | ctx := context.Background()
560 | abortReq := &dto.AbortRequest{
561 | Height: 10, // future height
562 | Reason: "Test abort future",
563 | }
564 |
565 | resp, err := committer.Abort(ctx, abortReq)
566 | require.NoError(t, err)
567 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
568 |
569 | // state should remain unchanged
570 | require.Equal(t, "propose", committer.getCurrentState())
571 | require.Equal(t, uint64(0), committer.Height())
572 |
573 | key, val, err := wal.Get(0)
574 | require.NoError(t, err)
575 | require.Equal(t, "", key)
576 | require.Nil(t, val, "WAL should not have entry for current height without propose")
577 | }
578 |
579 | func TestAbort_PastHeight(t *testing.T) {
580 | tempDir := t.TempDir()
581 | walPath := filepath.Join(tempDir, "wal")
582 | wal := openTestWAL(t, walPath)
583 |
584 | stateStore, recovery := newStateStore(t, wal)
585 |
586 | // create committer and advance height
587 | committer := NewCommitter(stateStore, "two-phase", wal, 5000)
588 | committer.SetHeight(recovery.Height)
589 |
590 | // complete a transaction to advance height
591 | ctx := context.Background()
592 | proposeReq := &dto.ProposeRequest{
593 | Height: 0,
594 | Key: "test-key",
595 | Value: []byte("test-value"),
596 | }
597 |
598 | _, err := committer.Propose(ctx, proposeReq)
599 | require.NoError(t, err)
600 |
601 | commitReq := &dto.CommitRequest{Height: 0}
602 | _, err = committer.Commit(ctx, commitReq)
603 | require.NoError(t, err)
604 |
605 | // height should now be 1
606 | require.Equal(t, uint64(1), committer.Height())
607 |
608 | // test abort for past height (should be ignored)
609 | abortReq := &dto.AbortRequest{
610 | Height: 0, // Past height
611 | Reason: "Test abort past",
612 | }
613 |
614 | resp, err := committer.Abort(ctx, abortReq)
615 | require.NoError(t, err)
616 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
617 |
618 | // state should remain unchanged
619 | require.Equal(t, "propose", committer.getCurrentState())
620 | require.Equal(t, uint64(1), committer.Height())
621 |
622 | // check wal unchanged
623 | // check Prepared
624 | k, _, err := wal.Get(walrecord.PreparedSlot(0))
625 | require.NoError(t, err)
626 | require.Equal(t, walrecord.KeyPrepared, k)
627 |
628 | // check Commit
629 | k, _, err = wal.Get(walrecord.CommitSlot(0))
630 | require.NoError(t, err)
631 | require.Equal(t, walrecord.KeyCommit, k)
632 |
633 | // check normal data is in db
634 | value, err := stateStore.Get("test-key")
635 | require.NoError(t, err, "Data should exist in database for past committed transaction")
636 | require.Equal(t, "test-value", string(value))
637 | }
638 |
639 | func TestAbort_StateRecovery_3PC(t *testing.T) {
640 | tempDir := t.TempDir()
641 | walPath := filepath.Join(tempDir, "wal")
642 | wal := openTestWAL(t, walPath)
643 |
644 | stateStore, recovery := newStateStore(t, wal)
645 |
646 | // create 3PC committer
647 | committer := NewCommitter(stateStore, "three-phase", wal, 5000)
648 | committer.SetHeight(recovery.Height)
649 |
650 | // set up transaction and move to precommit state
651 | ctx := context.Background()
652 | proposeReq := &dto.ProposeRequest{
653 | Height: 0,
654 | Key: "test-key",
655 | Value: []byte("test-value"),
656 | }
657 |
658 | _, err := committer.Propose(ctx, proposeReq)
659 | require.NoError(t, err)
660 |
661 | _, err = committer.Precommit(ctx, 0)
662 | require.NoError(t, err)
663 | require.Equal(t, "precommit", committer.getCurrentState())
664 |
665 | // test abort from precommit state
666 | abortReq := &dto.AbortRequest{
667 | Height: 0,
668 | Reason: "Test 3PC abort",
669 | }
670 |
671 | resp, err := committer.Abort(ctx, abortReq)
672 | require.NoError(t, err)
673 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
674 |
675 | // should be back in propose state
676 | require.Equal(t, "propose", committer.getCurrentState())
677 |
678 | // check wal
679 | k, _, err := wal.Get(walrecord.AbortSlot(0))
680 | require.NoError(t, err)
681 | require.Equal(t, walrecord.KeyAbort, k)
682 | }
683 |
684 | func TestAbort_StateRecovery_2PC(t *testing.T) {
685 | tempDir := t.TempDir()
686 | walPath := filepath.Join(tempDir, "wal")
687 | wal := openTestWAL(t, walPath)
688 |
689 | stateStore, recovery := newStateStore(t, wal)
690 |
691 | // create 2PC committer
692 | committer := NewCommitter(stateStore, "two-phase", wal, 5000)
693 | committer.SetHeight(recovery.Height)
694 |
695 | // set up transaction (in 2PC, we stay in propose state)
696 | ctx := context.Background()
697 | proposeReq := &dto.ProposeRequest{
698 | Height: 0,
699 | Key: "test-key",
700 | Value: []byte("test-value"),
701 | }
702 |
703 | _, err := committer.Propose(ctx, proposeReq)
704 | require.NoError(t, err)
705 | require.Equal(t, "propose", committer.getCurrentState())
706 |
707 | // test abort from propose state in 2PC
708 | abortReq := &dto.AbortRequest{
709 | Height: 0,
710 | Reason: "Test 2PC abort",
711 | }
712 |
713 | resp, err := committer.Abort(ctx, abortReq)
714 | require.NoError(t, err)
715 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType)
716 |
717 | // should remain in propose state
718 | require.Equal(t, "propose", committer.getCurrentState())
719 |
720 | // check wal
721 | k, _, err := wal.Get(walrecord.AbortSlot(0))
722 | require.NoError(t, err)
723 | require.Equal(t, walrecord.KeyAbort, k)
724 | }
725 |
--------------------------------------------------------------------------------
/chaos_test.go:
--------------------------------------------------------------------------------
1 | //go:build chaos
2 |
3 | // To run this tests you need to install toxiproxy
4 | //
5 | // # macOS/Linux
6 | // curl -L -o toxiproxy-server https://github.com/Shopify/toxiproxy/releases/download/v2.12.0/toxiproxy-server-darwin-amd64
7 | // chmod +x toxiproxy-server
8 | // mv toxiproxy-server ~/go/bin/
9 | //
10 | // And then run `make test-chaos`
11 | package main
12 |
13 | import (
14 | "context"
15 | "fmt"
16 | "os"
17 | "path/filepath"
18 | "strconv"
19 | "testing"
20 | "time"
21 |
22 | log "github.com/sirupsen/logrus"
23 | "github.com/stretchr/testify/require"
24 | "github.com/vadiminshakov/committer/core/cohort"
25 | "github.com/vadiminshakov/committer/core/cohort/commitalgo"
26 | "github.com/vadiminshakov/committer/core/coordinator"
27 | "github.com/vadiminshakov/committer/io/gateway/grpc/client"
28 | pb "github.com/vadiminshakov/committer/io/gateway/grpc/proto"
29 | "github.com/vadiminshakov/committer/io/gateway/grpc/server"
30 | "github.com/vadiminshakov/committer/io/store"
31 | "github.com/vadiminshakov/gowal"
32 | )
33 |
34 | const TOXIPROXY_URL = "http://localhost:8474"
35 |
36 | func TestChaosFollowerFailure(t *testing.T) {
37 | log.SetLevel(log.InfoLevel)
38 |
39 | // immediate connection reset
40 | t.Run("immediate_reset", func(t *testing.T) {
41 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
42 | defer chaosHelper.cleanup()
43 |
44 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
45 | for _, node := range nodes[COHORT_TYPE] {
46 | allAddresses = append(allAddresses, node.Nodeaddr)
47 | }
48 | for _, node := range nodes[COORDINATOR_TYPE] {
49 | allAddresses = append(allAddresses, node.Nodeaddr)
50 | }
51 |
52 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
53 | require.NoError(t, chaosHelper.addResetPeer(nodes[COHORT_TYPE][0].Nodeaddr, 0))
54 |
55 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
56 | defer canceller()
57 |
58 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
59 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
60 | coordAddr = proxyAddr
61 | }
62 |
63 | c, err := client.NewClientAPI(coordAddr)
64 | require.NoError(t, err)
65 |
66 | _, err = c.Put(context.Background(), "reset_test", []byte("value"))
67 | require.Error(t, err)
68 | require.Contains(t, err.Error(), "failed to send propose")
69 | })
70 |
71 | // connection drops after 10 bytes of data
72 | t.Run("cohort failure after 10 bytes", func(t *testing.T) {
73 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
74 | defer chaosHelper.cleanup()
75 |
76 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
77 | for _, node := range nodes[COHORT_TYPE] {
78 | allAddresses = append(allAddresses, node.Nodeaddr)
79 | }
80 | for _, node := range nodes[COORDINATOR_TYPE] {
81 | allAddresses = append(allAddresses, node.Nodeaddr)
82 | }
83 |
84 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
85 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 10)) // 10 bytes
86 |
87 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
88 | defer canceller()
89 |
90 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
91 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
92 | coordAddr = proxyAddr
93 | }
94 |
95 | c, err := client.NewClientAPI(coordAddr)
96 | require.NoError(t, err)
97 |
98 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value"))
99 | require.Error(t, err)
100 | require.Contains(t, err.Error(), "failed to send propose")
101 | })
102 |
103 | // connection drops after 50 bytes of data
104 | t.Run("cohort failure after 50 bytes", func(t *testing.T) {
105 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
106 | defer chaosHelper.cleanup()
107 |
108 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
109 | for _, node := range nodes[COHORT_TYPE] {
110 | allAddresses = append(allAddresses, node.Nodeaddr)
111 | }
112 | for _, node := range nodes[COORDINATOR_TYPE] {
113 | allAddresses = append(allAddresses, node.Nodeaddr)
114 | }
115 |
116 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
117 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 50)) // 50 bytes
118 |
119 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
120 | defer canceller()
121 |
122 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
123 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
124 | coordAddr = proxyAddr
125 | }
126 |
127 | c, err := client.NewClientAPI(coordAddr)
128 | require.NoError(t, err)
129 |
130 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value"))
131 | require.Error(t, err)
132 | require.Contains(t, err.Error(), "failed to send propose")
133 | })
134 |
135 | // connection drops after 100 bytes of data
136 | t.Run("cohort failure after 100 bytes", func(t *testing.T) {
137 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
138 | defer chaosHelper.cleanup()
139 |
140 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
141 | for _, node := range nodes[COHORT_TYPE] {
142 | allAddresses = append(allAddresses, node.Nodeaddr)
143 | }
144 | for _, node := range nodes[COORDINATOR_TYPE] {
145 | allAddresses = append(allAddresses, node.Nodeaddr)
146 | }
147 |
148 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
149 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 100)) // 100 bytes
150 |
151 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
152 | defer canceller()
153 |
154 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
155 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
156 | coordAddr = proxyAddr
157 | }
158 |
159 | c, err := client.NewClientAPI(coordAddr)
160 | require.NoError(t, err)
161 |
162 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value"))
163 | require.Error(t, err)
164 | require.Contains(t, err.Error(), "failed to send propose")
165 | })
166 |
167 | // connection drops after 150 bytes of data
168 | t.Run("cohort failure after 150 bytes", func(t *testing.T) {
169 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
170 | defer chaosHelper.cleanup()
171 |
172 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
173 | for _, node := range nodes[COHORT_TYPE] {
174 | allAddresses = append(allAddresses, node.Nodeaddr)
175 | }
176 | for _, node := range nodes[COORDINATOR_TYPE] {
177 | allAddresses = append(allAddresses, node.Nodeaddr)
178 | }
179 |
180 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
181 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 150)) // 150 bytes
182 |
183 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
184 | defer canceller()
185 |
186 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
187 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
188 | coordAddr = proxyAddr
189 | }
190 |
191 | c, err := client.NewClientAPI(coordAddr)
192 | require.NoError(t, err)
193 |
194 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value"))
195 | require.Error(t, err)
196 | require.Contains(t, err.Error(), "failed to send precommit")
197 | })
198 |
199 | // connection drops after 200 bytes of data
200 | t.Run("cohort failure after 200 bytes", func(t *testing.T) {
201 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
202 | defer chaosHelper.cleanup()
203 |
204 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
205 | for _, node := range nodes[COHORT_TYPE] {
206 | allAddresses = append(allAddresses, node.Nodeaddr)
207 | }
208 | for _, node := range nodes[COORDINATOR_TYPE] {
209 | allAddresses = append(allAddresses, node.Nodeaddr)
210 | }
211 |
212 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
213 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 200)) // 200 bytes
214 |
215 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
216 | defer canceller()
217 |
218 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
219 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
220 | coordAddr = proxyAddr
221 | }
222 |
223 | c, err := client.NewClientAPI(coordAddr)
224 | require.NoError(t, err)
225 |
226 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value"))
227 | require.Error(t, err)
228 | require.Contains(t, err.Error(), "failed to send precommit")
229 | })
230 |
231 | // connection drops after 250 bytes of data
232 | t.Run("cohort failure after 250 bytes", func(t *testing.T) {
233 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
234 | defer chaosHelper.cleanup()
235 |
236 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
237 | for _, node := range nodes[COHORT_TYPE] {
238 | allAddresses = append(allAddresses, node.Nodeaddr)
239 | }
240 | for _, node := range nodes[COORDINATOR_TYPE] {
241 | allAddresses = append(allAddresses, node.Nodeaddr)
242 | }
243 |
244 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
245 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 250)) // 250 bytes
246 |
247 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
248 | defer canceller()
249 |
250 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
251 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
252 | coordAddr = proxyAddr
253 | }
254 |
255 | c, err := client.NewClientAPI(coordAddr)
256 | require.NoError(t, err)
257 |
258 | _, err = c.Put(context.Background(), "commit_fail_test", []byte("test_value_250"))
259 | if err != nil {
260 | require.Contains(t, err.Error(), "failed to send commit")
261 | // if operation failed, check that value was committed on healthy nodes
262 | if checkValueOnCoordinator(t, "commit_fail_test", []byte("test_value_250")) {
263 | checkValueOnCohorts(t, "commit_fail_test", []byte("test_value_250"), 0) // skip failed cohort (index 0)
264 | checkValueNotOnNode(t, nodes[COHORT_TYPE][0].Nodeaddr, "commit_fail_test")
265 | }
266 | } else {
267 | // if operation succeeded despite limits, all nodes should have the value
268 | t.Log("operation succeeded despite network limits (250 bytes was sufficient)")
269 | checkValueOnCoordinator(t, "commit_fail_test", []byte("test_value_250"))
270 | checkValueOnAllCohorts(t, "commit_fail_test", []byte("test_value_250"))
271 | }
272 | })
273 |
274 | // connection drops after 500 bytes of data
275 | t.Run("cohort failure after 500 bytes", func(t *testing.T) {
276 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
277 | defer chaosHelper.cleanup()
278 |
279 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
280 | for _, node := range nodes[COHORT_TYPE] {
281 | allAddresses = append(allAddresses, node.Nodeaddr)
282 | }
283 | for _, node := range nodes[COORDINATOR_TYPE] {
284 | allAddresses = append(allAddresses, node.Nodeaddr)
285 | }
286 |
287 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
288 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 500)) // 500 bytes
289 |
290 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
291 | defer canceller()
292 |
293 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
294 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
295 | coordAddr = proxyAddr
296 | }
297 |
298 | c, err := client.NewClientAPI(coordAddr)
299 | require.NoError(t, err)
300 |
301 | _, err = c.Put(context.Background(), "success_test", []byte("test_value_500"))
302 | require.NoError(t, err)
303 |
304 | // check if value was committed on all nodes
305 | checkValueOnCoordinator(t, "success_test", []byte("test_value_500"))
306 | checkValueOnAllCohorts(t, "success_test", []byte("test_value_500"))
307 | })
308 | }
309 |
310 | func TestChaosCoordinatorFailure(t *testing.T) {
311 | log.SetLevel(log.InfoLevel)
312 |
313 | // immediate connection reset
314 | t.Run("coordinator_immediate_reset", func(t *testing.T) {
315 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
316 | defer chaosHelper.cleanup()
317 |
318 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
319 | for _, node := range nodes[COHORT_TYPE] {
320 | allAddresses = append(allAddresses, node.Nodeaddr)
321 | }
322 | for _, node := range nodes[COORDINATOR_TYPE] {
323 | allAddresses = append(allAddresses, node.Nodeaddr)
324 | }
325 |
326 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
327 | require.NoError(t, chaosHelper.addResetPeer(nodes[COORDINATOR_TYPE][1].Nodeaddr, 0))
328 |
329 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
330 | defer canceller()
331 |
332 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
333 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
334 | coordAddr = proxyAddr
335 | }
336 |
337 | c, err := client.NewClientAPI(coordAddr)
338 | require.NoError(t, err)
339 |
340 | _, err = c.Put(context.Background(), "coord_reset_test", []byte("value"))
341 | require.Error(t, err)
342 | require.Contains(t, err.Error(), "connection closed before server preface received")
343 | })
344 |
345 | // coordinator fails after 50 bytes
346 | t.Run("coordinator_failure_after_50_bytes", func(t *testing.T) {
347 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
348 | defer chaosHelper.cleanup()
349 |
350 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
351 | for _, node := range nodes[COHORT_TYPE] {
352 | allAddresses = append(allAddresses, node.Nodeaddr)
353 | }
354 | for _, node := range nodes[COORDINATOR_TYPE] {
355 | allAddresses = append(allAddresses, node.Nodeaddr)
356 | }
357 |
358 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
359 | require.NoError(t, chaosHelper.addDataLimit(nodes[COORDINATOR_TYPE][1].Nodeaddr, 50))
360 |
361 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
362 | defer canceller()
363 |
364 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
365 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
366 | coordAddr = proxyAddr
367 | }
368 |
369 | c, err := client.NewClientAPI(coordAddr)
370 | require.NoError(t, err)
371 |
372 | _, err = c.Put(context.Background(), "coord_50_test", []byte("value"))
373 | require.Error(t, err)
374 | })
375 |
376 | // coordinator fails after 100 bytes
377 | t.Run("coordinator_failure_after_100_bytes", func(t *testing.T) {
378 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
379 | defer chaosHelper.cleanup()
380 |
381 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
382 | for _, node := range nodes[COHORT_TYPE] {
383 | allAddresses = append(allAddresses, node.Nodeaddr)
384 | }
385 | for _, node := range nodes[COORDINATOR_TYPE] {
386 | allAddresses = append(allAddresses, node.Nodeaddr)
387 | }
388 |
389 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
390 | require.NoError(t, chaosHelper.addDataLimit(nodes[COORDINATOR_TYPE][1].Nodeaddr, 100))
391 |
392 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
393 | defer canceller()
394 |
395 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
396 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
397 | coordAddr = proxyAddr
398 | }
399 |
400 | c, err := client.NewClientAPI(coordAddr)
401 | require.NoError(t, err)
402 |
403 | _, err = c.Put(context.Background(), "coord_100_test", []byte("value"))
404 | require.Error(t, err)
405 | })
406 |
407 | // coordinator fails after 200 bytes
408 | t.Run("coordinator_failure_after_200_bytes", func(t *testing.T) {
409 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
410 | defer chaosHelper.cleanup()
411 |
412 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
413 | for _, node := range nodes[COHORT_TYPE] {
414 | allAddresses = append(allAddresses, node.Nodeaddr)
415 | }
416 | for _, node := range nodes[COORDINATOR_TYPE] {
417 | allAddresses = append(allAddresses, node.Nodeaddr)
418 | }
419 |
420 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
421 | require.NoError(t, chaosHelper.addDataLimit(nodes[COORDINATOR_TYPE][1].Nodeaddr, 200))
422 |
423 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
424 | defer canceller()
425 |
426 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
427 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
428 | coordAddr = proxyAddr
429 | }
430 |
431 | c, err := client.NewClientAPI(coordAddr)
432 | require.NoError(t, err)
433 |
434 | _, err = c.Put(context.Background(), "coord_200_test", []byte("value"))
435 | // may succeed or fail depending on when exactly coordinator fails
436 | // the main point is to check cohort consistency afterwards
437 | if err != nil {
438 | t.Logf("coordinator operation failed as expected: %v", err)
439 | } else {
440 | t.Log("coordinator operation succeeded despite limits")
441 | }
442 |
443 | t.Log("checking cohort states after coordinator failure")
444 | checkFollowerStatesAfterCoordinatorFailure(t, "coord_200_test", []byte("value"))
445 | })
446 |
447 | // coordinator fails during commit phase (after 300 bytes)
448 | t.Run("coordinator_failure_during_commit", func(t *testing.T) {
449 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL)
450 | defer chaosHelper.cleanup()
451 |
452 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
453 | for _, node := range nodes[COHORT_TYPE] {
454 | allAddresses = append(allAddresses, node.Nodeaddr)
455 | }
456 | for _, node := range nodes[COORDINATOR_TYPE] {
457 | allAddresses = append(allAddresses, node.Nodeaddr)
458 | }
459 |
460 | require.NoError(t, chaosHelper.setupProxies(allAddresses))
461 | require.NoError(t, chaosHelper.addDataLimit(nodes[COORDINATOR_TYPE][1].Nodeaddr, 300))
462 |
463 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT)
464 | defer canceller()
465 |
466 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr
467 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" {
468 | coordAddr = proxyAddr
469 | }
470 |
471 | c, err := client.NewClientAPI(coordAddr)
472 | require.NoError(t, err)
473 |
474 | _, err = c.Put(context.Background(), "coord_commit_test", []byte("commit_value"))
475 | // may succeed or fail depending on timing
476 | if err != nil {
477 | t.Logf("coordinator operation failed: %v", err)
478 | } else {
479 | t.Log("coordinator operation completed successfully")
480 | }
481 |
482 | t.Log("checking cohort states after coordinator failure during commit")
483 | checkFollowerStatesAfterCoordinatorFailure(t, "coord_commit_test", []byte("commit_value"))
484 | })
485 | }
486 |
487 | // startnodesChaos starts nodes with Toxiproxy support
488 | func startnodesChaos(helper *chaosTestHelper, commitType pb.CommitType) func() error {
489 | COORDINATOR_BADGER := fmt.Sprintf("%s%s%d", BADGER_DIR, "coordinator", time.Now().UnixNano())
490 | COHORT_BADGER := fmt.Sprintf("%s%s%d", BADGER_DIR, "cohort", time.Now().UnixNano())
491 |
492 | // cleanup dirs
493 | cleanupDirs := []string{COORDINATOR_BADGER, COHORT_BADGER, "./tmp"}
494 | for _, dir := range cleanupDirs {
495 | if _, err := os.Stat(dir); !os.IsNotExist(err) {
496 | failfast(os.RemoveAll(dir))
497 | }
498 | }
499 |
500 | // create dirs
501 | createDirs := []string{COORDINATOR_BADGER, COHORT_BADGER, "./tmp", "./tmp/cohort", "./tmp/coord"}
502 | for _, dir := range createDirs {
503 | failfast(os.Mkdir(dir, os.FileMode(0777)))
504 | }
505 |
506 | stopfuncs := make([]func(), 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE]))
507 |
508 | // start cohorts
509 | for i, node := range nodes[COHORT_TYPE] {
510 | if commitType == pb.CommitType_THREE_PHASE_COMMIT {
511 | // use proxy address of coordinator
512 | if proxyAddr := helper.getProxyAddress(nodes[COORDINATOR_TYPE][1].Nodeaddr); proxyAddr != "" {
513 | node.Coordinator = proxyAddr
514 | } else {
515 | node.Coordinator = nodes[COORDINATOR_TYPE][1].Nodeaddr
516 | }
517 | }
518 | node.DBPath = filepath.Join(COHORT_BADGER, strconv.Itoa(i))
519 | failfast(os.MkdirAll(node.DBPath, os.FileMode(0o777)))
520 |
521 | walConfig := gowal.Config{
522 | Dir: "./tmp/cohort/" + strconv.Itoa(i),
523 | Prefix: "msgs_",
524 | SegmentThreshold: 100,
525 | MaxSegments: 100,
526 | IsInSyncDiskMode: false,
527 | }
528 | c, err := gowal.NewWAL(walConfig)
529 | failfast(err)
530 |
531 | stateStore, recovery, err := store.New(c, node.DBPath)
532 | failfast(err)
533 |
534 | ct := server.TWO_PHASE
535 | if commitType == pb.CommitType_THREE_PHASE_COMMIT {
536 | ct = server.THREE_PHASE
537 | }
538 |
539 | committer := commitalgo.NewCommitter(stateStore, ct, c, node.Timeout)
540 | committer.SetHeight(recovery.Height)
541 | cohortImpl := cohort.NewCohort(committer, cohort.Mode(node.CommitType))
542 |
543 | cohortServer, err := server.New(node, cohortImpl, nil, stateStore)
544 | failfast(err)
545 |
546 | go cohortServer.Run(server.WhiteListChecker)
547 | stopfuncs = append(stopfuncs, cohortServer.Stop)
548 | }
549 |
550 | // start coordinators
551 | for i, coordConfig := range nodes[COORDINATOR_TYPE] {
552 | coordConfig.DBPath = filepath.Join(COORDINATOR_BADGER, strconv.Itoa(i))
553 | failfast(os.MkdirAll(coordConfig.DBPath, os.FileMode(0o777)))
554 | // update cohorts addresses to use proxies
555 | updatedCohorts := make([]string, len(coordConfig.Cohorts))
556 | for j, cohortAddr := range coordConfig.Cohorts {
557 | if proxyAddr := helper.getProxyAddress(cohortAddr); proxyAddr != "" {
558 | updatedCohorts[j] = proxyAddr
559 | } else {
560 | updatedCohorts[j] = cohortAddr
561 | }
562 | }
563 | coordConfig.Cohorts = updatedCohorts
564 |
565 | walConfig := gowal.Config{
566 | Dir: "./tmp/coord/msgs" + strconv.Itoa(i),
567 | Prefix: "msgs",
568 | SegmentThreshold: 100,
569 | MaxSegments: 100,
570 | IsInSyncDiskMode: false,
571 | }
572 |
573 | c, err := gowal.NewWAL(walConfig)
574 | failfast(err)
575 |
576 | stateStore, recovery, err := store.New(c, coordConfig.DBPath)
577 | failfast(err)
578 |
579 | coord, err := coordinator.New(coordConfig, c, stateStore)
580 | failfast(err)
581 | coord.SetHeight(recovery.Height)
582 |
583 | coordServer, err := server.New(coordConfig, nil, coord, stateStore)
584 | failfast(err)
585 |
586 | go coordServer.Run(server.WhiteListChecker)
587 | time.Sleep(100 * time.Millisecond)
588 | stopfuncs = append(stopfuncs, coordServer.Stop)
589 | }
590 |
591 | return func() error {
592 | for _, f := range stopfuncs {
593 | f()
594 | }
595 | failfast(os.RemoveAll("./tmp"))
596 | return os.RemoveAll(BADGER_DIR)
597 | }
598 | }
599 |
600 | // checkValueOnCoordinator checks if a value exists on coordinator
601 | func checkValueOnCoordinator(t *testing.T, key string, expectedValue []byte) bool {
602 | t.Helper()
603 |
604 | coordClient, err := client.NewClientAPI(nodes[COORDINATOR_TYPE][1].Nodeaddr)
605 | require.NoError(t, err)
606 |
607 | coordValue, err := coordClient.Get(context.Background(), key)
608 | if err != nil {
609 | t.Logf("coordinator does not have value for key %s: %v", key, err)
610 | return false
611 | }
612 |
613 | require.Equal(t, expectedValue, coordValue.Value)
614 | t.Logf("coordinator has correct value for key %s", key)
615 | return true
616 | }
617 |
618 | // checkValueOnCohorts checks if a value exists on working cohorts (excluding failed one)
619 | func checkValueOnCohorts(t *testing.T, key string, expectedValue []byte, skipFailedIndex int) int {
620 | t.Helper()
621 |
622 | successCount := 0
623 | for i, cohortAddr := range nodes[COHORT_TYPE] {
624 | if i == skipFailedIndex {
625 | continue
626 | }
627 |
628 | cohortClient, err := client.NewClientAPI(cohortAddr.Nodeaddr)
629 | require.NoError(t, err)
630 |
631 | cohortValue, err := cohortClient.Get(context.Background(), key)
632 | require.NoError(t, err)
633 | require.Equal(t, expectedValue, cohortValue.Value)
634 | successCount++
635 | }
636 | return successCount
637 | }
638 |
639 | // checkValueOnAllCohorts checks if a value exists on ALL cohorts
640 | func checkValueOnAllCohorts(t *testing.T, key string, expectedValue []byte) {
641 | t.Helper()
642 |
643 | for _, cohortAddr := range nodes[COHORT_TYPE] {
644 | cohortClient, err := client.NewClientAPI(cohortAddr.Nodeaddr)
645 | require.NoError(t, err)
646 |
647 | cohortValue, err := cohortClient.Get(context.Background(), key)
648 | require.NoError(t, err)
649 | require.Equal(t, expectedValue, cohortValue.Value)
650 | }
651 | }
652 |
653 | // checkValueNotOnNode checks that a value does NOT exist on specified node
654 | func checkValueNotOnNode(t *testing.T, nodeAddr string, key string) {
655 | t.Helper()
656 |
657 | nodeClient, err := client.NewClientAPI(nodeAddr)
658 | require.NoError(t, err)
659 |
660 | _, err = nodeClient.Get(context.Background(), key)
661 | if err != nil {
662 | t.Logf("node %s correctly does not have value (as expected for failed node)", nodeAddr)
663 | } else {
664 | t.Logf("WARNING: node %s unexpectedly has the value (should not happen for failed node)", nodeAddr)
665 | }
666 | }
667 |
668 | // checkFollowerStatesAfterCoordinatorFailure checks the state of cohort nodes after coordinator failure
669 | func checkFollowerStatesAfterCoordinatorFailure(t *testing.T, key string, expectedValue []byte) {
670 | t.Helper()
671 |
672 | committedCount := 0
673 | notCommittedCount := 0
674 |
675 | for _, cohortAddr := range nodes[COHORT_TYPE] {
676 | cohortClient, err := client.NewClientAPI(cohortAddr.Nodeaddr)
677 | require.NoError(t, err)
678 |
679 | cohortValue, err := cohortClient.Get(context.Background(), key)
680 | if err == nil {
681 | require.Equal(t, expectedValue, cohortValue.Value)
682 | committedCount++
683 | } else {
684 | notCommittedCount++
685 | }
686 | }
687 |
688 | totalCohorts := len(nodes[COHORT_TYPE])
689 | t.Logf("cohort states: %d committed, %d not committed", committedCount, notCommittedCount)
690 |
691 | // validation: after coordinator failure, cohorts must be in consistent state
692 | // either all committed or none committed (no partial commits allowed)
693 | if committedCount > 0 && committedCount < totalCohorts {
694 | t.Errorf("inconsistent state detected: %d cohorts committed, %d did not commit. This violates consistency!",
695 | committedCount, notCommittedCount)
696 | } else {
697 | t.Logf("consistent state maintained: all cohorts are in the same state")
698 | }
699 | }
700 |
--------------------------------------------------------------------------------