├── .gitignore ├── committer.png ├── io ├── gateway │ └── grpc │ │ ├── client │ │ ├── common.go │ │ ├── client_api.go │ │ └── internal_client.go │ │ ├── server │ │ ├── dto.go │ │ ├── interceptors.go │ │ └── server.go │ │ └── proto │ │ ├── schema.proto │ │ └── schema_grpc.pb.go └── store │ ├── store_test.go │ └── store.go ├── examples └── client │ └── client.go ├── core ├── cohort │ ├── commitalgo │ │ ├── hooks │ │ │ ├── hooks.go │ │ │ ├── registry.go │ │ │ ├── registry_test.go │ │ │ └── examples.go │ │ ├── fsm.go │ │ ├── commitalgo.go │ │ └── committer_test.go │ ├── cohort.go │ └── cohort_test.go ├── dto │ └── dto.go ├── walrecord │ └── walrecord.go └── coordinator │ ├── coordinator_test.go │ └── coordinator.go ├── .github └── workflows │ └── tests.yml ├── Makefile ├── go.mod ├── mocks ├── mock_state_store.go ├── mock_commitalgo_state_store.go ├── mock_wal.go ├── mock_cohort.go └── mock_committer.go ├── config └── config.go ├── main.go ├── README.md ├── toxiproxy_test.go ├── go.sum ├── main_test.go ├── LICENSE └── chaos_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ -------------------------------------------------------------------------------- /committer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vadiminshakov/committer/HEAD/committer.png -------------------------------------------------------------------------------- /io/gateway/grpc/client/common.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/pkg/errors" 8 | "google.golang.org/grpc" 9 | "google.golang.org/grpc/backoff" 10 | "google.golang.org/grpc/credentials/insecure" 11 | ) 12 | 13 | func createConnection(addr string) (*grpc.ClientConn, error) { 14 | connParams := grpc.ConnectParams{ 15 | Backoff: backoff.Config{ 16 | BaseDelay: 100 * time.Millisecond, 17 | MaxDelay: 10 * time.Second, 18 | }, 19 | MinConnectTimeout: 200 * time.Millisecond, 20 | } 21 | 22 | ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) 23 | conn, err := grpc.DialContext(ctx, addr, grpc.WithConnectParams(connParams), grpc.WithTransportCredentials(insecure.NewCredentials())) 24 | if err != nil { 25 | return nil, errors.Wrap(err, "failed to connect") 26 | } 27 | 28 | return conn, nil 29 | } 30 | -------------------------------------------------------------------------------- /io/gateway/grpc/server/dto.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/vadiminshakov/committer/core/dto" 5 | "github.com/vadiminshakov/committer/io/gateway/grpc/proto" 6 | ) 7 | 8 | func proposeRequestPbToEntity(request *proto.ProposeRequest) *dto.ProposeRequest { 9 | if request == nil { 10 | return nil 11 | } 12 | 13 | return &dto.ProposeRequest{ 14 | Key: request.Key, 15 | Value: request.Value, 16 | Height: request.Index, 17 | } 18 | } 19 | 20 | func commitRequestPbToEntity(request *proto.CommitRequest) *dto.CommitRequest { 21 | if request == nil { 22 | return nil 23 | } 24 | 25 | return &dto.CommitRequest{ 26 | Height: request.Index, 27 | IsRollback: request.IsRollback, 28 | } 29 | } 30 | 31 | func cohortResponseToProto(e *dto.CohortResponse) *proto.Response { 32 | if e == nil { 33 | return nil 34 | } 35 | return &proto.Response{ 36 | Type: proto.Type(e.ResponseType), 37 | Index: e.Height, 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/client/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/vadiminshakov/committer/io/gateway/grpc/client" 7 | pb "github.com/vadiminshakov/committer/io/gateway/grpc/proto" 8 | "strconv" 9 | ) 10 | 11 | const coordinatorAddr = "0.0.0.0:3000" 12 | 13 | func main() { 14 | key, value := "somekey", "somevalue" 15 | 16 | // create a client for interaction with coordinator 17 | cli, err := client.NewClientAPI(coordinatorAddr) 18 | if err != nil { 19 | panic(err) 20 | } 21 | 22 | for i := 0; i < 5; i++ { 23 | // put a key-value pair 24 | resp, err := cli.Put(context.Background(), key+strconv.Itoa(i), []byte(value+strconv.Itoa(i))) 25 | if err != nil { 26 | panic(err) 27 | } 28 | if resp.Type != pb.Type_ACK { 29 | panic("msg is not acknowledged") 30 | } 31 | 32 | // read committed key 33 | v, err := cli.Get(context.Background(), key+strconv.Itoa(i)) 34 | if err != nil { 35 | panic(err) 36 | } 37 | fmt.Printf("got value for key '%s': %s\n", key+strconv.Itoa(i), string(v.Value)+strconv.Itoa(i)) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /core/cohort/commitalgo/hooks/hooks.go: -------------------------------------------------------------------------------- 1 | // Package hooks provides an extensible hook system for commit algorithms. 2 | // 3 | // Hooks allow custom validation, metrics collection, and business logic 4 | // to be executed during propose and commit phases without modifying core logic. 5 | package hooks 6 | 7 | import ( 8 | log "github.com/sirupsen/logrus" 9 | "github.com/vadiminshakov/committer/core/dto" 10 | ) 11 | 12 | // DefaultHook provides the default logging behavior 13 | type DefaultHook struct{} 14 | 15 | // NewDefaultHook creates a new default hook instance 16 | func NewDefaultHook() *DefaultHook { 17 | return &DefaultHook{} 18 | } 19 | 20 | // OnPropose implements the Hook interface for propose operations 21 | func (h *DefaultHook) OnPropose(req *dto.ProposeRequest) bool { 22 | log.Infof("propose hook on height %d is OK", req.Height) 23 | return true 24 | } 25 | 26 | // OnCommit implements the Hook interface for commit operations 27 | func (h *DefaultHook) OnCommit(req *dto.CommitRequest) bool { 28 | log.Infof("commit hook on height %d is OK", req.Height) 29 | return true 30 | } 31 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: tests 4 | 5 | # Controls when the action will run. Triggers the workflow on push or pull request 6 | # events but only for the master branch 7 | on: 8 | push: 9 | branches: [ master ] 10 | pull_request: 11 | branches: [ master ] 12 | 13 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 14 | jobs: 15 | # This workflow contains a single job called "build" 16 | run-tests: 17 | # The type of runner that the job will run on 18 | runs-on: ubuntu-latest 19 | 20 | # Steps represent a sequence of tasks that will be executed as part of the job 21 | steps: 22 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 23 | - uses: actions/checkout@v2 24 | 25 | # Set up Go 26 | - name: Set up Go 27 | uses: actions/setup-go@v4 28 | with: 29 | go-version: '1.25' 30 | 31 | # Runs a single command using the runners shell 32 | - name: Run tests 33 | run: make prepare && make tests 34 | -------------------------------------------------------------------------------- /io/gateway/grpc/server/interceptors.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net" 6 | 7 | "google.golang.org/grpc" 8 | codes "google.golang.org/grpc/codes" 9 | "google.golang.org/grpc/peer" 10 | status "google.golang.org/grpc/status" 11 | ) 12 | 13 | // WhiteListChecker intercepts RPC and checks that the caller is whitelisted. 14 | func WhiteListChecker(ctx context.Context, 15 | req interface{}, 16 | info *grpc.UnaryServerInfo, 17 | handler grpc.UnaryHandler) (interface{}, error) { 18 | peerinfo, ok := peer.FromContext(ctx) 19 | if !ok { 20 | return nil, status.Errorf(codes.Internal, "failed to retrieve peer info") 21 | } 22 | 23 | host, _, err := net.SplitHostPort(peerinfo.Addr.String()) 24 | if err != nil { 25 | return nil, status.Errorf(codes.Internal, err.Error()) 26 | } 27 | 28 | serv := info.Server.(*Server) 29 | if !includes(serv.Config.Whitelist, host) { 30 | return nil, status.Errorf(codes.PermissionDenied, "host %s is not in whitelist", host) 31 | } 32 | 33 | // calls the handler 34 | h, err := handler(ctx, req) 35 | 36 | return h, err 37 | } 38 | 39 | // includes checks that the 'arr' includes 'value' 40 | func includes(arr []string, value string) bool { 41 | for i := range arr { 42 | if arr[i] == value { 43 | return true 44 | } 45 | } 46 | return false 47 | } -------------------------------------------------------------------------------- /core/cohort/commitalgo/hooks/registry.go: -------------------------------------------------------------------------------- 1 | package hooks 2 | 3 | import ( 4 | "github.com/vadiminshakov/committer/core/dto" 5 | ) 6 | 7 | // Hook defines the interface for commit algorithm hooks. 8 | type Hook interface { 9 | OnPropose(req *dto.ProposeRequest) bool 10 | OnCommit(req *dto.CommitRequest) bool 11 | } 12 | 13 | // Registry manages a collection of hooks. 14 | type Registry struct { 15 | hooks []Hook 16 | } 17 | 18 | // NewRegistry creates a new hook registry. 19 | func NewRegistry() *Registry { 20 | return &Registry{ 21 | hooks: make([]Hook, 0), 22 | } 23 | } 24 | 25 | // Register adds a new hook to the registry. 26 | func (r *Registry) Register(hook Hook) { 27 | r.hooks = append(r.hooks, hook) 28 | } 29 | 30 | // ExecutePropose runs all registered propose hooks. 31 | // Returns false if any hook returns false. 32 | func (r *Registry) ExecutePropose(req *dto.ProposeRequest) bool { 33 | for _, hook := range r.hooks { 34 | if !hook.OnPropose(req) { 35 | return false 36 | } 37 | } 38 | return true 39 | } 40 | 41 | // ExecuteCommit runs all registered commit hooks. 42 | // Returns false if any hook returns false. 43 | func (r *Registry) ExecuteCommit(req *dto.CommitRequest) bool { 44 | for _, hook := range r.hooks { 45 | if !hook.OnCommit(req) { 46 | return false 47 | } 48 | } 49 | return true 50 | } 51 | 52 | // Count returns the number of registered hooks 53 | func (r *Registry) Count() int { 54 | return len(r.hooks) 55 | } 56 | -------------------------------------------------------------------------------- /io/gateway/grpc/proto/schema.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package schema; 3 | import "google/protobuf/empty.proto"; 4 | option go_package = "github.com/vadiminshakov/committer/proto"; 5 | 6 | service InternalCommitAPI { 7 | rpc Propose(ProposeRequest) returns (Response); 8 | rpc Precommit(PrecommitRequest) returns (Response); 9 | rpc Commit(CommitRequest) returns (Response); 10 | rpc Abort(AbortRequest) returns (Response); 11 | } 12 | 13 | service ClientAPI { 14 | rpc Put(Entry) returns (Response); 15 | rpc Get(Msg) returns (Value); 16 | rpc NodeInfo(google.protobuf.Empty) returns (Info); 17 | } 18 | 19 | message ProposeRequest { 20 | string Key = 1; 21 | bytes Value = 2; 22 | CommitType CommitType = 3; 23 | uint64 index = 4; 24 | } 25 | 26 | enum CommitType { 27 | TWO_PHASE_COMMIT = 0; 28 | THREE_PHASE_COMMIT = 1; 29 | } 30 | 31 | enum Type { 32 | ACK = 0; 33 | NACK = 1; 34 | } 35 | 36 | message Response { 37 | Type type = 1; 38 | uint64 index = 2; 39 | } 40 | 41 | message PrecommitRequest { 42 | uint64 index = 1; 43 | } 44 | 45 | message CommitRequest { 46 | uint64 index = 1; 47 | bool isRollback = 2; 48 | } 49 | 50 | message Entry { 51 | string key = 1; 52 | bytes value = 2; 53 | } 54 | 55 | message Msg { 56 | string key = 1; 57 | } 58 | 59 | message Value { 60 | bytes value = 1; 61 | } 62 | 63 | message Info { 64 | uint64 height = 1; 65 | } 66 | 67 | message AbortRequest { 68 | uint64 height = 1; 69 | string reason = 2; 70 | } -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | prepare: 2 | @rm -rf ./badger 3 | @mkdir ./badger 4 | @mkdir ./badger/coordinator 5 | @mkdir ./badger/cohort 6 | 7 | run-example-coordinator: 8 | @rm -rf ./badger/coordinator 9 | @go run . -role=coordinator -nodeaddr=localhost:3000 -cohorts=localhost:3001 -committype=three-phase -timeout=1000 -dbpath=./badger/coordinator -whitelist=127.0.0.1 10 | 11 | run-example-cohort: 12 | @rm -rf ./badger/cohort 13 | @go run . -role=cohort -coordinator=localhost:3000 -nodeaddr=localhost:3001 -committype=three-phase -timeout=1000 -dbpath=./badger/cohort -whitelist=127.0.0.1 14 | 15 | run-example-client: 16 | @go run ./examples/client 17 | 18 | tests: 19 | @go test ./... 20 | 21 | start-toxiproxy: 22 | @echo "Starting Toxiproxy server..." 23 | @pkill toxiproxy-server || true 24 | @toxiproxy-server > /dev/null 2>&1 & 25 | @echo "Toxiproxy server started in background" 26 | 27 | stop-toxiproxy: 28 | @echo "Stopping Toxiproxy server..." 29 | @pkill toxiproxy-server || true 30 | 31 | test-chaos: start-toxiproxy 32 | @echo "Waiting for Toxiproxy to start..." 33 | @sleep 2 34 | @echo "Running chaos tests..." 35 | @go test -v -tags=chaos -run "TestChaos" ./... 36 | @$(MAKE) stop-toxiproxy 37 | 38 | proto-gen: 39 | @echo "Generating proto files..." 40 | @protoc --go_out=io/gateway/grpc/proto --go_opt=paths=source_relative \ 41 | --go-grpc_out=io/gateway/grpc/proto --go-grpc_opt=paths=source_relative \ 42 | --proto_path=io/gateway/grpc/proto io/gateway/grpc/proto/schema.proto 43 | @echo "Proto files generated successfully" 44 | 45 | generate: proto-gen 46 | @echo "Generating mocks..." 47 | @go generate ./... 48 | @echo "All files generated successfully" -------------------------------------------------------------------------------- /io/gateway/grpc/client/client_api.go: -------------------------------------------------------------------------------- 1 | // Package client provides gRPC client implementations for communicating with committer nodes. 2 | // 3 | // This package contains both internal client for node-to-node communication 4 | // and external client API for application integration. 5 | package client 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/vadiminshakov/committer/io/gateway/grpc/proto" 11 | "google.golang.org/protobuf/types/known/emptypb" 12 | ) 13 | 14 | // ClientAPIClient provides access to the client API. 15 | type ClientAPIClient struct { 16 | Connection proto.ClientAPIClient 17 | } 18 | 19 | // NewClientAPI creates an instance of the client API client. 20 | // The addr parameter should be the network address of the coordinator (host + port). 21 | func NewClientAPI(addr string) (*ClientAPIClient, error) { 22 | conn, err := createConnection(addr) 23 | if err != nil { 24 | return nil, err 25 | } 26 | return &ClientAPIClient{Connection: proto.NewClientAPIClient(conn)}, nil 27 | } 28 | 29 | // Put sends a put request to the client API. 30 | func (client *ClientAPIClient) Put(ctx context.Context, key string, value []byte) (*proto.Response, error) { 31 | return client.Connection.Put(ctx, &proto.Entry{Key: key, Value: value}) 32 | } 33 | 34 | // NodeInfo gets the current height of the node. 35 | func (client *ClientAPIClient) NodeInfo(ctx context.Context) (*proto.Info, error) { 36 | return client.Connection.NodeInfo(ctx, &emptypb.Empty{}) 37 | } 38 | 39 | // Get gets the value by the specified key. 40 | func (client *ClientAPIClient) Get(ctx context.Context, key string) (*proto.Value, error) { 41 | return client.Connection.Get(ctx, &proto.Msg{Key: key}) 42 | } 43 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/vadiminshakov/committer 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/dgraph-io/badger/v4 v4.8.0 7 | github.com/pkg/errors v0.9.1 8 | github.com/sirupsen/logrus v1.2.0 9 | github.com/stretchr/testify v1.10.0 10 | github.com/vadiminshakov/gowal v0.0.4 11 | go.uber.org/mock v0.6.0 12 | google.golang.org/grpc v1.50.0 13 | google.golang.org/protobuf v1.36.6 14 | ) 15 | 16 | require ( 17 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 18 | github.com/davecgh/go-spew v1.1.1 // indirect 19 | github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect 20 | github.com/dustin/go-humanize v1.0.1 // indirect 21 | github.com/go-logr/logr v1.4.3 // indirect 22 | github.com/go-logr/stdr v1.2.2 // indirect 23 | github.com/golang/protobuf v1.5.2 // indirect 24 | github.com/google/flatbuffers v25.2.10+incompatible // indirect 25 | github.com/klauspost/compress v1.18.0 // indirect 26 | github.com/konsorten/go-windows-terminal-sequences v1.0.1 // indirect 27 | github.com/pmezard/go-difflib v1.0.0 // indirect 28 | github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect 29 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 30 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 31 | go.opentelemetry.io/otel v1.37.0 // indirect 32 | go.opentelemetry.io/otel/metric v1.37.0 // indirect 33 | go.opentelemetry.io/otel/trace v1.37.0 // indirect 34 | golang.org/x/crypto v0.39.0 // indirect 35 | golang.org/x/net v0.41.0 // indirect 36 | golang.org/x/sys v0.34.0 // indirect 37 | golang.org/x/term v0.32.0 // indirect 38 | golang.org/x/text v0.26.0 // indirect 39 | google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e // indirect 40 | gopkg.in/yaml.v3 v3.0.1 // indirect 41 | ) 42 | -------------------------------------------------------------------------------- /core/dto/dto.go: -------------------------------------------------------------------------------- 1 | // Package dto defines the data transfer objects used for communication in 2 | // distributed atomic commit protocols between coordinators and cohorts. 3 | // These structures represent the messages exchanged during the two-phase 4 | // and three-phase commit protocols. 5 | package dto 6 | 7 | // ProposeRequest represents a proposal for a new transaction. 8 | type ProposeRequest struct { 9 | Key string // Key to be stored 10 | Value []byte // Value to be stored 11 | Height uint64 // Transaction height/sequence number 12 | } 13 | 14 | // CommitRequest represents a commit or rollback request. 15 | type CommitRequest struct { 16 | Height uint64 // Transaction height/sequence number 17 | IsRollback bool // Whether this is a rollback request 18 | } 19 | 20 | // ResponseType represents the type of response from a cohort. 21 | type ResponseType int32 22 | 23 | const ( 24 | // ResponseTypeAck indicates successful acknowledgment. 25 | ResponseTypeAck ResponseType = iota 26 | // ResponseTypeNack indicates negative acknowledgment (rejection). 27 | ResponseTypeNack 28 | ) 29 | 30 | // CohortResponse represents a response from a cohort node. 31 | type CohortResponse struct { 32 | ResponseType 33 | Height uint64 // Current height of the cohort 34 | } 35 | 36 | // BroadcastRequest represents a request to be broadcast to all cohorts. 37 | type BroadcastRequest struct { 38 | Key string // Key to be stored 39 | Value []byte // Value to be stored 40 | } 41 | 42 | // BroadcastResponse represents a response to a broadcast request. 43 | type BroadcastResponse struct { 44 | Type ResponseType // Response type (ACK/NACK) 45 | Height uint64 // Height of the committed transaction 46 | } 47 | 48 | // AbortRequest represents a request to abort a transaction. 49 | type AbortRequest struct { 50 | Height uint64 // Transaction height to abort 51 | Reason string // Reason for the abort 52 | } 53 | -------------------------------------------------------------------------------- /core/cohort/commitalgo/fsm.go: -------------------------------------------------------------------------------- 1 | package commitalgo 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | ) 7 | 8 | type mode string 9 | 10 | const ( 11 | twophase = "two-phase" 12 | threephase = "three-phase" 13 | ) 14 | const ( 15 | proposeStage = "propose" 16 | precommitStage = "precommit" 17 | commitStage = "commit" 18 | ) 19 | 20 | type stateMachine struct { 21 | mu sync.RWMutex 22 | currentState string 23 | mode mode 24 | transitions map[string]map[string]struct{} 25 | } 26 | 27 | var twoPhaseTransitions = map[string]map[string]struct{}{ 28 | proposeStage: { 29 | proposeStage: struct{}{}, 30 | commitStage: struct{}{}, 31 | }, 32 | commitStage: { 33 | proposeStage: struct{}{}, 34 | }, 35 | } 36 | 37 | var threePhaseTransitions = map[string]map[string]struct{}{ 38 | proposeStage: { 39 | proposeStage: struct{}{}, 40 | precommitStage: struct{}{}, 41 | }, 42 | precommitStage: { 43 | commitStage: struct{}{}, 44 | }, 45 | commitStage: { 46 | proposeStage: struct{}{}, 47 | }, 48 | } 49 | 50 | func newStateMachine(mode mode) *stateMachine { 51 | tr := twoPhaseTransitions 52 | if mode == threephase { 53 | tr = threePhaseTransitions 54 | } 55 | 56 | return &stateMachine{ 57 | currentState: proposeStage, 58 | mode: mode, 59 | transitions: tr, 60 | } 61 | } 62 | 63 | func (sm *stateMachine) Transition(nextState string) error { 64 | sm.mu.Lock() 65 | defer sm.mu.Unlock() 66 | 67 | if allowedStates, ok := sm.transitions[sm.currentState]; ok { 68 | if _, ok = allowedStates[nextState]; ok { 69 | sm.currentState = nextState 70 | return nil 71 | } 72 | } 73 | 74 | return errors.New("invalid state transition") 75 | } 76 | 77 | func (sm *stateMachine) getCurrentState() string { 78 | sm.mu.RLock() 79 | defer sm.mu.RUnlock() 80 | return sm.currentState 81 | } 82 | 83 | func (sm *stateMachine) GetMode() mode { 84 | sm.mu.RLock() 85 | defer sm.mu.RUnlock() 86 | return sm.mode 87 | } 88 | -------------------------------------------------------------------------------- /io/gateway/grpc/client/internal_client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | 6 | log "github.com/sirupsen/logrus" 7 | "github.com/vadiminshakov/committer/core/dto" 8 | "github.com/vadiminshakov/committer/io/gateway/grpc/proto" 9 | ) 10 | 11 | // InternalCommitClient provides access to the internal API of the atomic commit protocol 12 | type InternalCommitClient struct { 13 | Connection proto.InternalCommitAPIClient 14 | } 15 | 16 | // NewInternalClient creates an instance of the internal API client. 17 | // 'addr' - the network address of the node (host + port). 18 | func NewInternalClient(addr string) (*InternalCommitClient, error) { 19 | conn, err := createConnection(addr) 20 | if err != nil { 21 | return nil, err 22 | } 23 | return &InternalCommitClient{Connection: proto.NewInternalCommitAPIClient(conn)}, nil 24 | } 25 | 26 | // Propose sends a propose request to the internal commit API 27 | func (client *InternalCommitClient) Propose(ctx context.Context, req *proto.ProposeRequest) (*proto.Response, error) { 28 | return client.Connection.Propose(ctx, req) 29 | } 30 | 31 | // Precommit sends a precommit request to the internal commit API 32 | func (client *InternalCommitClient) Precommit(ctx context.Context, req *proto.PrecommitRequest) (*proto.Response, error) { 33 | return client.Connection.Precommit(ctx, req) 34 | } 35 | 36 | // Commit sends a commit request to the internal commit API 37 | func (client *InternalCommitClient) Commit(ctx context.Context, req *proto.CommitRequest) (*proto.Response, error) { 38 | return client.Connection.Commit(ctx, req) 39 | } 40 | 41 | // Abort sends an abort request to the internal commit API 42 | func (client *InternalCommitClient) Abort(ctx context.Context, req *dto.AbortRequest) (*proto.Response, error) { 43 | protoReq := &proto.AbortRequest{ 44 | Height: req.Height, 45 | Reason: req.Reason, 46 | } 47 | 48 | log.Infof("Sending abort request for height %d with reason: %s", req.Height, req.Reason) 49 | return client.Connection.Abort(ctx, protoReq) 50 | } 51 | -------------------------------------------------------------------------------- /mocks/mock_state_store.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/vadiminshakov/committer/core/coordinator (interfaces: StateStore) 3 | // 4 | // Generated by this command: 5 | // 6 | // mockgen -destination=../../mocks/mock_state_store.go -package=mocks . StateStore 7 | // 8 | 9 | // Package mocks is a generated GoMock package. 10 | package mocks 11 | 12 | import ( 13 | reflect "reflect" 14 | 15 | gomock "go.uber.org/mock/gomock" 16 | ) 17 | 18 | // MockStateStore is a mock of StateStore interface. 19 | type MockStateStore struct { 20 | ctrl *gomock.Controller 21 | recorder *MockStateStoreMockRecorder 22 | isgomock struct{} 23 | } 24 | 25 | // MockStateStoreMockRecorder is the mock recorder for MockStateStore. 26 | type MockStateStoreMockRecorder struct { 27 | mock *MockStateStore 28 | } 29 | 30 | // NewMockStateStore creates a new mock instance. 31 | func NewMockStateStore(ctrl *gomock.Controller) *MockStateStore { 32 | mock := &MockStateStore{ctrl: ctrl} 33 | mock.recorder = &MockStateStoreMockRecorder{mock} 34 | return mock 35 | } 36 | 37 | // EXPECT returns an object that allows the caller to indicate expected use. 38 | func (m *MockStateStore) EXPECT() *MockStateStoreMockRecorder { 39 | return m.recorder 40 | } 41 | 42 | // Close mocks base method. 43 | func (m *MockStateStore) Close() error { 44 | m.ctrl.T.Helper() 45 | ret := m.ctrl.Call(m, "Close") 46 | ret0, _ := ret[0].(error) 47 | return ret0 48 | } 49 | 50 | // Close indicates an expected call of Close. 51 | func (mr *MockStateStoreMockRecorder) Close() *gomock.Call { 52 | mr.mock.ctrl.T.Helper() 53 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStateStore)(nil).Close)) 54 | } 55 | 56 | // Put mocks base method. 57 | func (m *MockStateStore) Put(key string, value []byte) error { 58 | m.ctrl.T.Helper() 59 | ret := m.ctrl.Call(m, "Put", key, value) 60 | ret0, _ := ret[0].(error) 61 | return ret0 62 | } 63 | 64 | // Put indicates an expected call of Put. 65 | func (mr *MockStateStoreMockRecorder) Put(key, value any) *gomock.Call { 66 | mr.mock.ctrl.T.Helper() 67 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockStateStore)(nil).Put), key, value) 68 | } 69 | -------------------------------------------------------------------------------- /mocks/mock_commitalgo_state_store.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/vadiminshakov/committer/core/cohort/commitalgo (interfaces: StateStore) 3 | // 4 | // Generated by this command: 5 | // 6 | // mockgen -destination=../../../mocks/mock_commitalgo_state_store.go -package=mocks -mock_names=StateStore=MockCommitalgoStateStore . StateStore 7 | // 8 | 9 | // Package mocks is a generated GoMock package. 10 | package mocks 11 | 12 | import ( 13 | reflect "reflect" 14 | 15 | gomock "go.uber.org/mock/gomock" 16 | ) 17 | 18 | // MockCommitalgoStateStore is a mock of StateStore interface. 19 | type MockCommitalgoStateStore struct { 20 | ctrl *gomock.Controller 21 | recorder *MockCommitalgoStateStoreMockRecorder 22 | isgomock struct{} 23 | } 24 | 25 | // MockCommitalgoStateStoreMockRecorder is the mock recorder for MockCommitalgoStateStore. 26 | type MockCommitalgoStateStoreMockRecorder struct { 27 | mock *MockCommitalgoStateStore 28 | } 29 | 30 | // NewMockCommitalgoStateStore creates a new mock instance. 31 | func NewMockCommitalgoStateStore(ctrl *gomock.Controller) *MockCommitalgoStateStore { 32 | mock := &MockCommitalgoStateStore{ctrl: ctrl} 33 | mock.recorder = &MockCommitalgoStateStoreMockRecorder{mock} 34 | return mock 35 | } 36 | 37 | // EXPECT returns an object that allows the caller to indicate expected use. 38 | func (m *MockCommitalgoStateStore) EXPECT() *MockCommitalgoStateStoreMockRecorder { 39 | return m.recorder 40 | } 41 | 42 | // Close mocks base method. 43 | func (m *MockCommitalgoStateStore) Close() error { 44 | m.ctrl.T.Helper() 45 | ret := m.ctrl.Call(m, "Close") 46 | ret0, _ := ret[0].(error) 47 | return ret0 48 | } 49 | 50 | // Close indicates an expected call of Close. 51 | func (mr *MockCommitalgoStateStoreMockRecorder) Close() *gomock.Call { 52 | mr.mock.ctrl.T.Helper() 53 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCommitalgoStateStore)(nil).Close)) 54 | } 55 | 56 | // Put mocks base method. 57 | func (m *MockCommitalgoStateStore) Put(key string, value []byte) error { 58 | m.ctrl.T.Helper() 59 | ret := m.ctrl.Call(m, "Put", key, value) 60 | ret0, _ := ret[0].(error) 61 | return ret0 62 | } 63 | 64 | // Put indicates an expected call of Put. 65 | func (mr *MockCommitalgoStateStoreMockRecorder) Put(key, value any) *gomock.Call { 66 | mr.mock.ctrl.T.Helper() 67 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockCommitalgoStateStore)(nil).Put), key, value) 68 | } 69 | -------------------------------------------------------------------------------- /core/walrecord/walrecord.go: -------------------------------------------------------------------------------- 1 | package walrecord 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | ) 7 | 8 | const ( 9 | // system keys for protocol phases 10 | KeyPrepared = "__tx:prepared" 11 | KeyPrecommit = "__tx:precommit" 12 | KeyCommit = "__tx:commit" 13 | KeyAbort = "__tx:abort" 14 | 15 | // stride is the multiplier for height to determine the starting WAL index 16 | Stride = 4 17 | ) 18 | 19 | // WalTx represents the payload stored in the WAL. 20 | type WalTx struct { 21 | Key string 22 | Value []byte 23 | } 24 | 25 | // PreparedSlot returns the WAL index for the PREPARED phase at a given height. 26 | func PreparedSlot(height uint64) uint64 { 27 | return height*Stride + 0 28 | } 29 | 30 | // PrecommitSlot returns the WAL index for the PRECOMMIT phase at a given height. 31 | func PrecommitSlot(height uint64) uint64 { 32 | return height*Stride + 1 33 | } 34 | 35 | // CommitSlot returns the WAL index for the COMMIT phase at a given height. 36 | func CommitSlot(height uint64) uint64 { 37 | return height*Stride + 2 38 | } 39 | 40 | // AbortSlot returns the WAL index for the ABORT phase at a given height. 41 | func AbortSlot(height uint64) uint64 { 42 | return height*Stride + 3 43 | } 44 | 45 | // Encode serializes a WalTx into bytes. 46 | // Format: [KeyLen(4 bytes)] [KeyBytes] [ValueLen(4 bytes)] [ValueBytes] 47 | func Encode(tx WalTx) ([]byte, error) { 48 | keyLen := uint32(len(tx.Key)) 49 | valLen := uint32(len(tx.Value)) 50 | 51 | // 4 bytes for key len + key bytes + 4 bytes for val len + val bytes 52 | buf := make([]byte, 4+keyLen+4+valLen) 53 | 54 | binary.BigEndian.PutUint32(buf[0:4], keyLen) 55 | copy(buf[4:4+keyLen], tx.Key) 56 | 57 | binary.BigEndian.PutUint32(buf[4+keyLen:4+keyLen+4], valLen) 58 | copy(buf[4+keyLen+4:], tx.Value) 59 | 60 | return buf, nil 61 | } 62 | 63 | // Decode deserializes bytes into a WalTx. 64 | func Decode(data []byte) (WalTx, error) { 65 | if len(data) < 4 { 66 | return WalTx{}, fmt.Errorf("data too short for key length") 67 | } 68 | 69 | keyLen := binary.BigEndian.Uint32(data[0:4]) 70 | if uint32(len(data)) < 4+keyLen+4 { 71 | return WalTx{}, fmt.Errorf("data too short for key and value length") 72 | } 73 | 74 | key := string(data[4 : 4+keyLen]) 75 | 76 | valLen := binary.BigEndian.Uint32(data[4+keyLen : 4+keyLen+4]) 77 | if uint32(len(data)) < 4+keyLen+4+valLen { 78 | return WalTx{}, fmt.Errorf("data too short for value body") 79 | } 80 | 81 | value := make([]byte, valLen) 82 | copy(value, data[4+keyLen+4:4+keyLen+4+valLen]) 83 | 84 | return WalTx{ 85 | Key: key, 86 | Value: value, 87 | }, nil 88 | } 89 | -------------------------------------------------------------------------------- /core/cohort/cohort.go: -------------------------------------------------------------------------------- 1 | // Package cohort implements the cohort role in distributed atomic commit protocols. 2 | // A cohort is a participant in the atomic commit protocol that receives 3 | // proposals from the coordinator and responds with either commit or abort. 4 | // 5 | // Cohorts participate in 2PC and 3PC transactions by responding to coordinator 6 | // requests and maintaining local transaction state. 7 | package cohort 8 | 9 | import ( 10 | "context" 11 | "errors" 12 | 13 | "github.com/vadiminshakov/committer/core/cohort/commitalgo/hooks" 14 | "github.com/vadiminshakov/committer/core/dto" 15 | ) 16 | 17 | // Mode represents the commit protocol mode. 18 | type Mode string 19 | 20 | // THREE_PHASE represents the three-phase commit protocol mode. 21 | const THREE_PHASE Mode = "three-phase" 22 | 23 | // Committer defines the interface for commit algorithms. 24 | // 25 | //go:generate mockgen -destination=../../../mocks/mock_committer.go -package=mocks . Committer 26 | type Committer interface { 27 | Height() uint64 28 | Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) 29 | Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) 30 | Commit(ctx context.Context, req *dto.CommitRequest) (*dto.CohortResponse, error) 31 | Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) 32 | RegisterHook(hook hooks.Hook) 33 | } 34 | 35 | // CohortImpl implements the cohort node functionality. 36 | type CohortImpl struct { 37 | committer Committer 38 | commitType Mode 39 | } 40 | 41 | // NewCohort creates a new cohort instance. 42 | func NewCohort( 43 | committer Committer, 44 | commitType Mode) *CohortImpl { 45 | return &CohortImpl{ 46 | committer: committer, 47 | commitType: commitType, 48 | } 49 | } 50 | 51 | func (c *CohortImpl) Height() uint64 { 52 | return c.committer.Height() 53 | } 54 | 55 | func (c *CohortImpl) Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) { 56 | return c.committer.Propose(ctx, req) 57 | } 58 | 59 | func (s *CohortImpl) Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) { 60 | if s.commitType != THREE_PHASE { 61 | return nil, errors.New("precommit is allowed for 3PC mode only") 62 | } 63 | 64 | return s.committer.Precommit(ctx, index) 65 | } 66 | 67 | func (c *CohortImpl) Commit(ctx context.Context, in *dto.CommitRequest) (resp *dto.CohortResponse, err error) { 68 | return c.committer.Commit(ctx, in) 69 | } 70 | 71 | func (c *CohortImpl) Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) { 72 | return c.committer.Abort(ctx, req) 73 | } 74 | -------------------------------------------------------------------------------- /mocks/mock_wal.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/vadiminshakov/committer/core/coordinator (interfaces: wal) 3 | // 4 | // Generated by this command: 5 | // 6 | // mockgen -destination=../../mocks/mock_wal.go -package=mocks . wal 7 | // 8 | 9 | // Package mocks is a generated GoMock package. 10 | package mocks 11 | 12 | import ( 13 | reflect "reflect" 14 | 15 | gomock "go.uber.org/mock/gomock" 16 | ) 17 | 18 | // Mockwal is a mock of wal interface. 19 | type Mockwal struct { 20 | ctrl *gomock.Controller 21 | recorder *MockwalMockRecorder 22 | isgomock struct{} 23 | } 24 | 25 | // MockwalMockRecorder is the mock recorder for Mockwal. 26 | type MockwalMockRecorder struct { 27 | mock *Mockwal 28 | } 29 | 30 | // NewMockwal creates a new mock instance. 31 | func NewMockwal(ctrl *gomock.Controller) *Mockwal { 32 | mock := &Mockwal{ctrl: ctrl} 33 | mock.recorder = &MockwalMockRecorder{mock} 34 | return mock 35 | } 36 | 37 | // EXPECT returns an object that allows the caller to indicate expected use. 38 | func (m *Mockwal) EXPECT() *MockwalMockRecorder { 39 | return m.recorder 40 | } 41 | 42 | // Close mocks base method. 43 | func (m *Mockwal) Close() error { 44 | m.ctrl.T.Helper() 45 | ret := m.ctrl.Call(m, "Close") 46 | ret0, _ := ret[0].(error) 47 | return ret0 48 | } 49 | 50 | // Close indicates an expected call of Close. 51 | func (mr *MockwalMockRecorder) Close() *gomock.Call { 52 | mr.mock.ctrl.T.Helper() 53 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*Mockwal)(nil).Close)) 54 | } 55 | 56 | // Get mocks base method. 57 | func (m *Mockwal) Get(index uint64) (string, []byte, error) { 58 | m.ctrl.T.Helper() 59 | ret := m.ctrl.Call(m, "Get", index) 60 | ret0, _ := ret[0].(string) 61 | ret1, _ := ret[1].([]byte) 62 | ret2, _ := ret[2].(error) 63 | return ret0, ret1, ret2 64 | } 65 | 66 | // Get indicates an expected call of Get. 67 | func (mr *MockwalMockRecorder) Get(index any) *gomock.Call { 68 | mr.mock.ctrl.T.Helper() 69 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*Mockwal)(nil).Get), index) 70 | } 71 | 72 | // Write mocks base method. 73 | func (m *Mockwal) Write(index uint64, key string, value []byte) error { 74 | m.ctrl.T.Helper() 75 | ret := m.ctrl.Call(m, "Write", index, key, value) 76 | ret0, _ := ret[0].(error) 77 | return ret0 78 | } 79 | 80 | // Write indicates an expected call of Write. 81 | func (mr *MockwalMockRecorder) Write(index, key, value any) *gomock.Call { 82 | mr.mock.ctrl.T.Helper() 83 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*Mockwal)(nil).Write), index, key, value) 84 | } 85 | 86 | // WriteTombstone mocks base method. 87 | func (m *Mockwal) WriteTombstone(index uint64) error { 88 | m.ctrl.T.Helper() 89 | ret := m.ctrl.Call(m, "WriteTombstone", index) 90 | ret0, _ := ret[0].(error) 91 | return ret0 92 | } 93 | 94 | // WriteTombstone indicates an expected call of WriteTombstone. 95 | func (mr *MockwalMockRecorder) WriteTombstone(index any) *gomock.Call { 96 | mr.mock.ctrl.T.Helper() 97 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTombstone", reflect.TypeOf((*Mockwal)(nil).WriteTombstone), index) 98 | } 99 | -------------------------------------------------------------------------------- /core/cohort/commitalgo/hooks/registry_test.go: -------------------------------------------------------------------------------- 1 | package hooks 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "github.com/vadiminshakov/committer/core/dto" 8 | ) 9 | 10 | // TestHook is a simple test hook implementation 11 | type TestHook struct { 12 | proposeResult bool 13 | commitResult bool 14 | proposeCalled bool 15 | commitCalled bool 16 | } 17 | 18 | func (t *TestHook) OnPropose(req *dto.ProposeRequest) bool { 19 | t.proposeCalled = true 20 | return t.proposeResult 21 | } 22 | 23 | func (t *TestHook) OnCommit(req *dto.CommitRequest) bool { 24 | t.commitCalled = true 25 | return t.commitResult 26 | } 27 | 28 | func TestRegistry_Register(t *testing.T) { 29 | registry := NewRegistry() 30 | hook := &TestHook{} 31 | 32 | require.Equal(t, 0, registry.Count(), "Expected 0 hooks initially") 33 | 34 | registry.Register(hook) 35 | 36 | require.Equal(t, 1, registry.Count(), "Expected 1 hook after registration") 37 | } 38 | 39 | func TestRegistry_ExecutePropose(t *testing.T) { 40 | registry := NewRegistry() 41 | 42 | hook1 := &TestHook{proposeResult: true} 43 | registry.Register(hook1) 44 | 45 | req := &dto.ProposeRequest{Height: 1, Key: "test", Value: []byte("value")} 46 | result := registry.ExecutePropose(req) 47 | 48 | require.True(t, result, "Expected propose to succeed") 49 | require.True(t, hook1.proposeCalled, "Expected hook to be called") 50 | 51 | hook2 := &TestHook{proposeResult: false} 52 | registry.Register(hook2) 53 | 54 | result = registry.ExecutePropose(req) 55 | 56 | require.False(t, result, "Expected propose to fail when hook returns false") 57 | } 58 | 59 | func TestRegistry_ExecuteCommit(t *testing.T) { 60 | registry := NewRegistry() 61 | 62 | hook1 := &TestHook{commitResult: true} 63 | registry.Register(hook1) 64 | 65 | req := &dto.CommitRequest{Height: 1} 66 | result := registry.ExecuteCommit(req) 67 | 68 | require.True(t, result, "Expected commit to succeed") 69 | require.True(t, hook1.commitCalled, "Expected hook to be called") 70 | 71 | hook2 := &TestHook{commitResult: false} 72 | registry.Register(hook2) 73 | 74 | result = registry.ExecuteCommit(req) 75 | 76 | require.False(t, result, "Expected commit to fail when hook returns false") 77 | } 78 | 79 | func TestRegistry_MultipleHooks(t *testing.T) { 80 | registry := NewRegistry() 81 | 82 | hook1 := &TestHook{proposeResult: true, commitResult: true} 83 | hook2 := &TestHook{proposeResult: true, commitResult: true} 84 | hook3 := &TestHook{proposeResult: false, commitResult: true} // this one fails 85 | 86 | registry.Register(hook1) 87 | registry.Register(hook2) 88 | registry.Register(hook3) 89 | 90 | proposeReq := &dto.ProposeRequest{Height: 1, Key: "test", Value: []byte("value")} 91 | commitReq := &dto.CommitRequest{Height: 1} 92 | 93 | // propose should fail because hook3 returns false 94 | require.False(t, registry.ExecutePropose(proposeReq), "Expected propose to fail") 95 | 96 | // commit should succeed because all hooks return true for commit 97 | require.True(t, registry.ExecuteCommit(commitReq), "Expected commit to succeed") 98 | 99 | // check that all hooks were called 100 | require.True(t, hook1.proposeCalled, "Expected hook1 to be called for propose") 101 | require.True(t, hook2.proposeCalled, "Expected hook2 to be called for propose") 102 | require.True(t, hook3.proposeCalled, "Expected hook3 to be called for propose") 103 | 104 | require.True(t, hook1.commitCalled, "Expected hook1 to be called for commit") 105 | require.True(t, hook2.commitCalled, "Expected hook2 to be called for commit") 106 | require.True(t, hook3.commitCalled, "Expected hook3 to be called for commit") 107 | } 108 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | // Package config provides configuration management for the committer application. 2 | // 3 | // This package handles command-line flag parsing and configuration validation 4 | // for both coordinator and cohort nodes in the distributed consensus system. 5 | package config 6 | 7 | import ( 8 | "flag" 9 | "log" 10 | "strings" 11 | ) 12 | 13 | const ( 14 | // DefaultWalDir is the default directory for write-ahead logs. 15 | DefaultWalDir string = "wal" 16 | // DefaultWalSegmentPrefix is the default prefix for WAL segment files. 17 | DefaultWalSegmentPrefix string = "msgs_" 18 | // DefaultWalSegmentThreshold is the default number of entries per WAL segment. 19 | DefaultWalSegmentThreshold int = 10000 20 | // DefaultWalMaxSegments is the default maximum number of WAL segments to retain. 21 | DefaultWalMaxSegments int = 100 22 | // DefaultWalIsInSyncDiskMode enables synchronous disk writes for WAL by default. 23 | DefaultWalIsInSyncDiskMode bool = true 24 | ) 25 | 26 | // Config holds the configuration settings for the committer application. 27 | type Config struct { 28 | Role string // Node role: "coordinator" or "cohort" 29 | Nodeaddr string // Address of this node 30 | Coordinator string // Address of the coordinator (for cohorts) 31 | CommitType string // Commit protocol: "two-phase" or "three-phase" 32 | DBPath string // Path to the database directory 33 | Cohorts []string // List of cohort addresses (for coordinators) 34 | Whitelist []string // Whitelist of allowed node addresses 35 | Timeout uint64 // Timeout in milliseconds for 3PC operations 36 | } 37 | 38 | // Get creates configuration from yaml configuration file (if '-config=' flag specified) or command-line arguments. 39 | func Get() *Config { 40 | // command-line flags 41 | role := flag.String("role", "cohort", "role (coordinator or cohort)") 42 | nodeaddr := flag.String("nodeaddr", "localhost:3050", "node address") 43 | coordinator := flag.String("coordinator", "", "coordinator address") 44 | committype := flag.String("committype", "two-phase", "two-phase or three-phase commit mode") 45 | timeout := flag.Uint64("timeout", 1000, "ms, timeout after which the message is considered unacknowledged (only for three-phase mode, because two-phase is blocking by design)") 46 | dbpath := flag.String("dbpath", "./badger", "database path on filesystem") 47 | cohorts := flag.String("cohorts", "", "cohort addresses") 48 | whitelist := flag.String("whitelist", "127.0.0.1", "allowed hosts") 49 | flag.Parse() 50 | 51 | cohortsArray := filterEmpty(strings.Split(*cohorts, ",")) 52 | if *role != "coordinator" { 53 | if !includes(cohortsArray, *nodeaddr) { 54 | cohortsArray = append(cohortsArray, *nodeaddr) 55 | } 56 | } 57 | whitelistArray := filterEmpty(strings.Split(*whitelist, ",")) 58 | 59 | if *role == "coordinator" && len(cohortsArray) == 0 { 60 | log.Fatalf("coordinator role requires at least one cohort address") 61 | } 62 | 63 | return &Config{ 64 | Role: *role, 65 | Nodeaddr: *nodeaddr, 66 | Coordinator: *coordinator, 67 | CommitType: *committype, 68 | DBPath: *dbpath, 69 | Cohorts: cohortsArray, 70 | Whitelist: whitelistArray, 71 | Timeout: *timeout, 72 | } 73 | 74 | } 75 | 76 | // includes checks that the 'arr' includes 'value' 77 | func includes(arr []string, value string) bool { 78 | for i := range arr { 79 | if arr[i] == value { 80 | return true 81 | } 82 | } 83 | return false 84 | } 85 | 86 | // filterEmpty trims and removes empty entries from a slice of strings 87 | func filterEmpty(values []string) []string { 88 | result := make([]string, 0, len(values)) 89 | for _, v := range values { 90 | if trimmed := strings.TrimSpace(v); trimmed != "" { 91 | result = append(result, trimmed) 92 | } 93 | } 94 | return result 95 | } 96 | -------------------------------------------------------------------------------- /core/cohort/commitalgo/hooks/examples.go: -------------------------------------------------------------------------------- 1 | package hooks 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | log "github.com/sirupsen/logrus" 8 | "github.com/vadiminshakov/committer/core/dto" 9 | ) 10 | 11 | // MetricsHook collects metrics about propose and commit operations. 12 | type MetricsHook struct { 13 | proposeCount uint64 14 | commitCount uint64 15 | startTime time.Time 16 | } 17 | 18 | // NewMetricsHook creates a new metrics hook. 19 | func NewMetricsHook() *MetricsHook { 20 | return &MetricsHook{ 21 | startTime: time.Now(), 22 | } 23 | } 24 | 25 | // OnPropose increments propose counter and logs metrics. 26 | func (m *MetricsHook) OnPropose(req *dto.ProposeRequest) bool { 27 | m.proposeCount++ 28 | log.WithFields(log.Fields{ 29 | "height": req.Height, 30 | "propose_count": m.proposeCount, 31 | "uptime": time.Since(m.startTime), 32 | }).Info("Metrics: propose operation") 33 | return true 34 | } 35 | 36 | // OnCommit increments commit counter and logs metrics 37 | func (m *MetricsHook) OnCommit(req *dto.CommitRequest) bool { 38 | m.commitCount++ 39 | log.WithFields(log.Fields{ 40 | "height": req.Height, 41 | "commit_count": m.commitCount, 42 | "uptime": time.Since(m.startTime), 43 | }).Info("Metrics: commit operation") 44 | return true 45 | } 46 | 47 | // GetStats returns current statistics 48 | func (m *MetricsHook) GetStats() (uint64, uint64, time.Duration) { 49 | return m.proposeCount, m.commitCount, time.Since(m.startTime) 50 | } 51 | 52 | // ValidationHook validates requests before processing 53 | type ValidationHook struct { 54 | maxKeyLength int 55 | maxValueLength int 56 | } 57 | 58 | // NewValidationHook creates a new validation hook 59 | func NewValidationHook(maxKeyLength, maxValueLength int) *ValidationHook { 60 | return &ValidationHook{ 61 | maxKeyLength: maxKeyLength, 62 | maxValueLength: maxValueLength, 63 | } 64 | } 65 | 66 | // OnPropose validates the propose request 67 | func (v *ValidationHook) OnPropose(req *dto.ProposeRequest) bool { 68 | if len(req.Key) > v.maxKeyLength { 69 | log.Errorf("Key too long: %d > %d", len(req.Key), v.maxKeyLength) 70 | return false 71 | } 72 | 73 | if len(req.Value) > v.maxValueLength { 74 | log.Errorf("Value too long: %d > %d", len(req.Value), v.maxValueLength) 75 | return false 76 | } 77 | 78 | log.Debugf("Validation passed for key: %s", req.Key) 79 | return true 80 | } 81 | 82 | // OnCommit validates the commit request 83 | func (v *ValidationHook) OnCommit(req *dto.CommitRequest) bool { 84 | if req.Height == 0 { 85 | log.Error("Invalid height: cannot be zero") 86 | return false 87 | } 88 | 89 | log.Debugf("Commit validation passed for height: %d", req.Height) 90 | return true 91 | } 92 | 93 | // AuditHook logs all operations for audit purposes 94 | type AuditHook struct { 95 | logFile string 96 | } 97 | 98 | // NewAuditHook creates a new audit hook 99 | func NewAuditHook(logFile string) *AuditHook { 100 | return &AuditHook{ 101 | logFile: logFile, 102 | } 103 | } 104 | 105 | // OnPropose logs propose operations 106 | func (a *AuditHook) OnPropose(req *dto.ProposeRequest) bool { 107 | auditMsg := fmt.Sprintf("[AUDIT] PROPOSE - Height: %d, Key: %s, Value: %s, Time: %s", 108 | req.Height, req.Key, string(req.Value), time.Now().Format(time.RFC3339)) 109 | 110 | log.WithField("audit", true).Info(auditMsg) 111 | // Here you could also write to a file if needed 112 | return true 113 | } 114 | 115 | // OnCommit logs commit operations 116 | func (a *AuditHook) OnCommit(req *dto.CommitRequest) bool { 117 | auditMsg := fmt.Sprintf("[AUDIT] COMMIT - Height: %d, Time: %s", 118 | req.Height, time.Now().Format(time.RFC3339)) 119 | 120 | log.WithField("audit", true).Info(auditMsg) 121 | // Here you could also write to a file if needed 122 | return true 123 | } 124 | -------------------------------------------------------------------------------- /io/store/store_test.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "github.com/vadiminshakov/committer/core/walrecord" 11 | "github.com/vadiminshakov/gowal" 12 | ) 13 | 14 | func TestStore_Recovery_Commit(t *testing.T) { 15 | walDir := filepath.Join(os.TempDir(), "wal_commit") 16 | dbDir := filepath.Join(os.TempDir(), "db_commit") 17 | defer os.RemoveAll(walDir) 18 | defer os.RemoveAll(dbDir) 19 | 20 | // 1. setup WAL 21 | w, err := gowal.NewWAL(gowal.Config{ 22 | Dir: walDir, 23 | Prefix: "wal_", 24 | SegmentThreshold: 1024 * 1024, 25 | MaxSegments: 10, 26 | }) 27 | require.NoError(t, err) 28 | 29 | // 2. write prepared (should require commit to apply) 30 | height := uint64(10) 31 | prepIdx := walrecord.PreparedSlot(height) 32 | tx := walrecord.WalTx{Key: "key1", Value: []byte("value1")} 33 | encoded, _ := walrecord.Encode(tx) 34 | 35 | err = w.Write(prepIdx, walrecord.KeyPrepared, encoded) 36 | require.NoError(t, err) 37 | 38 | // 3. write commit 39 | commitIdx := walrecord.CommitSlot(height) 40 | err = w.Write(commitIdx, walrecord.KeyCommit, encoded) 41 | require.NoError(t, err) 42 | 43 | w.Close() 44 | 45 | w2, err := gowal.NewWAL(gowal.Config{Dir: walDir, Prefix: "wal_", SegmentThreshold: 1024 * 1024, MaxSegments: 10}) 46 | require.NoError(t, err) 47 | defer w2.Close() 48 | 49 | // 4. recover 50 | s, state, err := New(w2, dbDir) 51 | require.NoError(t, err) 52 | defer s.Close() 53 | 54 | // 5. verify 55 | assert.Equal(t, height+1, state.Height) 56 | 57 | val, err := s.Get("key1") 58 | assert.NoError(t, err) 59 | assert.Equal(t, []byte("value1"), val) 60 | } 61 | 62 | func TestStore_Recovery_PreparedOnly(t *testing.T) { 63 | walDir := filepath.Join(os.TempDir(), "wal_prepared") 64 | dbDir := filepath.Join(os.TempDir(), "db_prepared") 65 | defer os.RemoveAll(walDir) 66 | defer os.RemoveAll(dbDir) 67 | 68 | w, err := gowal.NewWAL(gowal.Config{ 69 | Dir: walDir, 70 | Prefix: "wal_", 71 | SegmentThreshold: 1024 * 1024, 72 | MaxSegments: 10, 73 | }) 74 | require.NoError(t, err) 75 | 76 | height := uint64(15) 77 | prepIdx := walrecord.PreparedSlot(height) 78 | tx := walrecord.WalTx{Key: "key2", Value: []byte("value2")} 79 | encoded, _ := walrecord.Encode(tx) 80 | 81 | err = w.Write(prepIdx, walrecord.KeyPrepared, encoded) 82 | require.NoError(t, err) 83 | w.Close() 84 | 85 | w2, err := gowal.NewWAL(gowal.Config{Dir: walDir, Prefix: "wal_", SegmentThreshold: 1024 * 1024, MaxSegments: 10}) 86 | require.NoError(t, err) 87 | defer w2.Close() 88 | 89 | s, state, err := New(w2, dbDir) 90 | require.NoError(t, err) 91 | defer s.Close() 92 | 93 | // should NOT be applied to DB 94 | _, err = s.Get("key2") 95 | assert.Equal(t, ErrNotFound, err) 96 | 97 | // height should resume at 15 (incomplete) 98 | assert.Equal(t, height, state.Height) 99 | } 100 | 101 | func TestStore_Recovery_Abort(t *testing.T) { 102 | walDir := filepath.Join(os.TempDir(), "wal_abort") 103 | dbDir := filepath.Join(os.TempDir(), "db_abort") 104 | defer os.RemoveAll(walDir) 105 | defer os.RemoveAll(dbDir) 106 | 107 | w, err := gowal.NewWAL(gowal.Config{ 108 | Dir: walDir, 109 | Prefix: "wal_", 110 | SegmentThreshold: 1024 * 1024, 111 | MaxSegments: 10, 112 | }) 113 | require.NoError(t, err) 114 | 115 | height := uint64(20) 116 | 117 | // write abort 118 | abortIdx := walrecord.AbortSlot(height) 119 | err = w.Write(abortIdx, walrecord.KeyAbort, nil) 120 | require.NoError(t, err) 121 | w.Close() 122 | 123 | w2, err := gowal.NewWAL(gowal.Config{Dir: walDir, Prefix: "wal_", SegmentThreshold: 1024 * 1024, MaxSegments: 10}) 124 | require.NoError(t, err) 125 | defer w2.Close() 126 | 127 | s, state, err := New(w2, dbDir) 128 | require.NoError(t, err) 129 | defer s.Close() 130 | 131 | // height should be 20+1 because it was resolved (Aborted) 132 | assert.Equal(t, height+1, state.Height) 133 | } 134 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // Package main provides a distributed consensus system implementing Two-Phase Commit (2PC) 2 | // and Three-Phase Commit (3PC) protocols for distributed transactions. 3 | // 4 | // Committer is a Go implementation of distributed atomic commit protocols that allows 5 | // you to achieve data consistency in distributed systems using Two-Phase Commit (2PC) 6 | // and Three-Phase Commit (3PC) protocols for distributed transactions. 7 | // The system consists of coordinators that manage transactions and cohorts that 8 | // participate in the consensus process. 9 | // 10 | // Usage: 11 | // 12 | // # Start coordinator 13 | // ./committer -role=coordinator -nodeaddr=localhost:3000 -cohorts=localhost:3001,3002 14 | // 15 | // # Start cohort 16 | // ./committer -role=cohort -coordinator=localhost:3000 -nodeaddr=localhost:3001 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "log" 22 | "os" 23 | "os/signal" 24 | "syscall" 25 | 26 | "github.com/vadiminshakov/committer/config" 27 | "github.com/vadiminshakov/committer/core/cohort" 28 | "github.com/vadiminshakov/committer/core/cohort/commitalgo" 29 | "github.com/vadiminshakov/committer/core/coordinator" 30 | "github.com/vadiminshakov/committer/io/gateway/grpc/server" 31 | "github.com/vadiminshakov/committer/io/store" 32 | "github.com/vadiminshakov/gowal" 33 | ) 34 | 35 | func main() { 36 | if err := run(); err != nil { 37 | log.Fatalf("committer failed: %v", err) 38 | } 39 | } 40 | 41 | func run() error { 42 | ctx := make(chan os.Signal, 1) 43 | signal.Notify(ctx, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) 44 | conf := config.Get() 45 | 46 | wal, err := newWAL() 47 | if err != nil { 48 | return err 49 | } 50 | defer wal.Close() 51 | 52 | stateStore, recovery, err := newStore(conf, wal) 53 | if err != nil { 54 | return err 55 | } 56 | 57 | roles, err := buildRoles(conf, stateStore, wal, recovery.Height) 58 | if err != nil { 59 | return err 60 | } 61 | 62 | srv, err := server.New(conf, roles.cohort, roles.coordinator, stateStore) 63 | if err != nil { 64 | return fmt.Errorf("failed to create server: %w", err) 65 | } 66 | 67 | srv.Run(server.WhiteListChecker) 68 | <-ctx 69 | srv.Stop() 70 | 71 | return nil 72 | } 73 | 74 | func newWAL() (*gowal.Wal, error) { 75 | walConfig := gowal.Config{ 76 | Dir: config.DefaultWalDir, 77 | Prefix: config.DefaultWalSegmentPrefix, 78 | SegmentThreshold: config.DefaultWalSegmentThreshold, 79 | MaxSegments: config.DefaultWalMaxSegments, 80 | IsInSyncDiskMode: config.DefaultWalIsInSyncDiskMode, 81 | } 82 | 83 | w, err := gowal.NewWAL(walConfig) 84 | if err != nil { 85 | return nil, fmt.Errorf("failed to create WAL: %w", err) 86 | } 87 | 88 | return w, nil 89 | } 90 | 91 | func newStore(conf *config.Config, wal *gowal.Wal) (*store.Store, *store.RecoveryState, error) { 92 | stateStore, recovery, err := store.New(wal, conf.DBPath) 93 | if err != nil { 94 | return nil, nil, fmt.Errorf("failed to initialize state store: %w", err) 95 | } 96 | 97 | log.Printf("Recovered state from WAL: next height %d, keys %d\n", recovery.Height, stateStore.Size()) 98 | return stateStore, recovery, nil 99 | } 100 | 101 | type roleComponents struct { 102 | cohort server.Cohort 103 | coordinator server.Coordinator 104 | } 105 | 106 | func buildRoles(conf *config.Config, stateStore *store.Store, wal *gowal.Wal, initialHeight uint64) (*roleComponents, error) { 107 | rc := &roleComponents{} 108 | switch conf.Role { 109 | case "cohort": 110 | committer := commitalgo.NewCommitter(stateStore, conf.CommitType, wal, conf.Timeout) 111 | committer.SetHeight(initialHeight) 112 | rc.cohort = cohort.NewCohort(committer, cohort.Mode(conf.CommitType)) 113 | case "coordinator": 114 | coord, err := coordinator.New(conf, wal, stateStore) 115 | if err != nil { 116 | return nil, fmt.Errorf("failed to create coordinator: %w", err) 117 | } 118 | coord.SetHeight(initialHeight) 119 | rc.coordinator = coord 120 | default: 121 | return nil, fmt.Errorf("unsupported role %q, expected coordinator or cohort", conf.Role) 122 | } 123 | 124 | return rc, nil 125 | } 126 | -------------------------------------------------------------------------------- /mocks/mock_cohort.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/vadiminshakov/committer/io/gateway/grpc/server (interfaces: Cohort) 3 | // 4 | // Generated by this command: 5 | // 6 | // mockgen -destination=../../../../mocks/mock_cohort.go -package=mocks . Cohort 7 | // 8 | 9 | // Package mocks is a generated GoMock package. 10 | package mocks 11 | 12 | import ( 13 | context "context" 14 | reflect "reflect" 15 | 16 | dto "github.com/vadiminshakov/committer/core/dto" 17 | gomock "go.uber.org/mock/gomock" 18 | ) 19 | 20 | // MockCohort is a mock of Cohort interface. 21 | type MockCohort struct { 22 | ctrl *gomock.Controller 23 | recorder *MockCohortMockRecorder 24 | isgomock struct{} 25 | } 26 | 27 | // MockCohortMockRecorder is the mock recorder for MockCohort. 28 | type MockCohortMockRecorder struct { 29 | mock *MockCohort 30 | } 31 | 32 | // NewMockCohort creates a new mock instance. 33 | func NewMockCohort(ctrl *gomock.Controller) *MockCohort { 34 | mock := &MockCohort{ctrl: ctrl} 35 | mock.recorder = &MockCohortMockRecorder{mock} 36 | return mock 37 | } 38 | 39 | // EXPECT returns an object that allows the caller to indicate expected use. 40 | func (m *MockCohort) EXPECT() *MockCohortMockRecorder { 41 | return m.recorder 42 | } 43 | 44 | // Abort mocks base method. 45 | func (m *MockCohort) Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) { 46 | m.ctrl.T.Helper() 47 | ret := m.ctrl.Call(m, "Abort", ctx, req) 48 | ret0, _ := ret[0].(*dto.CohortResponse) 49 | ret1, _ := ret[1].(error) 50 | return ret0, ret1 51 | } 52 | 53 | // Abort indicates an expected call of Abort. 54 | func (mr *MockCohortMockRecorder) Abort(ctx, req any) *gomock.Call { 55 | mr.mock.ctrl.T.Helper() 56 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abort", reflect.TypeOf((*MockCohort)(nil).Abort), ctx, req) 57 | } 58 | 59 | // Commit mocks base method. 60 | func (m *MockCohort) Commit(ctx context.Context, in *dto.CommitRequest) (*dto.CohortResponse, error) { 61 | m.ctrl.T.Helper() 62 | ret := m.ctrl.Call(m, "Commit", ctx, in) 63 | ret0, _ := ret[0].(*dto.CohortResponse) 64 | ret1, _ := ret[1].(error) 65 | return ret0, ret1 66 | } 67 | 68 | // Commit indicates an expected call of Commit. 69 | func (mr *MockCohortMockRecorder) Commit(ctx, in any) *gomock.Call { 70 | mr.mock.ctrl.T.Helper() 71 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockCohort)(nil).Commit), ctx, in) 72 | } 73 | 74 | // Height mocks base method. 75 | func (m *MockCohort) Height() uint64 { 76 | m.ctrl.T.Helper() 77 | ret := m.ctrl.Call(m, "Height") 78 | ret0, _ := ret[0].(uint64) 79 | return ret0 80 | } 81 | 82 | // Height indicates an expected call of Height. 83 | func (mr *MockCohortMockRecorder) Height() *gomock.Call { 84 | mr.mock.ctrl.T.Helper() 85 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Height", reflect.TypeOf((*MockCohort)(nil).Height)) 86 | } 87 | 88 | // Precommit mocks base method. 89 | func (m *MockCohort) Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) { 90 | m.ctrl.T.Helper() 91 | ret := m.ctrl.Call(m, "Precommit", ctx, index) 92 | ret0, _ := ret[0].(*dto.CohortResponse) 93 | ret1, _ := ret[1].(error) 94 | return ret0, ret1 95 | } 96 | 97 | // Precommit indicates an expected call of Precommit. 98 | func (mr *MockCohortMockRecorder) Precommit(ctx, index any) *gomock.Call { 99 | mr.mock.ctrl.T.Helper() 100 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Precommit", reflect.TypeOf((*MockCohort)(nil).Precommit), ctx, index) 101 | } 102 | 103 | // Propose mocks base method. 104 | func (m *MockCohort) Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) { 105 | m.ctrl.T.Helper() 106 | ret := m.ctrl.Call(m, "Propose", ctx, req) 107 | ret0, _ := ret[0].(*dto.CohortResponse) 108 | ret1, _ := ret[1].(error) 109 | return ret0, ret1 110 | } 111 | 112 | // Propose indicates an expected call of Propose. 113 | func (mr *MockCohortMockRecorder) Propose(ctx, req any) *gomock.Call { 114 | mr.mock.ctrl.T.Helper() 115 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Propose", reflect.TypeOf((*MockCohort)(nil).Propose), ctx, req) 116 | } 117 | -------------------------------------------------------------------------------- /mocks/mock_committer.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/vadiminshakov/committer/core/cohort (interfaces: Committer) 3 | // 4 | // Generated by this command: 5 | // 6 | // mockgen -destination=../../mocks/mock_committer.go -package=mocks . Committer 7 | // 8 | 9 | // Package mocks is a generated GoMock package. 10 | package mocks 11 | 12 | import ( 13 | context "context" 14 | reflect "reflect" 15 | 16 | hooks "github.com/vadiminshakov/committer/core/cohort/commitalgo/hooks" 17 | dto "github.com/vadiminshakov/committer/core/dto" 18 | gomock "go.uber.org/mock/gomock" 19 | ) 20 | 21 | // MockCommitter is a mock of Committer interface. 22 | type MockCommitter struct { 23 | ctrl *gomock.Controller 24 | recorder *MockCommitterMockRecorder 25 | isgomock struct{} 26 | } 27 | 28 | // MockCommitterMockRecorder is the mock recorder for MockCommitter. 29 | type MockCommitterMockRecorder struct { 30 | mock *MockCommitter 31 | } 32 | 33 | // NewMockCommitter creates a new mock instance. 34 | func NewMockCommitter(ctrl *gomock.Controller) *MockCommitter { 35 | mock := &MockCommitter{ctrl: ctrl} 36 | mock.recorder = &MockCommitterMockRecorder{mock} 37 | return mock 38 | } 39 | 40 | // EXPECT returns an object that allows the caller to indicate expected use. 41 | func (m *MockCommitter) EXPECT() *MockCommitterMockRecorder { 42 | return m.recorder 43 | } 44 | 45 | // Abort mocks base method. 46 | func (m *MockCommitter) Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) { 47 | m.ctrl.T.Helper() 48 | ret := m.ctrl.Call(m, "Abort", ctx, req) 49 | ret0, _ := ret[0].(*dto.CohortResponse) 50 | ret1, _ := ret[1].(error) 51 | return ret0, ret1 52 | } 53 | 54 | // Abort indicates an expected call of Abort. 55 | func (mr *MockCommitterMockRecorder) Abort(ctx, req any) *gomock.Call { 56 | mr.mock.ctrl.T.Helper() 57 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abort", reflect.TypeOf((*MockCommitter)(nil).Abort), ctx, req) 58 | } 59 | 60 | // Commit mocks base method. 61 | func (m *MockCommitter) Commit(ctx context.Context, req *dto.CommitRequest) (*dto.CohortResponse, error) { 62 | m.ctrl.T.Helper() 63 | ret := m.ctrl.Call(m, "Commit", ctx, req) 64 | ret0, _ := ret[0].(*dto.CohortResponse) 65 | ret1, _ := ret[1].(error) 66 | return ret0, ret1 67 | } 68 | 69 | // Commit indicates an expected call of Commit. 70 | func (mr *MockCommitterMockRecorder) Commit(ctx, req any) *gomock.Call { 71 | mr.mock.ctrl.T.Helper() 72 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockCommitter)(nil).Commit), ctx, req) 73 | } 74 | 75 | // Height mocks base method. 76 | func (m *MockCommitter) Height() uint64 { 77 | m.ctrl.T.Helper() 78 | ret := m.ctrl.Call(m, "Height") 79 | ret0, _ := ret[0].(uint64) 80 | return ret0 81 | } 82 | 83 | // Height indicates an expected call of Height. 84 | func (mr *MockCommitterMockRecorder) Height() *gomock.Call { 85 | mr.mock.ctrl.T.Helper() 86 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Height", reflect.TypeOf((*MockCommitter)(nil).Height)) 87 | } 88 | 89 | // Precommit mocks base method. 90 | func (m *MockCommitter) Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) { 91 | m.ctrl.T.Helper() 92 | ret := m.ctrl.Call(m, "Precommit", ctx, index) 93 | ret0, _ := ret[0].(*dto.CohortResponse) 94 | ret1, _ := ret[1].(error) 95 | return ret0, ret1 96 | } 97 | 98 | // Precommit indicates an expected call of Precommit. 99 | func (mr *MockCommitterMockRecorder) Precommit(ctx, index any) *gomock.Call { 100 | mr.mock.ctrl.T.Helper() 101 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Precommit", reflect.TypeOf((*MockCommitter)(nil).Precommit), ctx, index) 102 | } 103 | 104 | // Propose mocks base method. 105 | func (m *MockCommitter) Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) { 106 | m.ctrl.T.Helper() 107 | ret := m.ctrl.Call(m, "Propose", ctx, req) 108 | ret0, _ := ret[0].(*dto.CohortResponse) 109 | ret1, _ := ret[1].(error) 110 | return ret0, ret1 111 | } 112 | 113 | // Propose indicates an expected call of Propose. 114 | func (mr *MockCommitterMockRecorder) Propose(ctx, req any) *gomock.Call { 115 | mr.mock.ctrl.T.Helper() 116 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Propose", reflect.TypeOf((*MockCommitter)(nil).Propose), ctx, req) 117 | } 118 | 119 | // RegisterHook mocks base method. 120 | func (m *MockCommitter) RegisterHook(hook hooks.Hook) { 121 | m.ctrl.T.Helper() 122 | m.ctrl.Call(m, "RegisterHook", hook) 123 | } 124 | 125 | // RegisterHook indicates an expected call of RegisterHook. 126 | func (mr *MockCommitterMockRecorder) RegisterHook(hook any) *gomock.Call { 127 | mr.mock.ctrl.T.Helper() 128 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterHook", reflect.TypeOf((*MockCommitter)(nil).RegisterHook), hook) 129 | } 130 | -------------------------------------------------------------------------------- /io/store/store.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | stdErrors "errors" 5 | "os" 6 | "strings" 7 | "sync" 8 | 9 | "github.com/dgraph-io/badger/v4" 10 | "github.com/pkg/errors" 11 | "github.com/vadiminshakov/committer/core/walrecord" 12 | "github.com/vadiminshakov/gowal" 13 | ) 14 | 15 | // ErrNotFound returned when key does not exist in the store. 16 | var ErrNotFound = errors.New("key not found") 17 | 18 | // Store persists committed key/value pairs in BadgerDB and reconstructs them from WAL on startup. 19 | type Store struct { 20 | wal *gowal.Wal 21 | db *badger.DB 22 | mu sync.RWMutex 23 | } 24 | 25 | // Snapshot returns a shallow copy of the current state. 26 | func (s *Store) Snapshot() map[string][]byte { 27 | s.mu.RLock() 28 | defer s.mu.RUnlock() 29 | 30 | snapshot := make(map[string][]byte) 31 | _ = s.db.View(func(txn *badger.Txn) error { 32 | it := txn.NewIterator(badger.DefaultIteratorOptions) 33 | defer it.Close() 34 | 35 | for it.Rewind(); it.Valid(); it.Next() { 36 | item := it.Item() 37 | key := item.KeyCopy(nil) 38 | if err := item.Value(func(val []byte) error { 39 | snapshot[string(key)] = cloneBytes(val) 40 | return nil 41 | }); err != nil { 42 | return err 43 | } 44 | } 45 | return nil 46 | }) 47 | 48 | return snapshot 49 | } 50 | 51 | // Size returns current number of keys in the store. 52 | func (s *Store) Size() int { 53 | s.mu.RLock() 54 | defer s.mu.RUnlock() 55 | 56 | count := 0 57 | _ = s.db.View(func(txn *badger.Txn) error { 58 | it := txn.NewIterator(badger.DefaultIteratorOptions) 59 | defer it.Close() 60 | for it.Rewind(); it.Valid(); it.Next() { 61 | count++ 62 | } 63 | return nil 64 | }) 65 | 66 | return count 67 | } 68 | 69 | // RecoveryState contains information extracted from WAL during startup. 70 | type RecoveryState struct { 71 | // Height is the current protocol height (next proposal should use this height). 72 | Height uint64 73 | } 74 | 75 | // New creates a new WAL-backed store and reconstructs state from WAL entries using BadgerDB. 76 | func New(wal *gowal.Wal, dbPath string) (*Store, *RecoveryState, error) { 77 | if wal == nil { 78 | return nil, nil, errors.New("wal is nil") 79 | } 80 | if dbPath == "" { 81 | return nil, nil, errors.New("db path is empty") 82 | } 83 | 84 | if err := os.MkdirAll(dbPath, 0o755); err != nil { 85 | return nil, nil, errors.Wrap(err, "create badger directory") 86 | } 87 | 88 | opts := badger.DefaultOptions(dbPath) 89 | db, err := badger.Open(opts) 90 | if err != nil { 91 | return nil, nil, errors.Wrap(err, "open badger db") 92 | } 93 | 94 | state := &Store{ 95 | wal: wal, 96 | db: db, 97 | } 98 | 99 | recovery, err := state.recover() 100 | if err != nil { 101 | _ = db.Close() 102 | return nil, nil, err 103 | } 104 | 105 | return state, recovery, nil 106 | } 107 | 108 | // Put stores the provided value for the key. 109 | func (s *Store) Put(key string, value []byte) error { 110 | if key == "" { 111 | return errors.New("key cannot be empty") 112 | } 113 | 114 | s.mu.Lock() 115 | defer s.mu.Unlock() 116 | 117 | return s.db.Update(func(txn *badger.Txn) error { 118 | if value == nil { 119 | if err := txn.Delete([]byte(key)); err != nil && !stdErrors.Is(err, badger.ErrKeyNotFound) { 120 | return err 121 | } 122 | return nil 123 | } 124 | return txn.Set([]byte(key), cloneBytes(value)) 125 | }) 126 | } 127 | 128 | // Get retrieves value by key. Returns ErrNotFound if key does not exist. 129 | func (s *Store) Get(key string) ([]byte, error) { 130 | s.mu.RLock() 131 | defer s.mu.RUnlock() 132 | 133 | var result []byte 134 | err := s.db.View(func(txn *badger.Txn) error { 135 | item, err := txn.Get([]byte(key)) 136 | if err != nil { 137 | if stdErrors.Is(err, badger.ErrKeyNotFound) { 138 | return ErrNotFound 139 | } 140 | return err 141 | } 142 | result, err = item.ValueCopy(nil) 143 | return err 144 | }) 145 | if err != nil { 146 | return nil, err 147 | } 148 | 149 | return cloneBytes(result), nil 150 | } 151 | 152 | // Close closes the underlying Badger database. 153 | func (s *Store) Close() error { 154 | return s.db.Close() 155 | } 156 | 157 | func (s *Store) recover() (*RecoveryState, error) { 158 | var ( 159 | hasProto bool 160 | maxProtoHeight uint64 161 | maxProtoStatus string // last seen status for maxProtoHeight 162 | ) 163 | 164 | for msg := range s.wal.Iterator() { 165 | if msg.Key == "" || msg.Key == "skip" { 166 | continue 167 | } 168 | 169 | // check if it is a system/protocol key 170 | if strings.HasPrefix(msg.Key, "__tx:") { 171 | hasProto = true 172 | height := msg.Idx / walrecord.Stride 173 | 174 | // track max proto height and its status 175 | if height > maxProtoHeight { 176 | maxProtoHeight = height 177 | maxProtoStatus = msg.Key 178 | } else if height == maxProtoHeight { 179 | // if multiple messages for same height, commit/abort override prepared/precommit as "final" status 180 | if msg.Key == walrecord.KeyCommit || msg.Key == walrecord.KeyAbort { 181 | maxProtoStatus = msg.Key 182 | } 183 | } 184 | 185 | // apply strictly on commit 186 | if msg.Key == walrecord.KeyCommit { 187 | walTx, err := walrecord.Decode(msg.Value) 188 | if err != nil { 189 | return nil, errors.Wrapf(err, "decode wal tx at idx %d", msg.Idx) 190 | } 191 | if err := s.Put(walTx.Key, walTx.Value); err != nil { 192 | return nil, errors.Wrapf(err, "apply committed tx at idx %d", msg.Idx) 193 | } 194 | } 195 | } 196 | } 197 | 198 | var height uint64 199 | if hasProto { 200 | if maxProtoStatus == walrecord.KeyCommit || maxProtoStatus == walrecord.KeyAbort { 201 | height = maxProtoHeight + 1 202 | } else { 203 | // incomplete transaction at maxProtoHeight 204 | height = maxProtoHeight 205 | } 206 | } 207 | 208 | return &RecoveryState{Height: height}, nil 209 | } 210 | 211 | func cloneBytes(src []byte) []byte { 212 | if src == nil { 213 | return nil 214 | } 215 | 216 | dst := make([]byte, len(src)) 217 | copy(dst, src) 218 | return dst 219 | } 220 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![tests](https://github.com/vadiminshakov/committer/actions/workflows/tests.yml/badge.svg?branch=master) 2 | [![Go Reference](https://pkg.go.dev/badge/github.com/vadiminshakov/committer.svg)](https://pkg.go.dev/github.com/vadiminshakov/committer) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/vadiminshakov/committer)](https://goreportcard.com/report/github.com/vadiminshakov/committer) 4 | [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) 5 | 6 |

7 | Committer Logo 8 |

9 | 10 | # **Committer** 11 | 12 | **Committer** is a Go implementation of the **Two-Phase Commit (2PC)** and **Three-Phase Commit (3PC)** protocols for distributed systems. 13 | 14 | ## **Architecture** 15 | 16 | The system consists of two types of nodes: **Coordinator** and **Cohorts**. 17 | The **Coordinator** is responsible for initiating and managing the commit protocols (2PC or 3PC), while the **Cohorts** participate in the protocol by responding to the coordinator's requests. 18 | The communication between nodes is handled using gRPC, and the state of each node is managed using a state machine. 19 | 20 | ## **Atomic Commit Protocols** 21 | 22 | ### **Two-Phase Commit (2PC)** 23 | 24 | The Two-Phase Commit protocol ensures atomicity in distributed transactions through two distinct phases: 25 | 26 | #### **Phase 1: Voting Phase (Propose)** 27 | 1. **Coordinator** sends a `PROPOSE` request to all cohorts with transaction data 28 | 2. Each **Cohort** validates the transaction locally and responds: 29 | - `ACK` (Yes) - if ready to commit 30 | - `NACK` (No) - if unable to commit 31 | 3. **Coordinator** waits for all responses 32 | 33 | #### **Phase 2: Commit Phase** 34 | 1. If **all cohorts** voted `ACK`: 35 | - **Coordinator** sends `COMMIT` to all cohorts 36 | - Each **Cohort** commits the transaction and responds with `ACK` 37 | 2. If **any cohort** voted `NACK`: 38 | - **Coordinator** sends `ABORT` to all cohorts 39 | - Each **Cohort** aborts the transaction 40 | 41 | ### **Three-Phase Commit (3PC)** 42 | 43 | The Three-Phase Commit protocol extends 2PC with an additional phase to reduce blocking scenarios: 44 | 45 | #### **Phase 1: Voting Phase (Propose)** 46 | 1. **Coordinator** sends `PROPOSE` request to all cohorts 47 | 2. **Cohorts** respond with `ACK`/`NACK` (same as 2PC) 48 | 49 | #### **Phase 2: Preparation Phase (Precommit)** 50 | 1. If all cohorts voted `ACK`: 51 | - **Coordinator** sends `PRECOMMIT` to all cohorts 52 | - **Cohorts** acknowledge they're prepared to commit 53 | - **Timeout mechanism**: If cohort doesn't receive `COMMIT` within timeout, it auto-commits 54 | 2. If any cohort voted `NACK`: 55 | - **Coordinator** sends `ABORT` to all cohorts 56 | 57 | #### **Phase 3: Commit Phase** 58 | 1. **Coordinator** sends `COMMIT` to all cohorts 59 | 2. **Cohorts** perform the actual commit operation 60 | 61 | ## **Configuration** 62 | 63 | All configuration parameters can be set using command-line flags: 64 | 65 | | **Flag** | **Description** | **Default** | **Example** | 66 | |-----------------|---------------------------------------------------------|---------------------|-------------------------------------| 67 | | `role` | Node role: `coordinator` or `cohort` | `cohort` | `-role=coordinator` | 68 | | `nodeaddr` | Address of the current node | `localhost:3050` | `-nodeaddr=localhost:3051` | 69 | | `coordinator` | Coordinator address (required for cohorts) | `""` | `-coordinator=localhost:3050` | 70 | | `committype` | Commit protocol: `two-phase` or `three-phase` | `three-phase` | `-committype=two-phase` | 71 | | `timeout` | Timeout (ms) for unacknowledged messages (3PC only) | `1000` | `-timeout=500` | 72 | | `dbpath` | Path to the BadgerDB database on the filesystem | `./badger` | `-dbpath=/tmp/badger` | 73 | | `cohorts` | Comma-separated list of cohort addresses | `""` | `-cohorts=localhost:3052,3053` | 74 | | `whitelist` | Comma-separated list of allowed hosts | `127.0.0.1` | `-whitelist=192.168.0.1,192.168.0.2`| 75 | 76 | 77 | ## **Usage** 78 | 79 | ### **Running as a Cohort** 80 | ```bash 81 | ./committer -role=cohort -nodeaddr=localhost:3001 -coordinator=localhost:3000 -committype=three-phase -timeout=1000 -dbpath=/tmp/badger/cohort 82 | ``` 83 | 84 | ### **Running as a Coordinator** 85 | ```bash 86 | ./committer -role=coordinator -nodeaddr=localhost:3000 -cohorts=localhost:3001 -committype=three-phase -timeout=1000 -dbpath=/tmp/badger/coordinator 87 | ``` 88 | 89 | ## **Hooks System** 90 | 91 | The hooks system allows you to add custom validation and business logic during the **Propose** and **Commit** stages without modifying the core code. Hooks are executed in the order they were registered, and if any hook returns `false`, the operation is rejected. 92 | 93 | ```go 94 | // Default usage (with built-in default hook) 95 | committer := commitalgo.NewCommitter(database, "three-phase", wal, timeout) 96 | 97 | // With custom hooks 98 | metricsHook := hooks.NewMetricsHook() 99 | validationHook := hooks.NewValidationHook(100, 1024) 100 | auditHook := hooks.NewAuditHook("audit.log") 101 | 102 | committer := commitalgo.NewCommitter(database, "three-phase", wal, timeout, 103 | metricsHook, 104 | validationHook, 105 | auditHook, 106 | ) 107 | 108 | // Dynamic registration 109 | committer.RegisterHook(myCustomHook) 110 | ``` 111 | 112 | ## **Testing** 113 | 114 | ### **Run Functional Tests** 115 | ```bash 116 | make tests 117 | ``` 118 | 119 | ### **Testing with Example Client** 120 | 1. Compile executables: 121 | 122 | ```bash 123 | make prepare 124 | ``` 125 | 126 | 2. Run the coordinator: 127 | 128 | ```bash 129 | make run-example-coordinator 130 | ``` 131 | 132 | 3. Run a cohort in another terminal: 133 | 134 | ```bash 135 | make run-example-cohort 136 | ``` 137 | 138 | 4. Start the example client: 139 | 140 | ```bash 141 | make run-example-client 142 | ``` 143 | 144 | Or directly: 145 | 146 | ```bash 147 | go run ./examples/client/client.go 148 | ``` 149 | 150 | ## **Contributions** 151 | 152 | Contributions are welcome! Feel free to submit a PR or open an issue if you find bugs or have suggestions for improvement. 153 | 154 | ## **License** 155 | 156 | This project is licensed under the [Apache License](LICENSE). -------------------------------------------------------------------------------- /toxiproxy_test.go: -------------------------------------------------------------------------------- 1 | //go:build chaos 2 | 3 | package main 4 | 5 | import ( 6 | "bytes" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "strconv" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | type toxiproxyClient struct { 17 | BaseURL string 18 | HTTPClient *http.Client 19 | } 20 | 21 | type proxy struct { 22 | Name string `json:"name"` 23 | Listen string `json:"listen"` 24 | Upstream string `json:"upstream"` 25 | Enabled bool `json:"enabled"` 26 | } 27 | 28 | type toxic struct { 29 | Name string `json:"name"` 30 | Type string `json:"type"` 31 | Stream string `json:"stream"` 32 | Toxicity float32 `json:"toxicity"` 33 | Attributes map[string]interface{} `json:"attributes"` 34 | } 35 | 36 | func newToxiproxyClient(baseURL string) *toxiproxyClient { 37 | return &toxiproxyClient{ 38 | BaseURL: baseURL, 39 | HTTPClient: &http.Client{Timeout: 10 * time.Second}, 40 | } 41 | } 42 | 43 | func (c *toxiproxyClient) createProxy(name, listen, upstream string) (*proxy, error) { 44 | p := &proxy{ 45 | Name: name, 46 | Listen: listen, 47 | Upstream: upstream, 48 | Enabled: true, 49 | } 50 | 51 | data, err := json.Marshal(p) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | resp, err := c.HTTPClient.Post(c.BaseURL+"/proxies", "application/json", bytes.NewBuffer(data)) 57 | if err != nil { 58 | return nil, err 59 | } 60 | defer resp.Body.Close() 61 | 62 | body, _ := io.ReadAll(resp.Body) 63 | if resp.StatusCode != http.StatusCreated { 64 | return nil, fmt.Errorf("failed to create proxy (status %d): %s", resp.StatusCode, string(body)) 65 | } 66 | 67 | var createdProxy proxy 68 | if err := json.Unmarshal(body, &createdProxy); err != nil { 69 | return nil, fmt.Errorf("failed to parse created proxy response: %v, body: %s", err, string(body)) 70 | } 71 | 72 | return &createdProxy, nil 73 | } 74 | 75 | func (c *toxiproxyClient) addToxic(proxyName string, toxic *toxic) error { 76 | data, err := json.Marshal(toxic) 77 | if err != nil { 78 | return err 79 | } 80 | 81 | url := fmt.Sprintf("%s/proxies/%s/toxics", c.BaseURL, proxyName) 82 | resp, err := c.HTTPClient.Post(url, "application/json", bytes.NewBuffer(data)) 83 | if err != nil { 84 | return err 85 | } 86 | defer resp.Body.Close() 87 | 88 | body, _ := io.ReadAll(resp.Body) 89 | if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { 90 | return fmt.Errorf("failed to add toxic to proxy '%s' (status %d): %s. Request: %s", proxyName, resp.StatusCode, string(body), string(data)) 91 | } 92 | 93 | return nil 94 | } 95 | 96 | func (c *toxiproxyClient) deleteProxy(name string) error { 97 | url := fmt.Sprintf("%s/proxies/%s", c.BaseURL, name) 98 | req, err := http.NewRequest("DELETE", url, nil) 99 | if err != nil { 100 | return err 101 | } 102 | 103 | resp, err := c.HTTPClient.Do(req) 104 | if err != nil { 105 | return err 106 | } 107 | defer resp.Body.Close() 108 | 109 | if resp.StatusCode != http.StatusNoContent { 110 | body, _ := io.ReadAll(resp.Body) 111 | return fmt.Errorf("failed to delete proxy: %s", string(body)) 112 | } 113 | 114 | return nil 115 | } 116 | 117 | func (c *toxiproxyClient) resetProxy(name string) error { 118 | url := fmt.Sprintf("%s/proxies/%s/toxics", c.BaseURL, name) 119 | req, err := http.NewRequest("DELETE", url, nil) 120 | if err != nil { 121 | return err 122 | } 123 | 124 | resp, err := c.HTTPClient.Do(req) 125 | if err != nil { 126 | return err 127 | } 128 | defer resp.Body.Close() 129 | 130 | return nil 131 | } 132 | 133 | type chaosTestHelper struct { 134 | client *toxiproxyClient 135 | proxies map[string]*proxy 136 | proxyMapping map[string]string // original -> proxy address 137 | } 138 | 139 | func newChaosTestHelper(toxiproxyURL string) *chaosTestHelper { 140 | return &chaosTestHelper{ 141 | client: newToxiproxyClient(toxiproxyURL), 142 | proxies: make(map[string]*proxy), 143 | proxyMapping: make(map[string]string), 144 | } 145 | } 146 | 147 | func (h *chaosTestHelper) setupProxies(nodeAddresses []string) error { 148 | for i, addr := range nodeAddresses { 149 | proxyName := fmt.Sprintf("node_%d", i) 150 | proxyAddr := h.generateProxyAddress(addr) 151 | 152 | proxy, err := h.client.createProxy(proxyName, proxyAddr, addr) 153 | if err != nil { 154 | return fmt.Errorf("failed to create proxy for %s: %v", addr, err) 155 | } 156 | 157 | h.proxies[proxyName] = proxy 158 | h.proxyMapping[addr] = proxyAddr 159 | } 160 | 161 | return nil 162 | } 163 | 164 | // addResetPeer simulates TCP RESET after optional timeout 165 | func (h *chaosTestHelper) addResetPeer(nodeAddr string, timeout time.Duration) error { 166 | proxyName := h.getProxyName(nodeAddr) 167 | if proxyName == "" { 168 | return fmt.Errorf("proxy not found for address %s", nodeAddr) 169 | } 170 | 171 | toxic := &toxic{ 172 | Name: fmt.Sprintf("reset_peer_%s", proxyName), 173 | Type: "reset_peer", 174 | Stream: "downstream", 175 | Toxicity: 1.0, 176 | Attributes: map[string]interface{}{ 177 | "timeout": int(timeout.Milliseconds()), 178 | }, 179 | } 180 | 181 | return h.client.addToxic(proxyName, toxic) 182 | } 183 | 184 | // addDataLimit closes connection after transmitting specified bytes 185 | func (h *chaosTestHelper) addDataLimit(nodeAddr string, bytes int) error { 186 | proxyName := h.getProxyName(nodeAddr) 187 | if proxyName == "" { 188 | return fmt.Errorf("proxy not found for address %s", nodeAddr) 189 | } 190 | 191 | toxic := &toxic{ 192 | Name: fmt.Sprintf("limit_data_%s", proxyName), 193 | Type: "limit_data", 194 | Stream: "downstream", 195 | Toxicity: 1.0, 196 | Attributes: map[string]interface{}{ 197 | "bytes": bytes, 198 | }, 199 | } 200 | 201 | return h.client.addToxic(proxyName, toxic) 202 | } 203 | 204 | func (h *chaosTestHelper) getProxyAddress(originalAddr string) string { 205 | return h.proxyMapping[originalAddr] 206 | } 207 | 208 | func (h *chaosTestHelper) cleanup() error { 209 | for proxyName := range h.proxies { 210 | if err := h.client.deleteProxy(proxyName); err != nil { 211 | return err 212 | } 213 | } 214 | 215 | h.proxies = make(map[string]*proxy) 216 | h.proxyMapping = make(map[string]string) 217 | return nil 218 | } 219 | 220 | func (h *chaosTestHelper) generateProxyAddress(originalAddr string) string { 221 | parts := strings.Split(originalAddr, ":") 222 | if len(parts) != 2 { 223 | return originalAddr 224 | } 225 | 226 | port, err := strconv.Atoi(parts[1]) 227 | if err != nil { 228 | return originalAddr 229 | } 230 | 231 | proxyPort := port + 10000 232 | return fmt.Sprintf("%s:%d", parts[0], proxyPort) 233 | } 234 | 235 | func (h *chaosTestHelper) getProxyName(nodeAddr string) string { 236 | for name, proxy := range h.proxies { 237 | if proxy.Upstream == nodeAddr { 238 | return name 239 | } 240 | } 241 | return "" 242 | } 243 | -------------------------------------------------------------------------------- /core/cohort/cohort_test.go: -------------------------------------------------------------------------------- 1 | package cohort 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | "github.com/vadiminshakov/committer/core/dto" 10 | "github.com/vadiminshakov/committer/mocks" 11 | "go.uber.org/mock/gomock" 12 | ) 13 | 14 | func TestNewCohort(t *testing.T) { 15 | ctrl := gomock.NewController(t) 16 | defer ctrl.Finish() 17 | 18 | mockCommitter := mocks.NewMockCommitter(ctrl) 19 | 20 | // test creating 2PC cohort 21 | cohort := NewCohort(mockCommitter, "two-phase") 22 | require.NotNil(t, cohort) 23 | require.Equal(t, Mode("two-phase"), cohort.commitType) 24 | 25 | // test creating 3PC cohort 26 | cohort3PC := NewCohort(mockCommitter, THREE_PHASE) 27 | require.NotNil(t, cohort3PC) 28 | require.Equal(t, THREE_PHASE, cohort3PC.commitType) 29 | } 30 | 31 | func TestCohort_Height(t *testing.T) { 32 | ctrl := gomock.NewController(t) 33 | defer ctrl.Finish() 34 | 35 | mockCommitter := mocks.NewMockCommitter(ctrl) 36 | cohort := NewCohort(mockCommitter, "two-phase") 37 | 38 | // expect Height to be called and return 5 39 | mockCommitter.EXPECT().Height().Return(uint64(5)) 40 | 41 | height := cohort.Height() 42 | require.Equal(t, uint64(5), height) 43 | } 44 | 45 | func TestCohort_Propose(t *testing.T) { 46 | ctrl := gomock.NewController(t) 47 | defer ctrl.Finish() 48 | 49 | mockCommitter := mocks.NewMockCommitter(ctrl) 50 | cohort := NewCohort(mockCommitter, "two-phase") 51 | 52 | ctx := context.Background() 53 | proposeReq := &dto.ProposeRequest{ 54 | Height: 0, 55 | Key: "test-key", 56 | Value: []byte("test-value"), 57 | } 58 | 59 | expectedResp := &dto.CohortResponse{ 60 | ResponseType: dto.ResponseTypeAck, 61 | Height: 0, 62 | } 63 | 64 | // expect Propose to be called and return success 65 | mockCommitter.EXPECT().Propose(ctx, proposeReq).Return(expectedResp, nil) 66 | 67 | resp, err := cohort.Propose(ctx, proposeReq) 68 | require.NoError(t, err) 69 | require.NotNil(t, resp) 70 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 71 | require.Equal(t, uint64(0), resp.Height) 72 | } 73 | 74 | func TestCohort_Propose_Error(t *testing.T) { 75 | ctrl := gomock.NewController(t) 76 | defer ctrl.Finish() 77 | 78 | mockCommitter := mocks.NewMockCommitter(ctrl) 79 | cohort := NewCohort(mockCommitter, "two-phase") 80 | 81 | ctx := context.Background() 82 | proposeReq := &dto.ProposeRequest{ 83 | Height: 0, 84 | Key: "test-key", 85 | Value: []byte("test-value"), 86 | } 87 | 88 | expectedErr := fmt.Errorf("propose failed") 89 | 90 | // expect Propose to be called and return error 91 | mockCommitter.EXPECT().Propose(ctx, proposeReq).Return(nil, expectedErr) 92 | 93 | resp, err := cohort.Propose(ctx, proposeReq) 94 | require.Error(t, err) 95 | require.Nil(t, resp) 96 | require.Equal(t, expectedErr, err) 97 | } 98 | 99 | func TestCohort_Precommit_TwoPhase(t *testing.T) { 100 | ctrl := gomock.NewController(t) 101 | defer ctrl.Finish() 102 | 103 | mockCommitter := mocks.NewMockCommitter(ctrl) 104 | cohort := NewCohort(mockCommitter, "two-phase") 105 | 106 | ctx := context.Background() 107 | 108 | // test precommit in 2PC mode (should fail) 109 | _, err := cohort.Precommit(ctx, 0) 110 | require.Error(t, err) 111 | require.Contains(t, err.Error(), "precommit is allowed for 3PC mode only") 112 | } 113 | 114 | func TestCohort_Precommit_ThreePhase(t *testing.T) { 115 | ctrl := gomock.NewController(t) 116 | defer ctrl.Finish() 117 | 118 | mockCommitter := mocks.NewMockCommitter(ctrl) 119 | cohort := NewCohort(mockCommitter, THREE_PHASE) 120 | 121 | ctx := context.Background() 122 | 123 | expectedResp := &dto.CohortResponse{ 124 | ResponseType: dto.ResponseTypeAck, 125 | } 126 | 127 | // expect Precommit to be called and return success 128 | mockCommitter.EXPECT().Precommit(ctx, uint64(0)).Return(expectedResp, nil) 129 | 130 | // test precommit in 3PC mode (should succeed) 131 | resp, err := cohort.Precommit(ctx, 0) 132 | require.NoError(t, err) 133 | require.NotNil(t, resp) 134 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 135 | } 136 | 137 | func TestCohort_Precommit_ThreePhase_Error(t *testing.T) { 138 | ctrl := gomock.NewController(t) 139 | defer ctrl.Finish() 140 | 141 | mockCommitter := mocks.NewMockCommitter(ctrl) 142 | cohort := NewCohort(mockCommitter, THREE_PHASE) 143 | 144 | ctx := context.Background() 145 | expectedErr := fmt.Errorf("precommit failed") 146 | 147 | // expect Precommit to be called and return error 148 | mockCommitter.EXPECT().Precommit(ctx, uint64(0)).Return(nil, expectedErr) 149 | 150 | resp, err := cohort.Precommit(ctx, 0) 151 | require.Error(t, err) 152 | require.Nil(t, resp) 153 | require.Equal(t, expectedErr, err) 154 | } 155 | 156 | func TestCohort_Commit(t *testing.T) { 157 | ctrl := gomock.NewController(t) 158 | defer ctrl.Finish() 159 | 160 | mockCommitter := mocks.NewMockCommitter(ctrl) 161 | cohort := NewCohort(mockCommitter, "two-phase") 162 | 163 | ctx := context.Background() 164 | commitReq := &dto.CommitRequest{Height: 0} 165 | 166 | expectedResp := &dto.CohortResponse{ 167 | ResponseType: dto.ResponseTypeAck, 168 | } 169 | 170 | // expect Commit to be called and return success 171 | mockCommitter.EXPECT().Commit(ctx, commitReq).Return(expectedResp, nil) 172 | 173 | resp, err := cohort.Commit(ctx, commitReq) 174 | require.NoError(t, err) 175 | require.NotNil(t, resp) 176 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 177 | } 178 | 179 | func TestCohort_Commit_Error(t *testing.T) { 180 | ctrl := gomock.NewController(t) 181 | defer ctrl.Finish() 182 | 183 | mockCommitter := mocks.NewMockCommitter(ctrl) 184 | cohort := NewCohort(mockCommitter, "two-phase") 185 | 186 | ctx := context.Background() 187 | commitReq := &dto.CommitRequest{Height: 0} 188 | expectedErr := fmt.Errorf("commit failed") 189 | 190 | // expect Commit to be called and return error 191 | mockCommitter.EXPECT().Commit(ctx, commitReq).Return(nil, expectedErr) 192 | 193 | resp, err := cohort.Commit(ctx, commitReq) 194 | require.Error(t, err) 195 | require.Nil(t, resp) 196 | require.Equal(t, expectedErr, err) 197 | } 198 | 199 | func TestCohort_Commit_Nack(t *testing.T) { 200 | ctrl := gomock.NewController(t) 201 | defer ctrl.Finish() 202 | 203 | mockCommitter := mocks.NewMockCommitter(ctrl) 204 | cohort := NewCohort(mockCommitter, "two-phase") 205 | 206 | ctx := context.Background() 207 | commitReq := &dto.CommitRequest{Height: 0} 208 | 209 | expectedResp := &dto.CohortResponse{ 210 | ResponseType: dto.ResponseTypeNack, 211 | } 212 | 213 | // expect Commit to be called and return NACK 214 | mockCommitter.EXPECT().Commit(ctx, commitReq).Return(expectedResp, nil) 215 | 216 | resp, err := cohort.Commit(ctx, commitReq) 217 | require.NoError(t, err) 218 | require.NotNil(t, resp) 219 | require.Equal(t, dto.ResponseTypeNack, resp.ResponseType) 220 | } 221 | 222 | func TestCohort_ModeValidation(t *testing.T) { 223 | ctrl := gomock.NewController(t) 224 | defer ctrl.Finish() 225 | 226 | mockCommitter := mocks.NewMockCommitter(ctrl) 227 | 228 | // test different modes 229 | modes := []Mode{"two-phase", THREE_PHASE, "custom-mode"} 230 | 231 | for _, mode := range modes { 232 | cohort := NewCohort(mockCommitter, mode) 233 | require.NotNil(t, cohort) 234 | require.Equal(t, mode, cohort.commitType) 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /core/coordinator/coordinator_test.go: -------------------------------------------------------------------------------- 1 | package coordinator 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync/atomic" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/require" 11 | "github.com/vadiminshakov/committer/config" 12 | "github.com/vadiminshakov/committer/core/dto" 13 | "github.com/vadiminshakov/committer/core/walrecord" 14 | "github.com/vadiminshakov/committer/io/gateway/grpc/server" 15 | "github.com/vadiminshakov/committer/mocks" 16 | "go.uber.org/mock/gomock" 17 | ) 18 | 19 | func TestCoordinator_New(t *testing.T) { 20 | ctrl := gomock.NewController(t) 21 | defer ctrl.Finish() 22 | 23 | mockWAL := mocks.NewMockwal(ctrl) 24 | mockStore := mocks.NewMockStateStore(ctrl) 25 | 26 | conf := &config.Config{ 27 | Nodeaddr: "localhost:8080", 28 | Role: "coordinator", 29 | Cohorts: []string{"localhost:8081", "localhost:8082"}, 30 | CommitType: server.TWO_PHASE, 31 | } 32 | 33 | coord, err := New(conf, mockWAL, mockStore) 34 | require.NoError(t, err) 35 | require.NotNil(t, coord) 36 | require.Equal(t, uint64(0), coord.Height()) 37 | require.Len(t, coord.cohorts, 2) 38 | } 39 | 40 | func TestCoordinator_Height(t *testing.T) { 41 | ctrl := gomock.NewController(t) 42 | defer ctrl.Finish() 43 | 44 | mockWAL := mocks.NewMockwal(ctrl) 45 | mockStore := mocks.NewMockStateStore(ctrl) 46 | 47 | conf := &config.Config{ 48 | Nodeaddr: "localhost:8080", 49 | Role: "coordinator", 50 | Cohorts: []string{}, 51 | CommitType: server.TWO_PHASE, 52 | } 53 | 54 | coord, err := New(conf, mockWAL, mockStore) 55 | require.NoError(t, err) 56 | 57 | // test initial height 58 | require.Equal(t, uint64(0), coord.Height()) 59 | 60 | // test height increment 61 | atomic.AddUint64(&coord.height, 1) 62 | require.Equal(t, uint64(1), coord.Height()) 63 | } 64 | 65 | func TestCoordinator_SyncHeight(t *testing.T) { 66 | ctrl := gomock.NewController(t) 67 | defer ctrl.Finish() 68 | 69 | mockWAL := mocks.NewMockwal(ctrl) 70 | mockStore := mocks.NewMockStateStore(ctrl) 71 | 72 | conf := &config.Config{ 73 | Nodeaddr: "localhost:8080", 74 | Role: "coordinator", 75 | Cohorts: []string{}, 76 | CommitType: server.TWO_PHASE, 77 | } 78 | 79 | coord, err := New(conf, mockWAL, mockStore) 80 | require.NoError(t, err) 81 | 82 | // test sync to higher height 83 | coord.syncHeight(5) 84 | require.Equal(t, uint64(5), coord.Height()) 85 | 86 | // test sync to lower height (should not change) 87 | coord.syncHeight(3) 88 | require.Equal(t, uint64(5), coord.Height()) 89 | 90 | // test sync to same height (should not change) 91 | coord.syncHeight(5) 92 | require.Equal(t, uint64(5), coord.Height()) 93 | } 94 | 95 | func TestCoordinator_PersistMessage(t *testing.T) { 96 | ctrl := gomock.NewController(t) 97 | defer ctrl.Finish() 98 | 99 | mockWAL := mocks.NewMockwal(ctrl) 100 | mockStore := mocks.NewMockStateStore(ctrl) 101 | 102 | conf := &config.Config{ 103 | Nodeaddr: "localhost:8080", 104 | Role: "coordinator", 105 | Cohorts: []string{}, 106 | CommitType: server.TWO_PHASE, 107 | } 108 | 109 | coord, err := New(conf, mockWAL, mockStore) 110 | require.NoError(t, err) 111 | 112 | // test persist message success 113 | testKey := "test-key" 114 | testValue := []byte("test-value") 115 | 116 | // expect WAL.Get to return the test data 117 | // expect WAL.Get to return the test data 118 | ptx := walrecord.WalTx{Key: testKey, Value: testValue} 119 | encoded, _ := walrecord.Encode(ptx) 120 | mockWAL.EXPECT().Get(walrecord.PreparedSlot(0)).Return(walrecord.KeyPrepared, encoded, nil) 121 | 122 | // expect DB.Put to be called with the test data 123 | mockStore.EXPECT().Put(testKey, testValue).Return(nil) 124 | 125 | err = coord.persistMessage() 126 | require.NoError(t, err) 127 | } 128 | 129 | func TestCoordinator_PersistMessage_NoDataInWAL(t *testing.T) { 130 | ctrl := gomock.NewController(t) 131 | defer ctrl.Finish() 132 | 133 | mockWAL := mocks.NewMockwal(ctrl) 134 | mockStore := mocks.NewMockStateStore(ctrl) 135 | 136 | conf := &config.Config{ 137 | Nodeaddr: "localhost:8080", 138 | Role: "coordinator", 139 | Cohorts: []string{}, 140 | CommitType: server.TWO_PHASE, 141 | } 142 | 143 | coord, err := New(conf, mockWAL, mockStore) 144 | require.NoError(t, err) 145 | 146 | // expect WAL.Get to return no data 147 | mockWAL.EXPECT().Get(walrecord.PreparedSlot(0)).Return("", nil, nil) 148 | 149 | // test persist message when no data in WAL 150 | err = coord.persistMessage() 151 | require.Error(t, err) 152 | require.Contains(t, err.Error(), "can't find msg in wal") 153 | } 154 | 155 | func TestCoordinator_Abort(t *testing.T) { 156 | ctrl := gomock.NewController(t) 157 | defer ctrl.Finish() 158 | 159 | mockWAL := mocks.NewMockwal(ctrl) 160 | mockStore := mocks.NewMockStateStore(ctrl) 161 | 162 | conf := &config.Config{ 163 | Nodeaddr: "localhost:8080", 164 | Role: "coordinator", 165 | Cohorts: []string{}, 166 | CommitType: server.TWO_PHASE, 167 | } 168 | 169 | coord, err := New(conf, mockWAL, mockStore) 170 | require.NoError(t, err) 171 | 172 | ctx := context.Background() 173 | 174 | // expect WAL abort write 175 | mockWAL.EXPECT().Write(walrecord.AbortSlot(0), walrecord.KeyAbort, gomock.Any()).Return(nil) 176 | 177 | // test abort 178 | done := make(chan bool, 1) 179 | go func() { 180 | coord.abort(ctx, "test abort reason") 181 | done <- true 182 | }() 183 | 184 | select { 185 | case <-done: 186 | // abort completed successfully 187 | case <-time.After(1 * time.Second): 188 | t.Fatal("abort took too long to complete") 189 | } 190 | } 191 | 192 | func TestCoordinator_SyncHeight_Concurrent(t *testing.T) { 193 | ctrl := gomock.NewController(t) 194 | defer ctrl.Finish() 195 | 196 | mockWAL := mocks.NewMockwal(ctrl) 197 | mockStore := mocks.NewMockStateStore(ctrl) 198 | 199 | conf := &config.Config{ 200 | Nodeaddr: "localhost:8080", 201 | Role: "coordinator", 202 | Cohorts: []string{}, 203 | CommitType: server.TWO_PHASE, 204 | } 205 | 206 | coord, err := New(conf, mockWAL, mockStore) 207 | require.NoError(t, err) 208 | 209 | // test concurrent height sync 210 | done := make(chan bool, 10) 211 | 212 | for i := 0; i < 10; i++ { 213 | go func(height uint64) { 214 | coord.syncHeight(height) 215 | done <- true 216 | }(uint64(i + 1)) 217 | } 218 | 219 | // wait for all goroutines to complete 220 | for i := 0; i < 10; i++ { 221 | <-done 222 | } 223 | 224 | // height should be the maximum value 225 | require.Equal(t, uint64(10), coord.Height()) 226 | } 227 | 228 | func TestNackResponse(t *testing.T) { 229 | testErr := fmt.Errorf("test error") 230 | testMsg := "test message" 231 | 232 | resp, err := nackResponse(testErr, testMsg) 233 | 234 | require.NotNil(t, resp) 235 | require.Equal(t, dto.ResponseTypeNack, resp.Type) 236 | require.Error(t, err) 237 | require.Contains(t, err.Error(), testMsg) 238 | require.Contains(t, err.Error(), "test error") 239 | } 240 | 241 | func TestCoordinator_Config_Validation(t *testing.T) { 242 | ctrl := gomock.NewController(t) 243 | defer ctrl.Finish() 244 | 245 | mockWAL := mocks.NewMockwal(ctrl) 246 | mockStore := mocks.NewMockStateStore(ctrl) 247 | 248 | // test with empty cohorts list 249 | conf := &config.Config{ 250 | Nodeaddr: "localhost:8080", 251 | Role: "coordinator", 252 | Cohorts: []string{}, 253 | CommitType: server.TWO_PHASE, 254 | } 255 | 256 | coord, err := New(conf, mockWAL, mockStore) 257 | require.NoError(t, err) 258 | require.NotNil(t, coord) 259 | require.Len(t, coord.cohorts, 0) 260 | 261 | // test with multiple cohorts 262 | conf.Cohorts = []string{"localhost:8081", "localhost:8082", "localhost:8083"} 263 | coord, err = New(conf, mockWAL, mockStore) 264 | require.NoError(t, err) 265 | require.NotNil(t, coord) 266 | require.Len(t, coord.cohorts, 3) 267 | } 268 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= 2 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs= 6 | github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w= 7 | github.com/dgraph-io/ristretto/v2 v2.2.0 h1:bkY3XzJcXoMuELV8F+vS8kzNgicwQFAaGINAEJdWGOM= 8 | github.com/dgraph-io/ristretto/v2 v2.2.0/go.mod h1:RZrm63UmcBAaYWC1DotLYBmTvgkrs0+XhBd7Npn7/zI= 9 | github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= 10 | github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= 11 | github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= 12 | github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= 13 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 14 | github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= 15 | github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 16 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 17 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 18 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 19 | github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= 20 | github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 21 | github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= 22 | github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= 23 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 24 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 25 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 26 | github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= 27 | github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= 28 | github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= 29 | github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= 30 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 31 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 32 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 33 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 34 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 35 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 36 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 37 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 38 | github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= 39 | github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= 40 | github.com/sirupsen/logrus v1.2.0 h1:juTguoYk5qI21pwyTXY3B3Y5cOTH3ZUyZCg1v/mihuo= 41 | github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= 42 | github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 43 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 44 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 45 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 46 | github.com/vadiminshakov/gowal v0.0.4 h1:99mJRkHL7GFNo7f6TMXkUEkRdeVeNIpmywkbVnPaPMM= 47 | github.com/vadiminshakov/gowal v0.0.4/go.mod h1:NMvNH0xjRPYSrcaIg/JhqB5uTneFdgo45+Fs2sk+Esc= 48 | github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= 49 | github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= 50 | github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= 51 | github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= 52 | go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= 53 | go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= 54 | go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= 55 | go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= 56 | go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= 57 | go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= 58 | go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= 59 | go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= 60 | go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= 61 | go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= 62 | golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= 63 | golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= 64 | golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= 65 | golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= 66 | golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= 67 | golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 68 | golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= 69 | golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 70 | golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= 71 | golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= 72 | golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= 73 | golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= 74 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 75 | google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e h1:halCgTFuLWDRD61piiNSxPsARANGD3Xl16hPrLgLiIg= 76 | google.golang.org/genproto v0.0.0-20221010155953-15ba04fc1c0e/go.mod h1:3526vdqwhZAwq4wsRUaVG555sVgsNmIjRtO7t/JH29U= 77 | google.golang.org/grpc v1.50.0 h1:fPVVDxY9w++VjTZsYvXWqEf9Rqar/e+9zYfxKK+W+YU= 78 | google.golang.org/grpc v1.50.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= 79 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 80 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 81 | google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= 82 | google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= 83 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 84 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 85 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 86 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 87 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 88 | -------------------------------------------------------------------------------- /io/gateway/grpc/server/server.go: -------------------------------------------------------------------------------- 1 | // Package server provides gRPC server implementation for the committer service. 2 | // 3 | // This package implements both internal commit API for node-to-node communication 4 | // and client API for external interactions with the distributed consensus system. 5 | package server 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "net" 11 | "time" 12 | 13 | log "github.com/sirupsen/logrus" 14 | "github.com/vadiminshakov/committer/config" 15 | "github.com/vadiminshakov/committer/core/dto" 16 | "github.com/vadiminshakov/committer/io/gateway/grpc/proto" 17 | "github.com/vadiminshakov/committer/io/store" 18 | "google.golang.org/grpc" 19 | "google.golang.org/grpc/codes" 20 | "google.golang.org/grpc/status" 21 | "google.golang.org/protobuf/types/known/emptypb" 22 | ) 23 | 24 | const ( 25 | // TWO_PHASE represents the two-phase commit protocol. 26 | TWO_PHASE = "two-phase" 27 | // THREE_PHASE represents the three-phase commit protocol. 28 | THREE_PHASE = "three-phase" 29 | ) 30 | 31 | // Coordinator defines the interface for coordinator operations. 32 | type Coordinator interface { 33 | Broadcast(ctx context.Context, req dto.BroadcastRequest) (*dto.BroadcastResponse, error) 34 | Height() uint64 35 | SetHeight(height uint64) 36 | } 37 | 38 | // Cohort defines the interface for cohort operations. 39 | // 40 | //go:generate mockgen -destination=../../../../mocks/mock_cohort.go -package=mocks . Cohort 41 | type Cohort interface { 42 | Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) 43 | Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) 44 | Commit(ctx context.Context, in *dto.CommitRequest) (*dto.CohortResponse, error) 45 | Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) 46 | Height() uint64 47 | } 48 | 49 | // Server holds server instance, node config and connections to followers (if it's a coordinator node). 50 | type Server struct { 51 | proto.UnimplementedInternalCommitAPIServer 52 | proto.UnimplementedClientAPIServer 53 | 54 | cohort Cohort // Cohort implementation for this node 55 | store *store.Store // Persistent storage 56 | coordinator Coordinator // Coordinator implementation (if this node is a coordinator) 57 | GRPCServer *grpc.Server // gRPC server instance 58 | Config *config.Config // Node configuration 59 | ProposeHook func(req *proto.ProposeRequest) bool // Hook for propose phase 60 | CommitHook func(req *proto.CommitRequest) bool // Hook for commit phase 61 | Addr string // Server address 62 | } 63 | 64 | func (s *Server) Propose(ctx context.Context, req *proto.ProposeRequest) (*proto.Response, error) { 65 | if s.cohort == nil { 66 | return nil, status.Error(codes.FailedPrecondition, "cohort role not enabled on this node") 67 | } 68 | resp, err := s.cohort.Propose(ctx, proposeRequestPbToEntity(req)) 69 | return cohortResponseToProto(resp), err 70 | } 71 | 72 | func (s *Server) Precommit(ctx context.Context, req *proto.PrecommitRequest) (*proto.Response, error) { 73 | if s.cohort == nil { 74 | return nil, status.Error(codes.FailedPrecondition, "cohort role not enabled on this node") 75 | } 76 | resp, err := s.cohort.Precommit(ctx, req.Index) 77 | return cohortResponseToProto(resp), err 78 | } 79 | 80 | func (s *Server) Commit(ctx context.Context, req *proto.CommitRequest) (*proto.Response, error) { 81 | if s.cohort == nil { 82 | return nil, status.Error(codes.FailedPrecondition, "cohort role not enabled on this node") 83 | } 84 | resp, err := s.cohort.Commit(ctx, commitRequestPbToEntity(req)) 85 | return cohortResponseToProto(resp), err 86 | } 87 | 88 | func (s *Server) Abort(ctx context.Context, req *proto.AbortRequest) (*proto.Response, error) { 89 | if s.cohort == nil { 90 | return nil, status.Error(codes.FailedPrecondition, "cohort role not enabled on this node") 91 | } 92 | abortReq := &dto.AbortRequest{ 93 | Height: req.Height, 94 | Reason: req.Reason, 95 | } 96 | resp, err := s.cohort.Abort(ctx, abortReq) 97 | return cohortResponseToProto(resp), err 98 | } 99 | 100 | func (s *Server) Get(ctx context.Context, req *proto.Msg) (*proto.Value, error) { 101 | value, err := s.store.Get(req.Key) 102 | if err != nil { 103 | return nil, err 104 | } 105 | return &proto.Value{Value: value}, nil 106 | } 107 | 108 | // Put initiates a distributed transaction to store a key-value pair. 109 | func (s *Server) Put(ctx context.Context, req *proto.Entry) (*proto.Response, error) { 110 | if s.coordinator == nil { 111 | return nil, status.Error(codes.FailedPrecondition, "coordinator role not enabled on this node") 112 | } 113 | resp, err := s.coordinator.Broadcast(ctx, dto.BroadcastRequest{ 114 | Key: req.Key, 115 | Value: req.Value, 116 | }) 117 | if err != nil { 118 | return nil, err 119 | } 120 | 121 | return &proto.Response{ 122 | Type: proto.Type(resp.Type), 123 | Index: resp.Height, 124 | }, nil 125 | } 126 | 127 | // NodeInfo returns information about the current node. 128 | func (s *Server) NodeInfo(ctx context.Context, req *emptypb.Empty) (*proto.Info, error) { 129 | switch { 130 | case s.cohort != nil: 131 | return &proto.Info{Height: s.cohort.Height()}, nil 132 | case s.coordinator != nil: 133 | return &proto.Info{Height: s.coordinator.Height()}, nil 134 | default: 135 | return nil, status.Error(codes.FailedPrecondition, "node has neither cohort nor coordinator role configured") 136 | } 137 | } 138 | 139 | // New creates a new Server instance with the specified configuration. 140 | func New(conf *config.Config, cohort Cohort, coordinator Coordinator, stateStore *store.Store) (*Server, error) { 141 | log.SetFormatter(&log.TextFormatter{ 142 | ForceColors: true, // Seems like automatic color detection doesn't work on windows terminals 143 | FullTimestamp: true, 144 | TimestampFormat: time.RFC822, 145 | }) 146 | 147 | server := &Server{ 148 | Addr: conf.Nodeaddr, 149 | cohort: cohort, 150 | coordinator: coordinator, 151 | store: stateStore, 152 | Config: conf, 153 | } 154 | 155 | if server.Config.CommitType == TWO_PHASE { 156 | log.Info("two-phase-commit mode enabled") 157 | } else { 158 | log.Info("three-phase-commit mode enabled") 159 | } 160 | err := checkServerFields(server) 161 | return server, err 162 | } 163 | 164 | func checkServerFields(server *Server) error { 165 | if server.store == nil { 166 | return errors.New("store is not configured") 167 | } 168 | if server.Config.Role == "cohort" && server.cohort == nil { 169 | return errors.New("cohort role selected but cohort implementation is nil") 170 | } 171 | if server.Config.Role == "coordinator" && server.coordinator == nil { 172 | return errors.New("coordinator role selected but coordinator implementation is nil") 173 | } 174 | return nil 175 | } 176 | 177 | // Run starts the gRPC server in a non-blocking manner. 178 | func (s *Server) Run(opts ...grpc.UnaryServerInterceptor) { 179 | var err error 180 | s.GRPCServer = grpc.NewServer(grpc.ChainUnaryInterceptor(opts...)) 181 | if s.cohort != nil { 182 | proto.RegisterInternalCommitAPIServer(s.GRPCServer, s) 183 | } 184 | proto.RegisterClientAPIServer(s.GRPCServer, s) 185 | 186 | l, err := net.Listen("tcp", s.Addr) 187 | if err != nil { 188 | log.Fatalf("failed to listen: %v", err) 189 | } 190 | log.Infof("listening on tcp://%s", s.Addr) 191 | 192 | go s.GRPCServer.Serve(l) 193 | } 194 | 195 | // Stop gracefully stops the gRPC server. 196 | func (s *Server) Stop() { 197 | log.Info("stopping server") 198 | s.GRPCServer.GracefulStop() 199 | if s.store != nil { 200 | if err := s.store.Close(); err != nil { 201 | log.Infof("failed to close store: %s\n", err) 202 | } 203 | } 204 | log.Info("server stopped") 205 | } 206 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "strconv" 9 | "testing" 10 | "time" 11 | 12 | log "github.com/sirupsen/logrus" 13 | "github.com/stretchr/testify/require" 14 | "github.com/vadiminshakov/committer/config" 15 | "github.com/vadiminshakov/committer/core/cohort" 16 | "github.com/vadiminshakov/committer/core/cohort/commitalgo" 17 | "github.com/vadiminshakov/committer/core/coordinator" 18 | "github.com/vadiminshakov/committer/io/gateway/grpc/client" 19 | pb "github.com/vadiminshakov/committer/io/gateway/grpc/proto" 20 | "github.com/vadiminshakov/committer/io/gateway/grpc/server" 21 | "github.com/vadiminshakov/committer/io/store" 22 | "github.com/vadiminshakov/gowal" 23 | ) 24 | 25 | const ( 26 | COORDINATOR_TYPE = "coordinator" 27 | COHORT_TYPE = "cohort" 28 | BADGER_DIR = "/tmp/badger" 29 | ) 30 | 31 | var ( 32 | whitelist = []string{"127.0.0.1"} 33 | nodes = map[string][]*config.Config{ 34 | COORDINATOR_TYPE: { 35 | {Nodeaddr: "localhost:2938", Role: "coordinator", 36 | Cohorts: []string{"localhost:2345", "localhost:2384", "localhost:7532", "localhost:5743", "localhost:4991"}, 37 | Whitelist: whitelist, CommitType: "two-phase", Timeout: 100}, 38 | {Nodeaddr: "localhost:5002", Role: "coordinator", 39 | Cohorts: []string{"localhost:2345", "localhost:2384", "localhost:7532", "localhost:5743", "localhost:4991"}, 40 | Whitelist: whitelist, CommitType: "three-phase", Timeout: 100}, 41 | }, 42 | COHORT_TYPE: { 43 | &config.Config{Nodeaddr: "localhost:2345", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"}, 44 | &config.Config{Nodeaddr: "localhost:2384", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"}, 45 | &config.Config{Nodeaddr: "localhost:7532", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"}, 46 | &config.Config{Nodeaddr: "localhost:5743", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"}, 47 | &config.Config{Nodeaddr: "localhost:4991", Role: "cohort", Coordinator: "localhost:2938", Whitelist: whitelist, Timeout: 800, CommitType: "three-phase"}, 48 | }, 49 | } 50 | ) 51 | 52 | var testtable = map[string][]byte{ 53 | "key1": []byte("value1"), 54 | "key2": []byte("value2"), 55 | "key3": []byte("value3"), 56 | } 57 | 58 | func TestHappyPath(t *testing.T) { 59 | log.SetLevel(log.InfoLevel) 60 | 61 | var canceller func() error 62 | 63 | var height uint64 = 0 64 | coordConfig := nodes[COORDINATOR_TYPE][0] 65 | if coordConfig.CommitType == "two-phase" { 66 | canceller = startnodes(pb.CommitType_TWO_PHASE_COMMIT) 67 | log.Println("***\nTEST IN TWO-PHASE MODE\n***") 68 | } else { 69 | canceller = startnodes(pb.CommitType_THREE_PHASE_COMMIT) 70 | log.Println("***\nTEST IN THREE-PHASE MODE\n***") 71 | } 72 | 73 | defer canceller() 74 | 75 | c, err := client.NewClientAPI(coordConfig.Nodeaddr) 76 | if err != nil { 77 | t.Error(err) 78 | } 79 | 80 | for key, val := range testtable { 81 | resp, err := c.Put(context.Background(), key, val) 82 | if err != nil { 83 | t.Error(err) 84 | } 85 | 86 | if resp.Type != pb.Type_ACK { 87 | t.Error("msg is not acknowledged") 88 | } 89 | // ok, value is added, let's increment height counter 90 | height++ 91 | } 92 | 93 | // wait for rollback on cohorts 94 | time.Sleep(1 * time.Second) 95 | 96 | // connect to cohorts and check that them added key-value 97 | for _, node := range nodes[COHORT_TYPE] { 98 | cli, err := client.NewClientAPI(node.Nodeaddr) 99 | require.NoError(t, err, "err not nil") 100 | 101 | for key, val := range testtable { 102 | // check values added by nodes 103 | resp, err := cli.Get(context.Background(), key) 104 | require.NoError(t, err, "err not nil") 105 | require.Equal(t, resp.Value, val) 106 | 107 | // check height of node 108 | nodeInfo, err := cli.NodeInfo(context.Background()) 109 | require.NoError(t, err, "err not nil") 110 | require.Equal(t, height, nodeInfo.Height, "node %s ahead, %d commits behind (current height is %d)", node.Nodeaddr, nodeInfo.Height-height, nodeInfo.Height) 111 | } 112 | } 113 | 114 | require.NoError(t, canceller()) 115 | } 116 | 117 | func startnodes(commitType pb.CommitType) func() error { 118 | COORDINATOR_BADGER := fmt.Sprintf("%s%s%d", BADGER_DIR, "coordinator", time.Now().UnixNano()) 119 | COHORT_BADGER := fmt.Sprintf("%s%s%d", BADGER_DIR, "cohort", time.Now().UnixNano()) 120 | 121 | // check dir exists 122 | if _, err := os.Stat(COORDINATOR_BADGER); !os.IsNotExist(err) { 123 | // del dir 124 | err := os.RemoveAll(COORDINATOR_BADGER) 125 | failfast(err) 126 | } 127 | if _, err := os.Stat(COHORT_BADGER); !os.IsNotExist(err) { 128 | // del dir 129 | failfast(os.RemoveAll(COHORT_BADGER)) 130 | } 131 | if _, err := os.Stat("./tmp"); !os.IsNotExist(err) { 132 | // del dir 133 | failfast(os.RemoveAll("./tmp")) 134 | 135 | } 136 | 137 | { 138 | // nolint:errcheck 139 | failfast(os.Mkdir(COORDINATOR_BADGER, os.FileMode(0777))) 140 | failfast(os.Mkdir(COHORT_BADGER, os.FileMode(0777))) 141 | failfast(os.Mkdir("./tmp", os.FileMode(0777))) 142 | failfast(os.Mkdir("./tmp/cohort", os.FileMode(0777))) 143 | failfast(os.Mkdir("./tmp/coord", os.FileMode(0777))) 144 | } 145 | 146 | // start cohorts 147 | stopfuncs := make([]func(), 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 148 | for i, node := range nodes[COHORT_TYPE] { 149 | if commitType == pb.CommitType_THREE_PHASE_COMMIT { 150 | node.Coordinator = nodes[COORDINATOR_TYPE][1].Nodeaddr 151 | } 152 | node.DBPath = filepath.Join(COHORT_BADGER, strconv.Itoa(i)) 153 | 154 | failfast(os.MkdirAll(node.DBPath, os.FileMode(0o777))) 155 | walConfig := gowal.Config{ 156 | Dir: "./tmp/cohort/" + strconv.Itoa(i), 157 | Prefix: "msgs_", 158 | SegmentThreshold: 100, 159 | MaxSegments: 100, 160 | IsInSyncDiskMode: false, 161 | } 162 | w, err := gowal.NewWAL(walConfig) 163 | failfast(err) 164 | 165 | stateStore, recovery, err := store.New(w, node.DBPath) 166 | failfast(err) 167 | 168 | ct := server.TWO_PHASE 169 | if commitType == pb.CommitType_THREE_PHASE_COMMIT { 170 | ct = server.THREE_PHASE 171 | } 172 | 173 | committer := commitalgo.NewCommitter(stateStore, ct, w, node.Timeout) 174 | committer.SetHeight(recovery.Height) 175 | cohortImpl := cohort.NewCohort(committer, cohort.Mode(node.CommitType)) 176 | 177 | cohortServer, err := server.New(node, cohortImpl, nil, stateStore) 178 | failfast(err) 179 | 180 | go cohortServer.Run(server.WhiteListChecker) 181 | 182 | stopfuncs = append(stopfuncs, cohortServer.Stop) 183 | } 184 | 185 | // start coordinators (in two- and three-phase modes) 186 | for i, coordConfig := range nodes[COORDINATOR_TYPE] { 187 | coordConfig.DBPath = filepath.Join(COORDINATOR_BADGER, strconv.Itoa(i)) 188 | failfast(os.MkdirAll(coordConfig.DBPath, os.FileMode(0o777))) 189 | walConfig := gowal.Config{ 190 | Dir: "./tmp/coord/msgs" + strconv.Itoa(i), 191 | Prefix: "msgs", 192 | SegmentThreshold: 100, 193 | MaxSegments: 100, 194 | IsInSyncDiskMode: false, 195 | } 196 | 197 | c, err := gowal.NewWAL(walConfig) 198 | failfast(err) 199 | 200 | stateStore, recovery, err := store.New(c, coordConfig.DBPath) 201 | failfast(err) 202 | 203 | coord, err := coordinator.New(coordConfig, c, stateStore) 204 | failfast(err) 205 | coord.SetHeight(recovery.Height) 206 | 207 | coordServer, err := server.New(coordConfig, nil, coord, stateStore) 208 | failfast(err) 209 | 210 | go coordServer.Run(server.WhiteListChecker) 211 | time.Sleep(100 * time.Millisecond) 212 | stopfuncs = append(stopfuncs, coordServer.Stop) 213 | } 214 | 215 | return func() error { 216 | for _, f := range stopfuncs { 217 | f() 218 | } 219 | failfast(os.RemoveAll("./tmp")) 220 | return os.RemoveAll(BADGER_DIR) 221 | } 222 | } 223 | 224 | func failfast(err error) { 225 | if err != nil { 226 | panic(err) 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /core/coordinator/coordinator.go: -------------------------------------------------------------------------------- 1 | // Package coordinator implements the coordinator role in distributed atomic commit protocols. 2 | // The coordinator is responsible for managing the atomic commit process by 3 | // sending prepare requests to cohorts and collecting their responses. 4 | package coordinator 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "strings" 10 | "sync" 11 | "sync/atomic" 12 | 13 | "github.com/pkg/errors" 14 | log "github.com/sirupsen/logrus" 15 | "github.com/vadiminshakov/committer/config" 16 | "github.com/vadiminshakov/committer/core/dto" 17 | "github.com/vadiminshakov/committer/core/walrecord" 18 | "github.com/vadiminshakov/committer/io/gateway/grpc/client" 19 | pb "github.com/vadiminshakov/committer/io/gateway/grpc/proto" 20 | "github.com/vadiminshakov/committer/io/gateway/grpc/server" 21 | "google.golang.org/grpc/codes" 22 | "google.golang.org/grpc/status" 23 | ) 24 | 25 | //go:generate mockgen -destination=../../mocks/mock_wal.go -package=mocks . wal 26 | type wal interface { 27 | Write(index uint64, key string, value []byte) error 28 | WriteTombstone(index uint64) error 29 | Get(index uint64) (string, []byte, error) 30 | Close() error 31 | } 32 | 33 | // StateStore defines the interface for persistent state storage. 34 | // 35 | //go:generate mockgen -destination=../../mocks/mock_state_store.go -package=mocks . StateStore 36 | type StateStore interface { 37 | Put(key string, value []byte) error 38 | Close() error 39 | } 40 | 41 | type coordinator struct { 42 | wal wal 43 | store StateStore 44 | cohorts map[string]*client.InternalCommitClient 45 | config *config.Config 46 | commitType pb.CommitType 47 | threePhase bool 48 | height uint64 49 | mu sync.Mutex // serialize broadcast to keep height/WAL alignment 50 | } 51 | 52 | // New creates a new coordinator instance with the specified configuration. 53 | func New(conf *config.Config, wal wal, store StateStore) (*coordinator, error) { 54 | cohorts := make(map[string]*client.InternalCommitClient, len(conf.Cohorts)) 55 | for _, f := range conf.Cohorts { 56 | cl, err := client.NewInternalClient(f) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | cohorts[f] = cl 62 | } 63 | 64 | threePhase := conf.CommitType == server.THREE_PHASE 65 | commitType := pb.CommitType_TWO_PHASE_COMMIT 66 | if threePhase { 67 | commitType = pb.CommitType_THREE_PHASE_COMMIT 68 | } 69 | 70 | return &coordinator{ 71 | wal: wal, 72 | store: store, 73 | cohorts: cohorts, 74 | config: conf, 75 | commitType: commitType, 76 | threePhase: threePhase, 77 | }, nil 78 | } 79 | 80 | // Broadcast executes the complete distributed consensus algorithm (2PC or 3PC) for a transaction. 81 | // It runs through all phases: propose, precommit (if 3PC), commit, and persistence. 82 | func (c *coordinator) Broadcast(ctx context.Context, req dto.BroadcastRequest) (*dto.BroadcastResponse, error) { 83 | c.mu.Lock() 84 | defer c.mu.Unlock() 85 | 86 | log.Infof("Proposing key %s", req.Key) 87 | if err := c.propose(ctx, req); err != nil { 88 | return nackResponse(err, "failed to send propose") 89 | } 90 | 91 | if c.threePhase { 92 | log.Infof("Precommitting key %s", req.Key) 93 | if err := c.preCommit(ctx); err != nil { 94 | return nackResponse(err, "failed to send precommit") 95 | } 96 | } 97 | 98 | log.Infof("Committing key %s", req.Key) 99 | if err := c.commit(ctx); err != nil { 100 | s, ok := status.FromError(err) 101 | if !ok { 102 | return &dto.BroadcastResponse{Type: dto.ResponseTypeNack}, fmt.Errorf("failed to extract grpc status code from err: %s", err) 103 | } 104 | if s.Code() == codes.AlreadyExists { 105 | return &dto.BroadcastResponse{Type: dto.ResponseTypeNack}, nil 106 | } 107 | return nackResponse(err, "failed to send commit") 108 | } 109 | 110 | log.Infof("coordinator committed key %s", req.Key) 111 | 112 | newHeight := atomic.AddUint64(&c.height, 1) 113 | return &dto.BroadcastResponse{Type: dto.ResponseTypeAck, Height: newHeight}, nil 114 | } 115 | 116 | func (c *coordinator) propose(ctx context.Context, req dto.BroadcastRequest) error { 117 | for name, cohort := range c.cohorts { 118 | if err := c.sendProposal(ctx, cohort, name, req, c.commitType); err != nil { 119 | return err 120 | } 121 | } 122 | 123 | currentHeight := atomic.LoadUint64(&c.height) 124 | 125 | // write Prepared to WAL to persist payload 126 | ptx := walrecord.WalTx{Key: req.Key, Value: req.Value} 127 | pbBytes, err := walrecord.Encode(ptx) 128 | if err != nil { 129 | return err 130 | } 131 | return c.wal.Write(walrecord.PreparedSlot(currentHeight), walrecord.KeyPrepared, pbBytes) 132 | } 133 | 134 | func (c *coordinator) sendProposal(ctx context.Context, cohort *client.InternalCommitClient, name string, req dto.BroadcastRequest, commitType pb.CommitType) error { 135 | var ( 136 | resp *pb.Response 137 | err error 138 | ) 139 | 140 | for { 141 | currentHeight := atomic.LoadUint64(&c.height) 142 | resp, err = cohort.Propose(ctx, &pb.ProposeRequest{ 143 | Key: req.Key, 144 | Value: req.Value, 145 | CommitType: commitType, 146 | Index: currentHeight, 147 | }) 148 | 149 | if err == nil && resp != nil && resp.Type == pb.Type_ACK { 150 | break // success 151 | } 152 | 153 | // if cohort has bigger height, update coordinator's height and retry 154 | if resp != nil && resp.Index > currentHeight { 155 | c.syncHeight(resp.Index) 156 | continue 157 | } 158 | if err != nil { 159 | // send abort to all cohorts on error 160 | c.abort(ctx, fmt.Sprintf("node %s rejected proposed msg: %v", name, err)) 161 | return fmt.Errorf("node %s rejected proposed msg: %w", name, err) 162 | } 163 | 164 | // send abort to all cohorts on NACK 165 | c.abort(ctx, fmt.Sprintf("cohort %s sent NACK for propose", name)) 166 | return fmt.Errorf("cohort %s not acknowledged msg %v", name, req) 167 | } 168 | return nil 169 | } 170 | 171 | func (c *coordinator) preCommit(ctx context.Context) error { 172 | currentHeight := atomic.LoadUint64(&c.height) 173 | for name, cohort := range c.cohorts { 174 | resp, err := cohort.Precommit(ctx, &pb.PrecommitRequest{Index: currentHeight}) 175 | if err != nil { 176 | c.abort(ctx, fmt.Sprintf("cohort %s precommit error: %v", name, err)) 177 | return status.Error(codes.FailedPrecondition, "cohort not acknowledged msg") 178 | } 179 | if !isAck(resp) { 180 | c.abort(ctx, fmt.Sprintf("cohort %s sent NACK for precommit", name)) 181 | return status.Error(codes.FailedPrecondition, "cohort not acknowledged msg") 182 | } 183 | } 184 | 185 | return nil 186 | } 187 | 188 | func (c *coordinator) commit(ctx context.Context) error { 189 | currentHeight := atomic.LoadUint64(&c.height) 190 | var errs []string 191 | for name, cohort := range c.cohorts { 192 | resp, err := cohort.Commit(ctx, &pb.CommitRequest{Index: currentHeight}) 193 | if err != nil { 194 | errs = append(errs, fmt.Sprintf("%s: %v", name, err)) 195 | continue 196 | } 197 | if !isAck(resp) { 198 | errs = append(errs, fmt.Sprintf("%s: NACK", name)) 199 | } 200 | } 201 | 202 | if len(errs) > 0 { 203 | return status.Error(codes.FailedPrecondition, "cohort not acknowledged msg: "+strings.Join(errs, "; ")) 204 | } 205 | 206 | // Persist Decision (Commit) and Apply 207 | // 1. read Payload from Prepared 208 | k, v, err := c.wal.Get(walrecord.PreparedSlot(currentHeight)) 209 | if err != nil { 210 | return status.Errorf(codes.Internal, "failed to read prepared tx: %v", err) 211 | } 212 | if k != walrecord.KeyPrepared { 213 | return status.Errorf(codes.Internal, "expected prepared tx, got %s", k) 214 | } 215 | 216 | // 2. write commit 217 | if err := c.wal.Write(walrecord.CommitSlot(currentHeight), walrecord.KeyCommit, v); err != nil { 218 | return status.Errorf(codes.Internal, "failed to write commit val: %v", err) 219 | } 220 | 221 | // 3. apply 222 | walTx, err := walrecord.Decode(v) 223 | if err != nil { 224 | return status.Errorf(codes.Internal, "failed to decode tx: %v", err) 225 | } 226 | if err := c.store.Put(walTx.Key, walTx.Value); err != nil { 227 | log.Errorf("failed to apply to store: %v", err) 228 | return err // committed but not applied? critical error 229 | } 230 | 231 | return nil 232 | } 233 | 234 | func isAck(resp *pb.Response) bool { 235 | return resp != nil && resp.Type == pb.Type_ACK 236 | } 237 | 238 | // syncHeight atomically updates coordinator height to match cohort height if needed 239 | func (c *coordinator) syncHeight(cohortHeight uint64) { 240 | for { 241 | currentHeight := atomic.LoadUint64(&c.height) 242 | if cohortHeight <= currentHeight { 243 | return // height is already up to date 244 | } 245 | 246 | if atomic.CompareAndSwapUint64(&c.height, currentHeight, cohortHeight) { 247 | log.Warnf("Updating coordinator height: %d -> %d", currentHeight, cohortHeight) 248 | return 249 | } 250 | } 251 | } 252 | 253 | // Height returns the current transaction height. 254 | func (c *coordinator) Height() uint64 { 255 | return atomic.LoadUint64(&c.height) 256 | } 257 | 258 | // SetHeight initializes coordinator height during recovery. 259 | func (c *coordinator) SetHeight(height uint64) { 260 | atomic.StoreUint64(&c.height, height) 261 | } 262 | 263 | // abort sends abort requests to all cohorts in a fire-and-forget manner 264 | func (c *coordinator) abort(ctx context.Context, reason string) { 265 | currentHeight := atomic.LoadUint64(&c.height) 266 | log.Warnf("Aborting transaction at height %d: %s", currentHeight, reason) 267 | 268 | if err := c.wal.Write(walrecord.AbortSlot(currentHeight), walrecord.KeyAbort, nil); err != nil { 269 | log.Errorf("Failed to write abort to WAL: %v", err) 270 | } 271 | 272 | for name, cohort := range c.cohorts { 273 | go func(name string, cohort *client.InternalCommitClient) { 274 | if _, err := cohort.Abort(ctx, &dto.AbortRequest{Height: currentHeight, Reason: reason}); err != nil { 275 | log.Errorf("Failed to send abort to cohort %s: %v", name, err) 276 | } 277 | }(name, cohort) 278 | } 279 | } 280 | 281 | func nackResponse(err error, msg string) (*dto.BroadcastResponse, error) { 282 | return &dto.BroadcastResponse{Type: dto.ResponseTypeNack}, errors.Wrap(err, msg) 283 | } 284 | 285 | func (c *coordinator) persistMessage() error { 286 | currentHeight := atomic.LoadUint64(&c.height) 287 | 288 | // read prepared 289 | k, v, err := c.wal.Get(walrecord.PreparedSlot(currentHeight)) 290 | if err != nil { 291 | return status.Error(codes.Internal, fmt.Sprintf("failed to read msg at height %d from wal: %v", currentHeight, err)) 292 | } 293 | if v == nil { 294 | return status.Error(codes.Internal, "can't find msg in wal") 295 | } 296 | if k != walrecord.KeyPrepared { 297 | return status.Error(codes.Internal, fmt.Sprintf("expected prepared key, got %s", k)) 298 | } 299 | 300 | walTx, err := walrecord.Decode(v) 301 | if err != nil { 302 | return status.Error(codes.Internal, "failed to decode tx") 303 | } 304 | 305 | // save 306 | return c.store.Put(walTx.Key, walTx.Value) 307 | } 308 | -------------------------------------------------------------------------------- /io/gateway/grpc/proto/schema_grpc.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go-grpc. DO NOT EDIT. 2 | 3 | package proto 4 | 5 | import ( 6 | context "context" 7 | grpc "google.golang.org/grpc" 8 | codes "google.golang.org/grpc/codes" 9 | status "google.golang.org/grpc/status" 10 | emptypb "google.golang.org/protobuf/types/known/emptypb" 11 | ) 12 | 13 | // This is a compile-time assertion to ensure that this generated file 14 | // is compatible with the grpc package it is being compiled against. 15 | // Requires gRPC-Go v1.32.0 or later. 16 | const _ = grpc.SupportPackageIsVersion7 17 | 18 | // InternalCommitAPIClient is the client API for InternalCommitAPI service. 19 | // 20 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. 21 | type InternalCommitAPIClient interface { 22 | Propose(ctx context.Context, in *ProposeRequest, opts ...grpc.CallOption) (*Response, error) 23 | Precommit(ctx context.Context, in *PrecommitRequest, opts ...grpc.CallOption) (*Response, error) 24 | Commit(ctx context.Context, in *CommitRequest, opts ...grpc.CallOption) (*Response, error) 25 | Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*Response, error) 26 | } 27 | 28 | type internalCommitAPIClient struct { 29 | cc grpc.ClientConnInterface 30 | } 31 | 32 | func NewInternalCommitAPIClient(cc grpc.ClientConnInterface) InternalCommitAPIClient { 33 | return &internalCommitAPIClient{cc} 34 | } 35 | 36 | func (c *internalCommitAPIClient) Propose(ctx context.Context, in *ProposeRequest, opts ...grpc.CallOption) (*Response, error) { 37 | out := new(Response) 38 | err := c.cc.Invoke(ctx, "/schema.InternalCommitAPI/Propose", in, out, opts...) 39 | if err != nil { 40 | return nil, err 41 | } 42 | return out, nil 43 | } 44 | 45 | func (c *internalCommitAPIClient) Precommit(ctx context.Context, in *PrecommitRequest, opts ...grpc.CallOption) (*Response, error) { 46 | out := new(Response) 47 | err := c.cc.Invoke(ctx, "/schema.InternalCommitAPI/Precommit", in, out, opts...) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return out, nil 52 | } 53 | 54 | func (c *internalCommitAPIClient) Commit(ctx context.Context, in *CommitRequest, opts ...grpc.CallOption) (*Response, error) { 55 | out := new(Response) 56 | err := c.cc.Invoke(ctx, "/schema.InternalCommitAPI/Commit", in, out, opts...) 57 | if err != nil { 58 | return nil, err 59 | } 60 | return out, nil 61 | } 62 | 63 | func (c *internalCommitAPIClient) Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*Response, error) { 64 | out := new(Response) 65 | err := c.cc.Invoke(ctx, "/schema.InternalCommitAPI/Abort", in, out, opts...) 66 | if err != nil { 67 | return nil, err 68 | } 69 | return out, nil 70 | } 71 | 72 | // InternalCommitAPIServer is the server API for InternalCommitAPI service. 73 | // All implementations must embed UnimplementedInternalCommitAPIServer 74 | // for forward compatibility 75 | type InternalCommitAPIServer interface { 76 | Propose(context.Context, *ProposeRequest) (*Response, error) 77 | Precommit(context.Context, *PrecommitRequest) (*Response, error) 78 | Commit(context.Context, *CommitRequest) (*Response, error) 79 | Abort(context.Context, *AbortRequest) (*Response, error) 80 | mustEmbedUnimplementedInternalCommitAPIServer() 81 | } 82 | 83 | // UnimplementedInternalCommitAPIServer must be embedded to have forward compatible implementations. 84 | type UnimplementedInternalCommitAPIServer struct { 85 | } 86 | 87 | func (UnimplementedInternalCommitAPIServer) Propose(context.Context, *ProposeRequest) (*Response, error) { 88 | return nil, status.Errorf(codes.Unimplemented, "method Propose not implemented") 89 | } 90 | func (UnimplementedInternalCommitAPIServer) Precommit(context.Context, *PrecommitRequest) (*Response, error) { 91 | return nil, status.Errorf(codes.Unimplemented, "method Precommit not implemented") 92 | } 93 | func (UnimplementedInternalCommitAPIServer) Commit(context.Context, *CommitRequest) (*Response, error) { 94 | return nil, status.Errorf(codes.Unimplemented, "method Commit not implemented") 95 | } 96 | func (UnimplementedInternalCommitAPIServer) Abort(context.Context, *AbortRequest) (*Response, error) { 97 | return nil, status.Errorf(codes.Unimplemented, "method Abort not implemented") 98 | } 99 | func (UnimplementedInternalCommitAPIServer) mustEmbedUnimplementedInternalCommitAPIServer() {} 100 | 101 | // UnsafeInternalCommitAPIServer may be embedded to opt out of forward compatibility for this service. 102 | // Use of this interface is not recommended, as added methods to InternalCommitAPIServer will 103 | // result in compilation errors. 104 | type UnsafeInternalCommitAPIServer interface { 105 | mustEmbedUnimplementedInternalCommitAPIServer() 106 | } 107 | 108 | func RegisterInternalCommitAPIServer(s grpc.ServiceRegistrar, srv InternalCommitAPIServer) { 109 | s.RegisterService(&InternalCommitAPI_ServiceDesc, srv) 110 | } 111 | 112 | func _InternalCommitAPI_Propose_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 113 | in := new(ProposeRequest) 114 | if err := dec(in); err != nil { 115 | return nil, err 116 | } 117 | if interceptor == nil { 118 | return srv.(InternalCommitAPIServer).Propose(ctx, in) 119 | } 120 | info := &grpc.UnaryServerInfo{ 121 | Server: srv, 122 | FullMethod: "/schema.InternalCommitAPI/Propose", 123 | } 124 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 125 | return srv.(InternalCommitAPIServer).Propose(ctx, req.(*ProposeRequest)) 126 | } 127 | return interceptor(ctx, in, info, handler) 128 | } 129 | 130 | func _InternalCommitAPI_Precommit_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 131 | in := new(PrecommitRequest) 132 | if err := dec(in); err != nil { 133 | return nil, err 134 | } 135 | if interceptor == nil { 136 | return srv.(InternalCommitAPIServer).Precommit(ctx, in) 137 | } 138 | info := &grpc.UnaryServerInfo{ 139 | Server: srv, 140 | FullMethod: "/schema.InternalCommitAPI/Precommit", 141 | } 142 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 143 | return srv.(InternalCommitAPIServer).Precommit(ctx, req.(*PrecommitRequest)) 144 | } 145 | return interceptor(ctx, in, info, handler) 146 | } 147 | 148 | func _InternalCommitAPI_Commit_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 149 | in := new(CommitRequest) 150 | if err := dec(in); err != nil { 151 | return nil, err 152 | } 153 | if interceptor == nil { 154 | return srv.(InternalCommitAPIServer).Commit(ctx, in) 155 | } 156 | info := &grpc.UnaryServerInfo{ 157 | Server: srv, 158 | FullMethod: "/schema.InternalCommitAPI/Commit", 159 | } 160 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 161 | return srv.(InternalCommitAPIServer).Commit(ctx, req.(*CommitRequest)) 162 | } 163 | return interceptor(ctx, in, info, handler) 164 | } 165 | 166 | func _InternalCommitAPI_Abort_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 167 | in := new(AbortRequest) 168 | if err := dec(in); err != nil { 169 | return nil, err 170 | } 171 | if interceptor == nil { 172 | return srv.(InternalCommitAPIServer).Abort(ctx, in) 173 | } 174 | info := &grpc.UnaryServerInfo{ 175 | Server: srv, 176 | FullMethod: "/schema.InternalCommitAPI/Abort", 177 | } 178 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 179 | return srv.(InternalCommitAPIServer).Abort(ctx, req.(*AbortRequest)) 180 | } 181 | return interceptor(ctx, in, info, handler) 182 | } 183 | 184 | // InternalCommitAPI_ServiceDesc is the grpc.ServiceDesc for InternalCommitAPI service. 185 | // It's only intended for direct use with grpc.RegisterService, 186 | // and not to be introspected or modified (even as a copy) 187 | var InternalCommitAPI_ServiceDesc = grpc.ServiceDesc{ 188 | ServiceName: "schema.InternalCommitAPI", 189 | HandlerType: (*InternalCommitAPIServer)(nil), 190 | Methods: []grpc.MethodDesc{ 191 | { 192 | MethodName: "Propose", 193 | Handler: _InternalCommitAPI_Propose_Handler, 194 | }, 195 | { 196 | MethodName: "Precommit", 197 | Handler: _InternalCommitAPI_Precommit_Handler, 198 | }, 199 | { 200 | MethodName: "Commit", 201 | Handler: _InternalCommitAPI_Commit_Handler, 202 | }, 203 | { 204 | MethodName: "Abort", 205 | Handler: _InternalCommitAPI_Abort_Handler, 206 | }, 207 | }, 208 | Streams: []grpc.StreamDesc{}, 209 | Metadata: "schema.proto", 210 | } 211 | 212 | // ClientAPIClient is the client API for ClientAPI service. 213 | // 214 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. 215 | type ClientAPIClient interface { 216 | Put(ctx context.Context, in *Entry, opts ...grpc.CallOption) (*Response, error) 217 | Get(ctx context.Context, in *Msg, opts ...grpc.CallOption) (*Value, error) 218 | NodeInfo(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*Info, error) 219 | } 220 | 221 | type clientAPIClient struct { 222 | cc grpc.ClientConnInterface 223 | } 224 | 225 | func NewClientAPIClient(cc grpc.ClientConnInterface) ClientAPIClient { 226 | return &clientAPIClient{cc} 227 | } 228 | 229 | func (c *clientAPIClient) Put(ctx context.Context, in *Entry, opts ...grpc.CallOption) (*Response, error) { 230 | out := new(Response) 231 | err := c.cc.Invoke(ctx, "/schema.ClientAPI/Put", in, out, opts...) 232 | if err != nil { 233 | return nil, err 234 | } 235 | return out, nil 236 | } 237 | 238 | func (c *clientAPIClient) Get(ctx context.Context, in *Msg, opts ...grpc.CallOption) (*Value, error) { 239 | out := new(Value) 240 | err := c.cc.Invoke(ctx, "/schema.ClientAPI/Get", in, out, opts...) 241 | if err != nil { 242 | return nil, err 243 | } 244 | return out, nil 245 | } 246 | 247 | func (c *clientAPIClient) NodeInfo(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*Info, error) { 248 | out := new(Info) 249 | err := c.cc.Invoke(ctx, "/schema.ClientAPI/NodeInfo", in, out, opts...) 250 | if err != nil { 251 | return nil, err 252 | } 253 | return out, nil 254 | } 255 | 256 | // ClientAPIServer is the server API for ClientAPI service. 257 | // All implementations must embed UnimplementedClientAPIServer 258 | // for forward compatibility 259 | type ClientAPIServer interface { 260 | Put(context.Context, *Entry) (*Response, error) 261 | Get(context.Context, *Msg) (*Value, error) 262 | NodeInfo(context.Context, *emptypb.Empty) (*Info, error) 263 | mustEmbedUnimplementedClientAPIServer() 264 | } 265 | 266 | // UnimplementedClientAPIServer must be embedded to have forward compatible implementations. 267 | type UnimplementedClientAPIServer struct { 268 | } 269 | 270 | func (UnimplementedClientAPIServer) Put(context.Context, *Entry) (*Response, error) { 271 | return nil, status.Errorf(codes.Unimplemented, "method Put not implemented") 272 | } 273 | func (UnimplementedClientAPIServer) Get(context.Context, *Msg) (*Value, error) { 274 | return nil, status.Errorf(codes.Unimplemented, "method Get not implemented") 275 | } 276 | func (UnimplementedClientAPIServer) NodeInfo(context.Context, *emptypb.Empty) (*Info, error) { 277 | return nil, status.Errorf(codes.Unimplemented, "method NodeInfo not implemented") 278 | } 279 | func (UnimplementedClientAPIServer) mustEmbedUnimplementedClientAPIServer() {} 280 | 281 | // UnsafeClientAPIServer may be embedded to opt out of forward compatibility for this service. 282 | // Use of this interface is not recommended, as added methods to ClientAPIServer will 283 | // result in compilation errors. 284 | type UnsafeClientAPIServer interface { 285 | mustEmbedUnimplementedClientAPIServer() 286 | } 287 | 288 | func RegisterClientAPIServer(s grpc.ServiceRegistrar, srv ClientAPIServer) { 289 | s.RegisterService(&ClientAPI_ServiceDesc, srv) 290 | } 291 | 292 | func _ClientAPI_Put_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 293 | in := new(Entry) 294 | if err := dec(in); err != nil { 295 | return nil, err 296 | } 297 | if interceptor == nil { 298 | return srv.(ClientAPIServer).Put(ctx, in) 299 | } 300 | info := &grpc.UnaryServerInfo{ 301 | Server: srv, 302 | FullMethod: "/schema.ClientAPI/Put", 303 | } 304 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 305 | return srv.(ClientAPIServer).Put(ctx, req.(*Entry)) 306 | } 307 | return interceptor(ctx, in, info, handler) 308 | } 309 | 310 | func _ClientAPI_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 311 | in := new(Msg) 312 | if err := dec(in); err != nil { 313 | return nil, err 314 | } 315 | if interceptor == nil { 316 | return srv.(ClientAPIServer).Get(ctx, in) 317 | } 318 | info := &grpc.UnaryServerInfo{ 319 | Server: srv, 320 | FullMethod: "/schema.ClientAPI/Get", 321 | } 322 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 323 | return srv.(ClientAPIServer).Get(ctx, req.(*Msg)) 324 | } 325 | return interceptor(ctx, in, info, handler) 326 | } 327 | 328 | func _ClientAPI_NodeInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 329 | in := new(emptypb.Empty) 330 | if err := dec(in); err != nil { 331 | return nil, err 332 | } 333 | if interceptor == nil { 334 | return srv.(ClientAPIServer).NodeInfo(ctx, in) 335 | } 336 | info := &grpc.UnaryServerInfo{ 337 | Server: srv, 338 | FullMethod: "/schema.ClientAPI/NodeInfo", 339 | } 340 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 341 | return srv.(ClientAPIServer).NodeInfo(ctx, req.(*emptypb.Empty)) 342 | } 343 | return interceptor(ctx, in, info, handler) 344 | } 345 | 346 | // ClientAPI_ServiceDesc is the grpc.ServiceDesc for ClientAPI service. 347 | // It's only intended for direct use with grpc.RegisterService, 348 | // and not to be introspected or modified (even as a copy) 349 | var ClientAPI_ServiceDesc = grpc.ServiceDesc{ 350 | ServiceName: "schema.ClientAPI", 351 | HandlerType: (*ClientAPIServer)(nil), 352 | Methods: []grpc.MethodDesc{ 353 | { 354 | MethodName: "Put", 355 | Handler: _ClientAPI_Put_Handler, 356 | }, 357 | { 358 | MethodName: "Get", 359 | Handler: _ClientAPI_Get_Handler, 360 | }, 361 | { 362 | MethodName: "NodeInfo", 363 | Handler: _ClientAPI_NodeInfo_Handler, 364 | }, 365 | }, 366 | Streams: []grpc.StreamDesc{}, 367 | Metadata: "schema.proto", 368 | } 369 | -------------------------------------------------------------------------------- /core/cohort/commitalgo/commitalgo.go: -------------------------------------------------------------------------------- 1 | // Package commitalgo implements the core commit algorithms for 2PC and 3PC protocols. 2 | // 3 | // This package provides the finite state machine logic and transaction handling 4 | // for cohort nodes participating in distributed consensus. 5 | package commitalgo 6 | 7 | import ( 8 | "context" 9 | "sync" 10 | "sync/atomic" 11 | "time" 12 | 13 | log "github.com/sirupsen/logrus" 14 | "github.com/vadiminshakov/committer/core/cohort/commitalgo/hooks" 15 | "github.com/vadiminshakov/committer/core/dto" 16 | "github.com/vadiminshakov/committer/core/walrecord" 17 | "google.golang.org/grpc/codes" 18 | "google.golang.org/grpc/status" 19 | ) 20 | 21 | // wal defines the interface for write-ahead log operations. 22 | type wal interface { 23 | Write(index uint64, key string, value []byte) error 24 | WriteTombstone(index uint64) error 25 | Get(index uint64) (string, []byte, error) 26 | Close() error 27 | } 28 | 29 | // StateStore defines the interface for state storage. 30 | // 31 | //go:generate mockgen -destination=../../../mocks/mock_commitalgo_state_store.go -package=mocks -mock_names=StateStore=MockCommitalgoStateStore . StateStore 32 | type StateStore interface { 33 | Put(key string, value []byte) error 34 | Close() error 35 | } 36 | 37 | // CommitterImpl implements the commit algorithm with state machine and hooks. 38 | type CommitterImpl struct { 39 | store StateStore 40 | wal wal 41 | hookRegistry *hooks.Registry 42 | state *stateMachine 43 | height uint64 44 | timeout uint64 45 | mu sync.Mutex 46 | } 47 | 48 | // NewCommitter creates a new committer instance with the specified configuration. 49 | func NewCommitter(store StateStore, commitType string, wal wal, timeout uint64, customHooks ...hooks.Hook) *CommitterImpl { 50 | registry := hooks.NewRegistry() 51 | 52 | for _, hook := range customHooks { 53 | registry.Register(hook) 54 | } 55 | 56 | if len(customHooks) == 0 { 57 | registry.Register(hooks.NewDefaultHook()) 58 | } 59 | 60 | return &CommitterImpl{ 61 | hookRegistry: registry, 62 | store: store, 63 | wal: wal, 64 | timeout: timeout, 65 | state: newStateMachine(mode(commitType)), 66 | } 67 | } 68 | 69 | // RegisterHook adds a new hook to the committer. 70 | func (c *CommitterImpl) RegisterHook(hook hooks.Hook) { 71 | c.hookRegistry.Register(hook) 72 | } 73 | 74 | // Height returns the current transaction height. 75 | func (c *CommitterImpl) Height() uint64 { 76 | return atomic.LoadUint64(&c.height) 77 | } 78 | 79 | // SetHeight initializes committer height from recovered WAL state. 80 | func (c *CommitterImpl) SetHeight(height uint64) { 81 | atomic.StoreUint64(&c.height, height) 82 | } 83 | 84 | func (c *CommitterImpl) getCurrentState() string { 85 | return c.state.getCurrentState() 86 | } 87 | 88 | // Propose handles the propose phase of the commit protocol. 89 | func (c *CommitterImpl) Propose(ctx context.Context, req *dto.ProposeRequest) (*dto.CohortResponse, error) { 90 | c.mu.Lock() 91 | defer c.mu.Unlock() 92 | 93 | currentHeight := atomic.LoadUint64(&c.height) 94 | if currentHeight > req.Height { 95 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack, Height: currentHeight}, nil 96 | } 97 | 98 | if !c.hookRegistry.ExecutePropose(req) { 99 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack, Height: req.Height}, nil 100 | } 101 | 102 | if err := c.state.Transition(proposeStage); err != nil { 103 | return nil, err 104 | } 105 | 106 | // prepare payload 107 | ptx := walrecord.WalTx{Key: req.Key, Value: req.Value} 108 | pb, err := walrecord.Encode(ptx) 109 | if err != nil { 110 | return nil, status.Errorf(codes.Internal, "encode error: %v", err) 111 | } 112 | 113 | // save 114 | if err := c.wal.Write(walrecord.PreparedSlot(req.Height), walrecord.KeyPrepared, pb); err != nil { 115 | if terr := c.state.Transition(proposeStage); terr != nil { 116 | log.Errorf("failed to reset state to propose after WAL error: %v", terr) 117 | } 118 | return nil, status.Errorf(codes.Internal, "failed to write wal on index %d: %v", req.Height, err) 119 | } 120 | 121 | if c.state.mode == twophase { 122 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck, Height: req.Height}, nil 123 | } 124 | 125 | go c.handleProposeTimeout(req.Height) 126 | 127 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck, Height: req.Height}, nil 128 | } 129 | 130 | func (c *CommitterImpl) handleProposeTimeout(height uint64) { 131 | timer := time.NewTimer(time.Duration(c.timeout) * time.Millisecond) 132 | defer timer.Stop() 133 | <-timer.C 134 | 135 | c.mu.Lock() 136 | defer c.mu.Unlock() 137 | 138 | currentState := c.state.getCurrentState() 139 | currentHeight := atomic.LoadUint64(&c.height) 140 | 141 | log.Debugf("propose timeout handler: state=%s, height=%d, index=%d", currentState, currentHeight, height) 142 | 143 | if currentState != proposeStage || currentHeight != height { 144 | log.Debugf("skipping propose timeout handling for height %d: state=%s, currentHeight=%d", height, currentState, currentHeight) 145 | return 146 | } 147 | 148 | if err := c.wal.Write(walrecord.AbortSlot(height), walrecord.KeyAbort, nil); err != nil { 149 | log.Errorf("failed to write skip record for height %d: %v", height, err) 150 | } else { 151 | log.Warnf("skip proposed message after timeout for height %d", height) 152 | } 153 | } 154 | 155 | // Precommit handles the precommit phase of the three-phase commit protocol. 156 | func (c *CommitterImpl) Precommit(ctx context.Context, index uint64) (*dto.CohortResponse, error) { 157 | c.mu.Lock() 158 | defer c.mu.Unlock() 159 | 160 | currentHeight := atomic.LoadUint64(&c.height) 161 | if index != currentHeight { 162 | return nil, status.Errorf(codes.FailedPrecondition, "invalid precommit height: expected %d, got %d", currentHeight, index) 163 | } 164 | 165 | if c.state.getCurrentState() != proposeStage { 166 | return nil, status.Errorf(codes.FailedPrecondition, "precommit allowed only from propose state, current: %s", c.state.getCurrentState()) 167 | } 168 | 169 | // check if already aborted 170 | if k, _, _ := c.wal.Get(walrecord.AbortSlot(index)); k == walrecord.KeyAbort { 171 | return nil, status.Errorf(codes.Aborted, "transaction %d was aborted", index) 172 | } 173 | 174 | // read prepared to ensure we have data 175 | pSlot := walrecord.PreparedSlot(index) 176 | k, val, err := c.wal.Get(pSlot) 177 | if err != nil { 178 | return nil, status.Errorf(codes.Internal, "failed to read prepared slot %d: %v", pSlot, err) 179 | } 180 | if k != walrecord.KeyPrepared { 181 | return nil, status.Errorf(codes.FailedPrecondition, "not prepared (key=%s)", k) 182 | } 183 | 184 | if err := c.wal.Write(walrecord.PrecommitSlot(index), walrecord.KeyPrecommit, val); err != nil { 185 | return nil, status.Errorf(codes.Internal, "failed to write precommit: %v", err) 186 | } 187 | 188 | c.state.Transition(precommitStage) 189 | 190 | go c.handlePrecommitTimeout(index) 191 | 192 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck, Height: index}, nil 193 | } 194 | 195 | func (c *CommitterImpl) handlePrecommitTimeout(height uint64) { 196 | timer := time.NewTimer(time.Duration(c.timeout) * time.Millisecond) 197 | defer timer.Stop() 198 | <-timer.C 199 | 200 | c.mu.Lock() 201 | defer c.mu.Unlock() 202 | 203 | currentState := c.state.getCurrentState() 204 | currentHeight := atomic.LoadUint64(&c.height) 205 | 206 | log.Debugf("precommit timeout handler: state=%s, height=%d, index=%d", currentState, currentHeight, height) 207 | 208 | if currentState != precommitStage || currentHeight != height { 209 | log.Debugf("skipping autocommit for height %d: state=%s, currentHeight=%d", height, currentState, currentHeight) 210 | return 211 | } 212 | 213 | // check abort 214 | if k, _, _ := c.wal.Get(walrecord.AbortSlot(height)); k == walrecord.KeyAbort { 215 | log.Infof("found abort record for height %d during precommit timeout", height) 216 | c.resetToPropose(height, "abort record found") 217 | return 218 | } 219 | 220 | // check data (prepared/precommit) 221 | key, value, err := c.wal.Get(walrecord.PreparedSlot(height)) 222 | if err != nil { 223 | log.Errorf("failed to read WAL for height %d during precommit timeout: %v", height, err) 224 | c.resetToPropose(height, "WAL read error") 225 | return 226 | } 227 | if key != walrecord.KeyPrepared { 228 | log.Warnf("unexpected key at prepared slot: %s", key) 229 | c.resetToPropose(height, "invalid prepared key") 230 | return 231 | } 232 | if value == nil { 233 | log.Errorf("no data found in WAL for height %d during precommit timeout", height) 234 | c.resetToPropose(height, "no WAL data") 235 | return 236 | } 237 | 238 | log.Warnf("performing autocommit after precommit timeout for height %d", height) 239 | 240 | response, err := c.commit(height) 241 | if err != nil { 242 | log.Errorf("autocommit failed for height %d: %v", height, err) 243 | c.resetToPropose(height, "autocommit failed") 244 | return 245 | } 246 | 247 | if response != nil && response.ResponseType == dto.ResponseTypeNack { 248 | log.Warnf("autocommit returned NACK for height %d", height) 249 | c.resetToPropose(height, "autocommit NACK") 250 | return 251 | } 252 | 253 | log.Infof("successfully autocommitted height %d after precommit timeout", height) 254 | } 255 | 256 | func (c *CommitterImpl) resetToPropose(height uint64, reason string) { 257 | log.Debugf("resetting state to propose for height %d: %s", height, reason) 258 | 259 | currentState := c.state.getCurrentState() 260 | 261 | if c.state.GetMode() == threephase && currentState == precommitStage { 262 | if err := c.state.Transition(commitStage); err != nil { 263 | log.Errorf("failed to transition to commit state during reset for height %d: %v", height, err) 264 | return 265 | } 266 | } 267 | 268 | if err := c.state.Transition(proposeStage); err != nil { 269 | log.Errorf("failed to reset to propose state for height %d: %v (current: %s)", height, err, c.state.getCurrentState()) 270 | } 271 | } 272 | 273 | // Commit handles the commit phase of the atomic commit protocol. 274 | func (c *CommitterImpl) Commit(ctx context.Context, req *dto.CommitRequest) (*dto.CohortResponse, error) { 275 | c.mu.Lock() 276 | defer c.mu.Unlock() 277 | return c.commit(req.Height) 278 | } 279 | 280 | func (c *CommitterImpl) commit(height uint64) (*dto.CohortResponse, error) { 281 | currentHeight := atomic.LoadUint64(&c.height) 282 | 283 | // idempotent commit: if already applied, return ACK without error 284 | if height < currentHeight { 285 | log.Debugf("commit for height %d already applied (current height: %d)", height, currentHeight) 286 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck, Height: height}, nil 287 | } 288 | 289 | // future height: reject 290 | if height > currentHeight { 291 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack, Height: currentHeight}, nil 292 | } 293 | 294 | // height == currentHeight: process the commit 295 | currentState := c.state.getCurrentState() 296 | expectedState := c.getExpectedCommitState() 297 | 298 | if currentState != expectedState { 299 | return nil, status.Errorf(codes.FailedPrecondition, 300 | "invalid state for commit: expected %s for %s mode, but current state is %s", 301 | expectedState, c.state.GetMode(), currentState) 302 | } 303 | 304 | // check if already Aborted 305 | if k, _, _ := c.wal.Get(walrecord.AbortSlot(height)); k == walrecord.KeyAbort { 306 | log.Warnf("rejecting commit for aborted height %d", height) 307 | c.resetToPropose(height, "wal aborted") 308 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack}, nil 309 | } 310 | 311 | if err := c.state.Transition(commitStage); err != nil { 312 | return nil, status.Errorf(codes.FailedPrecondition, "invalid state transition to commit: %v", err) 313 | } 314 | 315 | if !c.hookRegistry.ExecuteCommit(&dto.CommitRequest{Height: height}) { 316 | if terr := c.state.Transition(proposeStage); terr != nil { 317 | log.Errorf("failed to reset state after hook failure: %v", terr) 318 | } 319 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack}, nil 320 | } 321 | 322 | // retrieve payload (try Precommit, then Prepared) 323 | var payload []byte 324 | if k, v, err := c.wal.Get(walrecord.PrecommitSlot(height)); err == nil && k == walrecord.KeyPrecommit { 325 | payload = v 326 | } else if k, v, err := c.wal.Get(walrecord.PreparedSlot(height)); err == nil && k == walrecord.KeyPrepared { 327 | payload = v 328 | } else { 329 | c.resetToPropose(height, "no prepared/precommit data found") 330 | return nil, status.Errorf(codes.FailedPrecondition, "no prepared/precommit data found") 331 | } 332 | 333 | // decode 334 | walTx, err := walrecord.Decode(payload) 335 | if err != nil { 336 | return nil, status.Errorf(codes.Internal, "decode error: %v", err) 337 | } 338 | 339 | // write commit to WAL 340 | if err := c.wal.Write(walrecord.CommitSlot(height), walrecord.KeyCommit, payload); err != nil { 341 | c.resetToPropose(height, "wal write failed") 342 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack}, err 343 | } 344 | 345 | // 2. apply to storage 346 | if err := c.store.Put(walTx.Key, walTx.Value); err != nil { 347 | log.Errorf("CRITICAL: failed to apply committed tx to store: %v", err) 348 | return nil, err 349 | } 350 | 351 | atomic.StoreUint64(&c.height, currentHeight+1) 352 | 353 | if terr := c.state.Transition(proposeStage); terr != nil { 354 | log.Errorf("failed to transition back to propose state after successful commit: %v", terr) 355 | } 356 | 357 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck}, nil 358 | } 359 | 360 | func (c *CommitterImpl) getExpectedCommitState() string { 361 | if c.state.GetMode() == twophase { 362 | return proposeStage 363 | } 364 | return precommitStage 365 | } 366 | 367 | // Abort handles abort requests from coordinator. 368 | func (c *CommitterImpl) Abort(ctx context.Context, req *dto.AbortRequest) (*dto.CohortResponse, error) { 369 | log.Warnf("received abort request for height %d: %s", req.Height, req.Reason) 370 | 371 | c.mu.Lock() 372 | defer c.mu.Unlock() 373 | 374 | currentHeight := atomic.LoadUint64(&c.height) 375 | 376 | if req.Height > currentHeight { 377 | log.Debugf("ignoring abort for future height %d (current: %d)", req.Height, currentHeight) 378 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck}, nil 379 | } 380 | 381 | if req.Height < currentHeight { 382 | log.Debugf("ignoring abort for past height %d (current: %d)", req.Height, currentHeight) 383 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck}, nil 384 | } 385 | 386 | log.Infof("processing abort for current height %d", req.Height) 387 | 388 | if err := c.wal.Write(walrecord.AbortSlot(req.Height), walrecord.KeyAbort, nil); err != nil { 389 | log.Errorf("failed to write abort record for aborted transaction at height %d: %v", req.Height, err) 390 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeNack}, err 391 | } 392 | 393 | log.Infof("successfully wrote tombstone record for height %d", req.Height) 394 | 395 | c.resetToPropose(req.Height, "abort request") 396 | 397 | log.Infof("successfully processed abort for height %d", req.Height) 398 | return &dto.CohortResponse{ResponseType: dto.ResponseTypeAck}, nil 399 | } 400 | -------------------------------------------------------------------------------- /core/cohort/commitalgo/committer_test.go: -------------------------------------------------------------------------------- 1 | package commitalgo 2 | 3 | import ( 4 | "context" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | "github.com/vadiminshakov/committer/core/cohort/commitalgo/hooks" 10 | "github.com/vadiminshakov/committer/core/dto" 11 | "github.com/vadiminshakov/committer/core/walrecord" 12 | "github.com/vadiminshakov/committer/io/store" 13 | "github.com/vadiminshakov/gowal" 14 | ) 15 | 16 | // testHook for testing purposes 17 | type testHook struct { 18 | proposeResult bool 19 | commitResult bool 20 | proposeCalled bool 21 | commitCalled bool 22 | } 23 | 24 | func (t *testHook) OnPropose(req *dto.ProposeRequest) bool { 25 | t.proposeCalled = true 26 | return t.proposeResult 27 | } 28 | 29 | func (t *testHook) OnCommit(req *dto.CommitRequest) bool { 30 | t.commitCalled = true 31 | return t.commitResult 32 | } 33 | 34 | func openTestWAL(t *testing.T, walPath string) *gowal.Wal { 35 | walConfig := gowal.Config{ 36 | Dir: walPath, 37 | Prefix: "test", 38 | SegmentThreshold: 1024, 39 | MaxSegments: 10, 40 | IsInSyncDiskMode: false, 41 | } 42 | 43 | wal, err := gowal.NewWAL(walConfig) 44 | require.NoError(t, err) 45 | t.Cleanup(func() { wal.Close() }) 46 | 47 | return wal 48 | } 49 | 50 | func newStateStore(t *testing.T, wal *gowal.Wal) (*store.Store, *store.RecoveryState) { 51 | dbPath := filepath.Join(t.TempDir(), "badger") 52 | stateStore, recovery, err := store.New(wal, dbPath) 53 | require.NoError(t, err) 54 | t.Cleanup(func() { stateStore.Close() }) 55 | return stateStore, recovery 56 | } 57 | 58 | func prepareCommitter(t *testing.T, walPath, commitType string, timeout uint64, hooks ...hooks.Hook) (*CommitterImpl, *store.Store, *gowal.Wal, *store.RecoveryState) { 59 | wal := openTestWAL(t, walPath) 60 | stateStore, recovery := newStateStore(t, wal) 61 | 62 | committer := NewCommitter(stateStore, commitType, wal, timeout, hooks...) 63 | committer.SetHeight(recovery.Height) 64 | 65 | return committer, stateStore, wal, recovery 66 | } 67 | 68 | func TestNewCommitter_DefaultHook(t *testing.T) { 69 | tempDir := t.TempDir() 70 | committer, _, _, _ := prepareCommitter(t, filepath.Join(tempDir, "wal"), "3pc", 5000) 71 | 72 | require.Equal(t, 1, committer.hookRegistry.Count(), "Expected 1 default hook") 73 | } 74 | 75 | func TestNewCommitter_CustomHooks(t *testing.T) { 76 | tempDir := t.TempDir() 77 | // test: create committer with custom hooks 78 | testHook1 := &testHook{proposeResult: true, commitResult: true} 79 | testHook2 := &testHook{proposeResult: true, commitResult: true} 80 | 81 | committer, _, _, _ := prepareCommitter(t, filepath.Join(tempDir, "wal"), "3pc", 5000, testHook1, testHook2) 82 | 83 | require.Equal(t, 2, committer.hookRegistry.Count(), "Expected 2 custom hooks") 84 | } 85 | 86 | func TestNewCommitter_DynamicHookRegistration(t *testing.T) { 87 | tempDir := t.TempDir() 88 | // test: create committer and add hooks dynamically 89 | committer, _, _, _ := prepareCommitter(t, filepath.Join(tempDir, "wal"), "3pc", 5000) 90 | 91 | require.Equal(t, 1, committer.hookRegistry.Count(), "Expected 1 default hook initially") 92 | 93 | // add hooks dynamically 94 | testHook := &testHook{proposeResult: true, commitResult: true} 95 | committer.RegisterHook(testHook) 96 | 97 | require.Equal(t, 2, committer.hookRegistry.Count(), "Expected 2 hooks after registration") 98 | } 99 | 100 | func TestNewCommitter_BuiltinHooks(t *testing.T) { 101 | tempDir := t.TempDir() 102 | 103 | metricsHook := hooks.NewMetricsHook() 104 | validationHook := hooks.NewValidationHook(100, 1024) 105 | auditHook := hooks.NewAuditHook("test_audit.log") 106 | 107 | committer, _, _, _ := prepareCommitter(t, filepath.Join(tempDir, "wal"), "3pc", 5000, 108 | metricsHook, 109 | validationHook, 110 | auditHook, 111 | ) 112 | 113 | require.Equal(t, 3, committer.hookRegistry.Count(), "Expected 3 built-in hooks") 114 | 115 | proposeCount, commitCount, _ := metricsHook.GetStats() 116 | require.Equal(t, uint64(0), proposeCount, "Expected initial propose count to be 0") 117 | require.Equal(t, uint64(0), commitCount, "Expected initial commit count to be 0") 118 | } 119 | func TestCommit_StateValidation_2PC(t *testing.T) { 120 | tempDir := t.TempDir() 121 | walPath := filepath.Join(tempDir, "wal") 122 | wal := openTestWAL(t, walPath) 123 | 124 | stateStore, recovery := newStateStore(t, wal) 125 | 126 | // create 2PC committer 127 | committer := NewCommitter(stateStore, "two-phase", wal, 5000) 128 | committer.SetHeight(recovery.Height) 129 | 130 | require.Equal(t, "propose", committer.getCurrentState()) 131 | 132 | // first, propose a transaction 133 | proposeReq := &dto.ProposeRequest{ 134 | Height: 0, 135 | Key: "test-key", 136 | Value: []byte("test-value"), 137 | } 138 | 139 | _, err := committer.Propose(context.Background(), proposeReq) 140 | require.NoError(t, err) 141 | 142 | // should still be in propose state for 2PC 143 | require.Equal(t, "propose", committer.getCurrentState()) 144 | 145 | // commit should work from propose state 146 | commitReq := &dto.CommitRequest{Height: 0} 147 | resp, err := committer.Commit(context.Background(), commitReq) 148 | require.NoError(t, err) 149 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 150 | 151 | // should return to propose state after commit 152 | require.Equal(t, "propose", committer.getCurrentState()) 153 | } 154 | 155 | func TestCommit_StateValidation_3PC(t *testing.T) { 156 | tempDir := t.TempDir() 157 | walPath := filepath.Join(tempDir, "wal") 158 | 159 | wal := openTestWAL(t, walPath) 160 | stateStore, recovery := newStateStore(t, wal) 161 | 162 | // create 3PC committer 163 | committer := NewCommitter(stateStore, "three-phase", wal, 5000) 164 | committer.SetHeight(recovery.Height) 165 | 166 | require.Equal(t, "propose", committer.getCurrentState()) 167 | 168 | // first, propose a transaction 169 | proposeReq := &dto.ProposeRequest{ 170 | Height: 0, 171 | Key: "test-key", 172 | Value: []byte("test-value"), 173 | } 174 | 175 | _, err := committer.Propose(context.Background(), proposeReq) 176 | require.NoError(t, err) 177 | 178 | // should still be in propose state 179 | require.Equal(t, "propose", committer.getCurrentState()) 180 | 181 | // commit should fail from propose state in 3PC mode 182 | commitReq := &dto.CommitRequest{Height: 0} 183 | _, err = committer.Commit(context.Background(), commitReq) 184 | require.Error(t, err) 185 | require.Contains(t, err.Error(), "invalid state for commit: expected precommit for three-phase mode, but current state is propose") 186 | 187 | // state should remain in propose after failed commit 188 | require.Equal(t, "propose", committer.getCurrentState()) 189 | 190 | // now go through proper 3PC flow: propose -> precommit -> commit 191 | _, err = committer.Precommit(context.Background(), 0) 192 | require.NoError(t, err) 193 | require.Equal(t, "precommit", committer.getCurrentState()) 194 | 195 | // now commit should work from precommit state 196 | resp, err := committer.Commit(context.Background(), commitReq) 197 | require.NoError(t, err) 198 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 199 | 200 | // should return to propose state after successful commit 201 | require.Equal(t, "propose", committer.getCurrentState()) 202 | } 203 | 204 | func TestCommit_StateRestoration_OnErrors(t *testing.T) { 205 | t.Run("Hook failure", func(t *testing.T) { 206 | tempDir := t.TempDir() 207 | walPath := filepath.Join(tempDir, "wal") 208 | wal := openTestWAL(t, walPath) 209 | 210 | stateStore, recovery := newStateStore(t, wal) 211 | 212 | // create 3PC committer with a hook that will fail commit 213 | failingHook := &testHook{proposeResult: true, commitResult: false} 214 | committer := NewCommitter(stateStore, "three-phase", wal, 5000, failingHook) 215 | committer.SetHeight(recovery.Height) 216 | 217 | // go through proper 3PC flow: propose -> precommit 218 | proposeReq := &dto.ProposeRequest{ 219 | Height: 0, 220 | Key: "test-key", 221 | Value: []byte("test-value"), 222 | } 223 | 224 | _, err := committer.Propose(context.Background(), proposeReq) 225 | require.NoError(t, err) 226 | 227 | _, err = committer.Precommit(context.Background(), 0) 228 | require.NoError(t, err) 229 | require.Equal(t, "precommit", committer.getCurrentState()) 230 | 231 | // commit should fail due to hook, but state should be restored 232 | commitReq := &dto.CommitRequest{Height: 0} 233 | resp, err := committer.Commit(context.Background(), commitReq) 234 | require.NoError(t, err) 235 | require.Equal(t, dto.ResponseTypeNack, resp.ResponseType) 236 | 237 | // state should be restored to propose after failed commit 238 | require.Equal(t, "propose", committer.getCurrentState()) 239 | // height should not be incremented on hook failure 240 | require.Equal(t, uint64(0), committer.Height()) 241 | }) 242 | 243 | t.Run("Skip record (tombstone)", func(t *testing.T) { 244 | tempDir := t.TempDir() 245 | walPath := filepath.Join(tempDir, "wal") 246 | wal := openTestWAL(t, walPath) 247 | 248 | stateStore, recovery := newStateStore(t, wal) 249 | 250 | // create 3PC committer 251 | committer := NewCommitter(stateStore, "three-phase", wal, 5000) 252 | committer.SetHeight(recovery.Height) 253 | 254 | // first propose a normal transaction 255 | proposeReq := &dto.ProposeRequest{ 256 | Height: 0, 257 | Key: "test-key", 258 | Value: []byte("test-value"), 259 | } 260 | 261 | _, err := committer.Propose(context.Background(), proposeReq) 262 | require.NoError(t, err) 263 | 264 | _, err = committer.Precommit(context.Background(), 0) 265 | require.NoError(t, err) 266 | require.Equal(t, "precommit", committer.getCurrentState()) 267 | 268 | // simulate abort by writing abort record to WAL (this would normally be done by Abort method) 269 | err = wal.Write(walrecord.AbortSlot(0), walrecord.KeyAbort, nil) 270 | require.NoError(t, err) 271 | 272 | // commit should detect skip record and restore state 273 | commitReq := &dto.CommitRequest{Height: 0} 274 | resp, err := committer.Commit(context.Background(), commitReq) 275 | require.NoError(t, err) 276 | require.Equal(t, dto.ResponseTypeNack, resp.ResponseType) 277 | 278 | // state should be restored to propose after detecting skip record 279 | require.Equal(t, "propose", committer.getCurrentState()) 280 | // height should not be incremented for skip records 281 | require.Equal(t, uint64(0), committer.Height()) 282 | 283 | // data should not be in database (commit was skipped) 284 | value, err := stateStore.Get("test-key") 285 | require.Error(t, err) // should not exist 286 | require.Nil(t, value) 287 | }) 288 | } 289 | 290 | func TestGetExpectedCommitState(t *testing.T) { 291 | tempDir := t.TempDir() 292 | walPath := filepath.Join(tempDir, "wal") 293 | wal := openTestWAL(t, walPath) 294 | 295 | stateStore, _ := newStateStore(t, wal) 296 | 297 | // test 2PC mode 298 | committer2PC := NewCommitter(stateStore, "two-phase", wal, 5000) 299 | require.Equal(t, "propose", committer2PC.getExpectedCommitState()) 300 | 301 | // test 3PC mode 302 | committer3PC := NewCommitter(stateStore, "three-phase", wal, 5000) 303 | require.Equal(t, "precommit", committer3PC.getExpectedCommitState()) 304 | } 305 | func TestPrecommitTimeout_StateValidation(t *testing.T) { 306 | tempDir := t.TempDir() 307 | walPath := filepath.Join(tempDir, "wal") 308 | wal := openTestWAL(t, walPath) 309 | 310 | stateStore, recovery := newStateStore(t, wal) 311 | 312 | // сreate 3PC committer with short timeout for testing 313 | committer := NewCommitter(stateStore, "three-phase", wal, 50) // 50ms timeout 314 | committer.SetHeight(recovery.Height) 315 | 316 | // test case 1: should skip autocommit when in commit state 317 | // first go to precommit, then to commit 318 | committer.state.Transition(precommitStage) 319 | committer.state.Transition(commitStage) 320 | committer.handlePrecommitTimeout(0) 321 | require.Equal(t, "commit", committer.getCurrentState()) // state unchanged 322 | 323 | // test case 2: should skip autocommit when in propose state 324 | // transition back to propose (commit -> propose is valid) 325 | committer.state.Transition(proposeStage) 326 | committer.handlePrecommitTimeout(0) 327 | require.Equal(t, "propose", committer.getCurrentState()) // state unchanged 328 | 329 | // test case 3: should skip autocommit when height doesn't match 330 | committer.state.Transition(precommitStage) 331 | committer.handlePrecommitTimeout(999) // wrong height 332 | require.Equal(t, "precommit", committer.getCurrentState()) // state unchanged 333 | } 334 | 335 | func TestPrecommitTimeout_AutocommitSuccess(t *testing.T) { 336 | tempDir := t.TempDir() 337 | walPath := filepath.Join(tempDir, "wal") 338 | wal := openTestWAL(t, walPath) 339 | 340 | stateStore, recovery := newStateStore(t, wal) 341 | 342 | // create 3PC committer 343 | committer := NewCommitter(stateStore, "three-phase", wal, 50) 344 | committer.SetHeight(recovery.Height) 345 | 346 | ctx := context.Background() 347 | 348 | // first propose to get data in WAL 349 | proposeReq := &dto.ProposeRequest{ 350 | Height: 0, 351 | Key: "test-key", 352 | Value: []byte("test-value"), 353 | } 354 | _, err := committer.Propose(ctx, proposeReq) 355 | require.NoError(t, err) 356 | 357 | // move to precommit state 358 | _, err = committer.Precommit(ctx, 0) 359 | require.NoError(t, err) 360 | require.Equal(t, "precommit", committer.getCurrentState()) 361 | 362 | // test successful autocommit 363 | committer.handlePrecommitTimeout(0) 364 | 365 | // should be back in propose state after successful autocommit 366 | require.Equal(t, "propose", committer.getCurrentState()) 367 | require.Equal(t, uint64(1), committer.Height()) // height should be incremented 368 | } 369 | 370 | func TestPrecommitTimeout_AutocommitWithSkipRecord(t *testing.T) { 371 | tempDir := t.TempDir() 372 | walPath := filepath.Join(tempDir, "wal") 373 | wal := openTestWAL(t, walPath) 374 | 375 | stateStore, recovery := newStateStore(t, wal) 376 | 377 | // create 3PC committer 378 | committer := NewCommitter(stateStore, "three-phase", wal, 50) 379 | committer.SetHeight(recovery.Height) 380 | 381 | // write abort record directly to WAL 382 | err := wal.Write(walrecord.AbortSlot(0), walrecord.KeyAbort, nil) 383 | require.NoError(t, err) 384 | 385 | // move to precommit state 386 | committer.state.Transition(precommitStage) 387 | require.Equal(t, "precommit", committer.getCurrentState()) 388 | 389 | // test autocommit with skip record - should recover to propose 390 | committer.handlePrecommitTimeout(0) 391 | 392 | // should be back in propose state after recovery 393 | require.Equal(t, "propose", committer.getCurrentState()) 394 | require.Equal(t, uint64(0), committer.Height()) // Hhight should not be incremented for skip 395 | } 396 | 397 | func TestPrecommitTimeout_AutocommitFailure(t *testing.T) { 398 | tempDir := t.TempDir() 399 | walPath := filepath.Join(tempDir, "wal") 400 | wal := openTestWAL(t, walPath) 401 | 402 | stateStore, recovery := newStateStore(t, wal) 403 | 404 | // create 3PC committer with failing hook 405 | failingHook := &testHook{proposeResult: true, commitResult: false} 406 | committer := NewCommitter(stateStore, "three-phase", wal, 50, failingHook) 407 | committer.SetHeight(recovery.Height) 408 | 409 | ctx := context.Background() 410 | 411 | // first propose to get data in WAL 412 | proposeReq := &dto.ProposeRequest{ 413 | Height: 0, 414 | Key: "test-key", 415 | Value: []byte("test-value"), 416 | } 417 | _, err := committer.Propose(ctx, proposeReq) 418 | require.NoError(t, err) 419 | 420 | // move to precommit state 421 | _, err = committer.Precommit(ctx, 0) 422 | require.NoError(t, err) 423 | require.Equal(t, "precommit", committer.getCurrentState()) 424 | 425 | // test autocommit failure - should recover to propose 426 | committer.handlePrecommitTimeout(0) 427 | 428 | // should be back in propose state after recovery from failed autocommit 429 | require.Equal(t, "propose", committer.getCurrentState()) 430 | require.Equal(t, uint64(0), committer.Height()) // height should not be incremented on failure 431 | } 432 | 433 | func TestPrecommitTimeout_NoDataInWAL(t *testing.T) { 434 | tempDir := t.TempDir() 435 | walPath := filepath.Join(tempDir, "wal") 436 | wal := openTestWAL(t, walPath) 437 | 438 | stateStore, recovery := newStateStore(t, wal) 439 | 440 | // create 3PC committer 441 | committer := NewCommitter(stateStore, "three-phase", wal, 50) 442 | committer.SetHeight(recovery.Height) 443 | 444 | // set up state without data in WAL 445 | 446 | // move to precommit state without proposing first 447 | committer.state.Transition(precommitStage) 448 | require.Equal(t, "precommit", committer.getCurrentState()) 449 | 450 | // test autocommit with no data in WAL - should recover to propose 451 | committer.handlePrecommitTimeout(0) 452 | 453 | // should be back in propose state after recovery 454 | require.Equal(t, "propose", committer.getCurrentState()) 455 | require.Equal(t, uint64(0), committer.Height()) // height should not be incremented 456 | } 457 | 458 | func TestRecoverToPropose(t *testing.T) { 459 | tempDir := t.TempDir() 460 | walPath := filepath.Join(tempDir, "wal") 461 | wal := openTestWAL(t, walPath) 462 | 463 | stateStore, recovery := newStateStore(t, wal) 464 | 465 | // create 3PC committer 466 | committer := NewCommitter(stateStore, "three-phase", wal, 50) 467 | committer.SetHeight(recovery.Height) 468 | 469 | // test recovery from precommit state 470 | committer.state.Transition(precommitStage) 471 | require.Equal(t, "precommit", committer.getCurrentState()) 472 | 473 | committer.resetToPropose(0, "test") 474 | require.Equal(t, "propose", committer.getCurrentState()) 475 | 476 | // test recovery from commit state (should transition to propose) 477 | // first to precommit, then to commit 478 | committer.state.Transition(precommitStage) 479 | committer.state.Transition(commitStage) 480 | require.Equal(t, "commit", committer.getCurrentState()) 481 | 482 | committer.resetToPropose(0, "test") 483 | // should transition to propose since commit -> propose is valid in FSM 484 | require.Equal(t, "propose", committer.getCurrentState()) 485 | } 486 | 487 | func TestAbort_CurrentHeight(t *testing.T) { 488 | tempDir := t.TempDir() 489 | walPath := filepath.Join(tempDir, "wal") 490 | wal := openTestWAL(t, walPath) 491 | 492 | stateStore, recovery := newStateStore(t, wal) 493 | 494 | // create 3PC committer 495 | committer := NewCommitter(stateStore, "three-phase", wal, 5000) 496 | committer.SetHeight(recovery.Height) 497 | 498 | // set up a transaction at current height 499 | ctx := context.Background() 500 | proposeReq := &dto.ProposeRequest{ 501 | Height: 0, 502 | Key: "test-key", 503 | Value: []byte("test-value"), 504 | } 505 | 506 | _, err := committer.Propose(ctx, proposeReq) 507 | require.NoError(t, err) 508 | 509 | // move to precommit state 510 | _, err = committer.Precommit(ctx, 0) 511 | require.NoError(t, err) 512 | require.Equal(t, "precommit", committer.getCurrentState()) 513 | 514 | // test abort for current height 515 | abortReq := &dto.AbortRequest{ 516 | Height: 0, 517 | Reason: "Test abort", 518 | } 519 | 520 | resp, err := committer.Abort(ctx, abortReq) 521 | require.NoError(t, err) 522 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 523 | 524 | // should be back in propose state after abort 525 | require.Equal(t, "propose", committer.getCurrentState()) 526 | 527 | // should have the original data in WAL (Prepared) 528 | k, val, err := wal.Get(walrecord.PreparedSlot(0)) 529 | require.NoError(t, err) 530 | require.Equal(t, walrecord.KeyPrepared, k) 531 | 532 | // verify payload 533 | walTx, err := walrecord.Decode(val) 534 | require.NoError(t, err) 535 | require.Equal(t, "test-key", walTx.Key) 536 | 537 | // verify Abort 538 | k, _, err = wal.Get(walrecord.AbortSlot(0)) 539 | require.NoError(t, err) 540 | require.Equal(t, walrecord.KeyAbort, k) 541 | 542 | value, err := stateStore.Get("test-key") 543 | require.Error(t, err, "Original value should not exist in database after abort") 544 | require.Nil(t, value) 545 | } 546 | 547 | func TestAbort_FutureHeight(t *testing.T) { 548 | tempDir := t.TempDir() 549 | walPath := filepath.Join(tempDir, "wal") 550 | wal := openTestWAL(t, walPath) 551 | 552 | stateStore, recovery := newStateStore(t, wal) 553 | 554 | // create committer 555 | committer := NewCommitter(stateStore, "three-phase", wal, 5000) 556 | committer.SetHeight(recovery.Height) 557 | 558 | // test abort for future height (should be ignored) 559 | ctx := context.Background() 560 | abortReq := &dto.AbortRequest{ 561 | Height: 10, // future height 562 | Reason: "Test abort future", 563 | } 564 | 565 | resp, err := committer.Abort(ctx, abortReq) 566 | require.NoError(t, err) 567 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 568 | 569 | // state should remain unchanged 570 | require.Equal(t, "propose", committer.getCurrentState()) 571 | require.Equal(t, uint64(0), committer.Height()) 572 | 573 | key, val, err := wal.Get(0) 574 | require.NoError(t, err) 575 | require.Equal(t, "", key) 576 | require.Nil(t, val, "WAL should not have entry for current height without propose") 577 | } 578 | 579 | func TestAbort_PastHeight(t *testing.T) { 580 | tempDir := t.TempDir() 581 | walPath := filepath.Join(tempDir, "wal") 582 | wal := openTestWAL(t, walPath) 583 | 584 | stateStore, recovery := newStateStore(t, wal) 585 | 586 | // create committer and advance height 587 | committer := NewCommitter(stateStore, "two-phase", wal, 5000) 588 | committer.SetHeight(recovery.Height) 589 | 590 | // complete a transaction to advance height 591 | ctx := context.Background() 592 | proposeReq := &dto.ProposeRequest{ 593 | Height: 0, 594 | Key: "test-key", 595 | Value: []byte("test-value"), 596 | } 597 | 598 | _, err := committer.Propose(ctx, proposeReq) 599 | require.NoError(t, err) 600 | 601 | commitReq := &dto.CommitRequest{Height: 0} 602 | _, err = committer.Commit(ctx, commitReq) 603 | require.NoError(t, err) 604 | 605 | // height should now be 1 606 | require.Equal(t, uint64(1), committer.Height()) 607 | 608 | // test abort for past height (should be ignored) 609 | abortReq := &dto.AbortRequest{ 610 | Height: 0, // Past height 611 | Reason: "Test abort past", 612 | } 613 | 614 | resp, err := committer.Abort(ctx, abortReq) 615 | require.NoError(t, err) 616 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 617 | 618 | // state should remain unchanged 619 | require.Equal(t, "propose", committer.getCurrentState()) 620 | require.Equal(t, uint64(1), committer.Height()) 621 | 622 | // check wal unchanged 623 | // check Prepared 624 | k, _, err := wal.Get(walrecord.PreparedSlot(0)) 625 | require.NoError(t, err) 626 | require.Equal(t, walrecord.KeyPrepared, k) 627 | 628 | // check Commit 629 | k, _, err = wal.Get(walrecord.CommitSlot(0)) 630 | require.NoError(t, err) 631 | require.Equal(t, walrecord.KeyCommit, k) 632 | 633 | // check normal data is in db 634 | value, err := stateStore.Get("test-key") 635 | require.NoError(t, err, "Data should exist in database for past committed transaction") 636 | require.Equal(t, "test-value", string(value)) 637 | } 638 | 639 | func TestAbort_StateRecovery_3PC(t *testing.T) { 640 | tempDir := t.TempDir() 641 | walPath := filepath.Join(tempDir, "wal") 642 | wal := openTestWAL(t, walPath) 643 | 644 | stateStore, recovery := newStateStore(t, wal) 645 | 646 | // create 3PC committer 647 | committer := NewCommitter(stateStore, "three-phase", wal, 5000) 648 | committer.SetHeight(recovery.Height) 649 | 650 | // set up transaction and move to precommit state 651 | ctx := context.Background() 652 | proposeReq := &dto.ProposeRequest{ 653 | Height: 0, 654 | Key: "test-key", 655 | Value: []byte("test-value"), 656 | } 657 | 658 | _, err := committer.Propose(ctx, proposeReq) 659 | require.NoError(t, err) 660 | 661 | _, err = committer.Precommit(ctx, 0) 662 | require.NoError(t, err) 663 | require.Equal(t, "precommit", committer.getCurrentState()) 664 | 665 | // test abort from precommit state 666 | abortReq := &dto.AbortRequest{ 667 | Height: 0, 668 | Reason: "Test 3PC abort", 669 | } 670 | 671 | resp, err := committer.Abort(ctx, abortReq) 672 | require.NoError(t, err) 673 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 674 | 675 | // should be back in propose state 676 | require.Equal(t, "propose", committer.getCurrentState()) 677 | 678 | // check wal 679 | k, _, err := wal.Get(walrecord.AbortSlot(0)) 680 | require.NoError(t, err) 681 | require.Equal(t, walrecord.KeyAbort, k) 682 | } 683 | 684 | func TestAbort_StateRecovery_2PC(t *testing.T) { 685 | tempDir := t.TempDir() 686 | walPath := filepath.Join(tempDir, "wal") 687 | wal := openTestWAL(t, walPath) 688 | 689 | stateStore, recovery := newStateStore(t, wal) 690 | 691 | // create 2PC committer 692 | committer := NewCommitter(stateStore, "two-phase", wal, 5000) 693 | committer.SetHeight(recovery.Height) 694 | 695 | // set up transaction (in 2PC, we stay in propose state) 696 | ctx := context.Background() 697 | proposeReq := &dto.ProposeRequest{ 698 | Height: 0, 699 | Key: "test-key", 700 | Value: []byte("test-value"), 701 | } 702 | 703 | _, err := committer.Propose(ctx, proposeReq) 704 | require.NoError(t, err) 705 | require.Equal(t, "propose", committer.getCurrentState()) 706 | 707 | // test abort from propose state in 2PC 708 | abortReq := &dto.AbortRequest{ 709 | Height: 0, 710 | Reason: "Test 2PC abort", 711 | } 712 | 713 | resp, err := committer.Abort(ctx, abortReq) 714 | require.NoError(t, err) 715 | require.Equal(t, dto.ResponseTypeAck, resp.ResponseType) 716 | 717 | // should remain in propose state 718 | require.Equal(t, "propose", committer.getCurrentState()) 719 | 720 | // check wal 721 | k, _, err := wal.Get(walrecord.AbortSlot(0)) 722 | require.NoError(t, err) 723 | require.Equal(t, walrecord.KeyAbort, k) 724 | } 725 | -------------------------------------------------------------------------------- /chaos_test.go: -------------------------------------------------------------------------------- 1 | //go:build chaos 2 | 3 | // To run this tests you need to install toxiproxy 4 | // 5 | // # macOS/Linux 6 | // curl -L -o toxiproxy-server https://github.com/Shopify/toxiproxy/releases/download/v2.12.0/toxiproxy-server-darwin-amd64 7 | // chmod +x toxiproxy-server 8 | // mv toxiproxy-server ~/go/bin/ 9 | // 10 | // And then run `make test-chaos` 11 | package main 12 | 13 | import ( 14 | "context" 15 | "fmt" 16 | "os" 17 | "path/filepath" 18 | "strconv" 19 | "testing" 20 | "time" 21 | 22 | log "github.com/sirupsen/logrus" 23 | "github.com/stretchr/testify/require" 24 | "github.com/vadiminshakov/committer/core/cohort" 25 | "github.com/vadiminshakov/committer/core/cohort/commitalgo" 26 | "github.com/vadiminshakov/committer/core/coordinator" 27 | "github.com/vadiminshakov/committer/io/gateway/grpc/client" 28 | pb "github.com/vadiminshakov/committer/io/gateway/grpc/proto" 29 | "github.com/vadiminshakov/committer/io/gateway/grpc/server" 30 | "github.com/vadiminshakov/committer/io/store" 31 | "github.com/vadiminshakov/gowal" 32 | ) 33 | 34 | const TOXIPROXY_URL = "http://localhost:8474" 35 | 36 | func TestChaosFollowerFailure(t *testing.T) { 37 | log.SetLevel(log.InfoLevel) 38 | 39 | // immediate connection reset 40 | t.Run("immediate_reset", func(t *testing.T) { 41 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 42 | defer chaosHelper.cleanup() 43 | 44 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 45 | for _, node := range nodes[COHORT_TYPE] { 46 | allAddresses = append(allAddresses, node.Nodeaddr) 47 | } 48 | for _, node := range nodes[COORDINATOR_TYPE] { 49 | allAddresses = append(allAddresses, node.Nodeaddr) 50 | } 51 | 52 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 53 | require.NoError(t, chaosHelper.addResetPeer(nodes[COHORT_TYPE][0].Nodeaddr, 0)) 54 | 55 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 56 | defer canceller() 57 | 58 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 59 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 60 | coordAddr = proxyAddr 61 | } 62 | 63 | c, err := client.NewClientAPI(coordAddr) 64 | require.NoError(t, err) 65 | 66 | _, err = c.Put(context.Background(), "reset_test", []byte("value")) 67 | require.Error(t, err) 68 | require.Contains(t, err.Error(), "failed to send propose") 69 | }) 70 | 71 | // connection drops after 10 bytes of data 72 | t.Run("cohort failure after 10 bytes", func(t *testing.T) { 73 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 74 | defer chaosHelper.cleanup() 75 | 76 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 77 | for _, node := range nodes[COHORT_TYPE] { 78 | allAddresses = append(allAddresses, node.Nodeaddr) 79 | } 80 | for _, node := range nodes[COORDINATOR_TYPE] { 81 | allAddresses = append(allAddresses, node.Nodeaddr) 82 | } 83 | 84 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 85 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 10)) // 10 bytes 86 | 87 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 88 | defer canceller() 89 | 90 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 91 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 92 | coordAddr = proxyAddr 93 | } 94 | 95 | c, err := client.NewClientAPI(coordAddr) 96 | require.NoError(t, err) 97 | 98 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value")) 99 | require.Error(t, err) 100 | require.Contains(t, err.Error(), "failed to send propose") 101 | }) 102 | 103 | // connection drops after 50 bytes of data 104 | t.Run("cohort failure after 50 bytes", func(t *testing.T) { 105 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 106 | defer chaosHelper.cleanup() 107 | 108 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 109 | for _, node := range nodes[COHORT_TYPE] { 110 | allAddresses = append(allAddresses, node.Nodeaddr) 111 | } 112 | for _, node := range nodes[COORDINATOR_TYPE] { 113 | allAddresses = append(allAddresses, node.Nodeaddr) 114 | } 115 | 116 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 117 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 50)) // 50 bytes 118 | 119 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 120 | defer canceller() 121 | 122 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 123 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 124 | coordAddr = proxyAddr 125 | } 126 | 127 | c, err := client.NewClientAPI(coordAddr) 128 | require.NoError(t, err) 129 | 130 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value")) 131 | require.Error(t, err) 132 | require.Contains(t, err.Error(), "failed to send propose") 133 | }) 134 | 135 | // connection drops after 100 bytes of data 136 | t.Run("cohort failure after 100 bytes", func(t *testing.T) { 137 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 138 | defer chaosHelper.cleanup() 139 | 140 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 141 | for _, node := range nodes[COHORT_TYPE] { 142 | allAddresses = append(allAddresses, node.Nodeaddr) 143 | } 144 | for _, node := range nodes[COORDINATOR_TYPE] { 145 | allAddresses = append(allAddresses, node.Nodeaddr) 146 | } 147 | 148 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 149 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 100)) // 100 bytes 150 | 151 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 152 | defer canceller() 153 | 154 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 155 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 156 | coordAddr = proxyAddr 157 | } 158 | 159 | c, err := client.NewClientAPI(coordAddr) 160 | require.NoError(t, err) 161 | 162 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value")) 163 | require.Error(t, err) 164 | require.Contains(t, err.Error(), "failed to send propose") 165 | }) 166 | 167 | // connection drops after 150 bytes of data 168 | t.Run("cohort failure after 150 bytes", func(t *testing.T) { 169 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 170 | defer chaosHelper.cleanup() 171 | 172 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 173 | for _, node := range nodes[COHORT_TYPE] { 174 | allAddresses = append(allAddresses, node.Nodeaddr) 175 | } 176 | for _, node := range nodes[COORDINATOR_TYPE] { 177 | allAddresses = append(allAddresses, node.Nodeaddr) 178 | } 179 | 180 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 181 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 150)) // 150 bytes 182 | 183 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 184 | defer canceller() 185 | 186 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 187 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 188 | coordAddr = proxyAddr 189 | } 190 | 191 | c, err := client.NewClientAPI(coordAddr) 192 | require.NoError(t, err) 193 | 194 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value")) 195 | require.Error(t, err) 196 | require.Contains(t, err.Error(), "failed to send precommit") 197 | }) 198 | 199 | // connection drops after 200 bytes of data 200 | t.Run("cohort failure after 200 bytes", func(t *testing.T) { 201 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 202 | defer chaosHelper.cleanup() 203 | 204 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 205 | for _, node := range nodes[COHORT_TYPE] { 206 | allAddresses = append(allAddresses, node.Nodeaddr) 207 | } 208 | for _, node := range nodes[COORDINATOR_TYPE] { 209 | allAddresses = append(allAddresses, node.Nodeaddr) 210 | } 211 | 212 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 213 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 200)) // 200 bytes 214 | 215 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 216 | defer canceller() 217 | 218 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 219 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 220 | coordAddr = proxyAddr 221 | } 222 | 223 | c, err := client.NewClientAPI(coordAddr) 224 | require.NoError(t, err) 225 | 226 | _, err = c.Put(context.Background(), "early_fail_test", []byte("value")) 227 | require.Error(t, err) 228 | require.Contains(t, err.Error(), "failed to send precommit") 229 | }) 230 | 231 | // connection drops after 250 bytes of data 232 | t.Run("cohort failure after 250 bytes", func(t *testing.T) { 233 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 234 | defer chaosHelper.cleanup() 235 | 236 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 237 | for _, node := range nodes[COHORT_TYPE] { 238 | allAddresses = append(allAddresses, node.Nodeaddr) 239 | } 240 | for _, node := range nodes[COORDINATOR_TYPE] { 241 | allAddresses = append(allAddresses, node.Nodeaddr) 242 | } 243 | 244 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 245 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 250)) // 250 bytes 246 | 247 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 248 | defer canceller() 249 | 250 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 251 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 252 | coordAddr = proxyAddr 253 | } 254 | 255 | c, err := client.NewClientAPI(coordAddr) 256 | require.NoError(t, err) 257 | 258 | _, err = c.Put(context.Background(), "commit_fail_test", []byte("test_value_250")) 259 | if err != nil { 260 | require.Contains(t, err.Error(), "failed to send commit") 261 | // if operation failed, check that value was committed on healthy nodes 262 | if checkValueOnCoordinator(t, "commit_fail_test", []byte("test_value_250")) { 263 | checkValueOnCohorts(t, "commit_fail_test", []byte("test_value_250"), 0) // skip failed cohort (index 0) 264 | checkValueNotOnNode(t, nodes[COHORT_TYPE][0].Nodeaddr, "commit_fail_test") 265 | } 266 | } else { 267 | // if operation succeeded despite limits, all nodes should have the value 268 | t.Log("operation succeeded despite network limits (250 bytes was sufficient)") 269 | checkValueOnCoordinator(t, "commit_fail_test", []byte("test_value_250")) 270 | checkValueOnAllCohorts(t, "commit_fail_test", []byte("test_value_250")) 271 | } 272 | }) 273 | 274 | // connection drops after 500 bytes of data 275 | t.Run("cohort failure after 500 bytes", func(t *testing.T) { 276 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 277 | defer chaosHelper.cleanup() 278 | 279 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 280 | for _, node := range nodes[COHORT_TYPE] { 281 | allAddresses = append(allAddresses, node.Nodeaddr) 282 | } 283 | for _, node := range nodes[COORDINATOR_TYPE] { 284 | allAddresses = append(allAddresses, node.Nodeaddr) 285 | } 286 | 287 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 288 | require.NoError(t, chaosHelper.addDataLimit(nodes[COHORT_TYPE][0].Nodeaddr, 500)) // 500 bytes 289 | 290 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 291 | defer canceller() 292 | 293 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 294 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 295 | coordAddr = proxyAddr 296 | } 297 | 298 | c, err := client.NewClientAPI(coordAddr) 299 | require.NoError(t, err) 300 | 301 | _, err = c.Put(context.Background(), "success_test", []byte("test_value_500")) 302 | require.NoError(t, err) 303 | 304 | // check if value was committed on all nodes 305 | checkValueOnCoordinator(t, "success_test", []byte("test_value_500")) 306 | checkValueOnAllCohorts(t, "success_test", []byte("test_value_500")) 307 | }) 308 | } 309 | 310 | func TestChaosCoordinatorFailure(t *testing.T) { 311 | log.SetLevel(log.InfoLevel) 312 | 313 | // immediate connection reset 314 | t.Run("coordinator_immediate_reset", func(t *testing.T) { 315 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 316 | defer chaosHelper.cleanup() 317 | 318 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 319 | for _, node := range nodes[COHORT_TYPE] { 320 | allAddresses = append(allAddresses, node.Nodeaddr) 321 | } 322 | for _, node := range nodes[COORDINATOR_TYPE] { 323 | allAddresses = append(allAddresses, node.Nodeaddr) 324 | } 325 | 326 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 327 | require.NoError(t, chaosHelper.addResetPeer(nodes[COORDINATOR_TYPE][1].Nodeaddr, 0)) 328 | 329 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 330 | defer canceller() 331 | 332 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 333 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 334 | coordAddr = proxyAddr 335 | } 336 | 337 | c, err := client.NewClientAPI(coordAddr) 338 | require.NoError(t, err) 339 | 340 | _, err = c.Put(context.Background(), "coord_reset_test", []byte("value")) 341 | require.Error(t, err) 342 | require.Contains(t, err.Error(), "connection closed before server preface received") 343 | }) 344 | 345 | // coordinator fails after 50 bytes 346 | t.Run("coordinator_failure_after_50_bytes", func(t *testing.T) { 347 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 348 | defer chaosHelper.cleanup() 349 | 350 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 351 | for _, node := range nodes[COHORT_TYPE] { 352 | allAddresses = append(allAddresses, node.Nodeaddr) 353 | } 354 | for _, node := range nodes[COORDINATOR_TYPE] { 355 | allAddresses = append(allAddresses, node.Nodeaddr) 356 | } 357 | 358 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 359 | require.NoError(t, chaosHelper.addDataLimit(nodes[COORDINATOR_TYPE][1].Nodeaddr, 50)) 360 | 361 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 362 | defer canceller() 363 | 364 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 365 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 366 | coordAddr = proxyAddr 367 | } 368 | 369 | c, err := client.NewClientAPI(coordAddr) 370 | require.NoError(t, err) 371 | 372 | _, err = c.Put(context.Background(), "coord_50_test", []byte("value")) 373 | require.Error(t, err) 374 | }) 375 | 376 | // coordinator fails after 100 bytes 377 | t.Run("coordinator_failure_after_100_bytes", func(t *testing.T) { 378 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 379 | defer chaosHelper.cleanup() 380 | 381 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 382 | for _, node := range nodes[COHORT_TYPE] { 383 | allAddresses = append(allAddresses, node.Nodeaddr) 384 | } 385 | for _, node := range nodes[COORDINATOR_TYPE] { 386 | allAddresses = append(allAddresses, node.Nodeaddr) 387 | } 388 | 389 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 390 | require.NoError(t, chaosHelper.addDataLimit(nodes[COORDINATOR_TYPE][1].Nodeaddr, 100)) 391 | 392 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 393 | defer canceller() 394 | 395 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 396 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 397 | coordAddr = proxyAddr 398 | } 399 | 400 | c, err := client.NewClientAPI(coordAddr) 401 | require.NoError(t, err) 402 | 403 | _, err = c.Put(context.Background(), "coord_100_test", []byte("value")) 404 | require.Error(t, err) 405 | }) 406 | 407 | // coordinator fails after 200 bytes 408 | t.Run("coordinator_failure_after_200_bytes", func(t *testing.T) { 409 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 410 | defer chaosHelper.cleanup() 411 | 412 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 413 | for _, node := range nodes[COHORT_TYPE] { 414 | allAddresses = append(allAddresses, node.Nodeaddr) 415 | } 416 | for _, node := range nodes[COORDINATOR_TYPE] { 417 | allAddresses = append(allAddresses, node.Nodeaddr) 418 | } 419 | 420 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 421 | require.NoError(t, chaosHelper.addDataLimit(nodes[COORDINATOR_TYPE][1].Nodeaddr, 200)) 422 | 423 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 424 | defer canceller() 425 | 426 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 427 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 428 | coordAddr = proxyAddr 429 | } 430 | 431 | c, err := client.NewClientAPI(coordAddr) 432 | require.NoError(t, err) 433 | 434 | _, err = c.Put(context.Background(), "coord_200_test", []byte("value")) 435 | // may succeed or fail depending on when exactly coordinator fails 436 | // the main point is to check cohort consistency afterwards 437 | if err != nil { 438 | t.Logf("coordinator operation failed as expected: %v", err) 439 | } else { 440 | t.Log("coordinator operation succeeded despite limits") 441 | } 442 | 443 | t.Log("checking cohort states after coordinator failure") 444 | checkFollowerStatesAfterCoordinatorFailure(t, "coord_200_test", []byte("value")) 445 | }) 446 | 447 | // coordinator fails during commit phase (after 300 bytes) 448 | t.Run("coordinator_failure_during_commit", func(t *testing.T) { 449 | chaosHelper := newChaosTestHelper(TOXIPROXY_URL) 450 | defer chaosHelper.cleanup() 451 | 452 | allAddresses := make([]string, 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 453 | for _, node := range nodes[COHORT_TYPE] { 454 | allAddresses = append(allAddresses, node.Nodeaddr) 455 | } 456 | for _, node := range nodes[COORDINATOR_TYPE] { 457 | allAddresses = append(allAddresses, node.Nodeaddr) 458 | } 459 | 460 | require.NoError(t, chaosHelper.setupProxies(allAddresses)) 461 | require.NoError(t, chaosHelper.addDataLimit(nodes[COORDINATOR_TYPE][1].Nodeaddr, 300)) 462 | 463 | canceller := startnodesChaos(chaosHelper, pb.CommitType_THREE_PHASE_COMMIT) 464 | defer canceller() 465 | 466 | coordAddr := nodes[COORDINATOR_TYPE][1].Nodeaddr 467 | if proxyAddr := chaosHelper.getProxyAddress(coordAddr); proxyAddr != "" { 468 | coordAddr = proxyAddr 469 | } 470 | 471 | c, err := client.NewClientAPI(coordAddr) 472 | require.NoError(t, err) 473 | 474 | _, err = c.Put(context.Background(), "coord_commit_test", []byte("commit_value")) 475 | // may succeed or fail depending on timing 476 | if err != nil { 477 | t.Logf("coordinator operation failed: %v", err) 478 | } else { 479 | t.Log("coordinator operation completed successfully") 480 | } 481 | 482 | t.Log("checking cohort states after coordinator failure during commit") 483 | checkFollowerStatesAfterCoordinatorFailure(t, "coord_commit_test", []byte("commit_value")) 484 | }) 485 | } 486 | 487 | // startnodesChaos starts nodes with Toxiproxy support 488 | func startnodesChaos(helper *chaosTestHelper, commitType pb.CommitType) func() error { 489 | COORDINATOR_BADGER := fmt.Sprintf("%s%s%d", BADGER_DIR, "coordinator", time.Now().UnixNano()) 490 | COHORT_BADGER := fmt.Sprintf("%s%s%d", BADGER_DIR, "cohort", time.Now().UnixNano()) 491 | 492 | // cleanup dirs 493 | cleanupDirs := []string{COORDINATOR_BADGER, COHORT_BADGER, "./tmp"} 494 | for _, dir := range cleanupDirs { 495 | if _, err := os.Stat(dir); !os.IsNotExist(err) { 496 | failfast(os.RemoveAll(dir)) 497 | } 498 | } 499 | 500 | // create dirs 501 | createDirs := []string{COORDINATOR_BADGER, COHORT_BADGER, "./tmp", "./tmp/cohort", "./tmp/coord"} 502 | for _, dir := range createDirs { 503 | failfast(os.Mkdir(dir, os.FileMode(0777))) 504 | } 505 | 506 | stopfuncs := make([]func(), 0, len(nodes[COHORT_TYPE])+len(nodes[COORDINATOR_TYPE])) 507 | 508 | // start cohorts 509 | for i, node := range nodes[COHORT_TYPE] { 510 | if commitType == pb.CommitType_THREE_PHASE_COMMIT { 511 | // use proxy address of coordinator 512 | if proxyAddr := helper.getProxyAddress(nodes[COORDINATOR_TYPE][1].Nodeaddr); proxyAddr != "" { 513 | node.Coordinator = proxyAddr 514 | } else { 515 | node.Coordinator = nodes[COORDINATOR_TYPE][1].Nodeaddr 516 | } 517 | } 518 | node.DBPath = filepath.Join(COHORT_BADGER, strconv.Itoa(i)) 519 | failfast(os.MkdirAll(node.DBPath, os.FileMode(0o777))) 520 | 521 | walConfig := gowal.Config{ 522 | Dir: "./tmp/cohort/" + strconv.Itoa(i), 523 | Prefix: "msgs_", 524 | SegmentThreshold: 100, 525 | MaxSegments: 100, 526 | IsInSyncDiskMode: false, 527 | } 528 | c, err := gowal.NewWAL(walConfig) 529 | failfast(err) 530 | 531 | stateStore, recovery, err := store.New(c, node.DBPath) 532 | failfast(err) 533 | 534 | ct := server.TWO_PHASE 535 | if commitType == pb.CommitType_THREE_PHASE_COMMIT { 536 | ct = server.THREE_PHASE 537 | } 538 | 539 | committer := commitalgo.NewCommitter(stateStore, ct, c, node.Timeout) 540 | committer.SetHeight(recovery.Height) 541 | cohortImpl := cohort.NewCohort(committer, cohort.Mode(node.CommitType)) 542 | 543 | cohortServer, err := server.New(node, cohortImpl, nil, stateStore) 544 | failfast(err) 545 | 546 | go cohortServer.Run(server.WhiteListChecker) 547 | stopfuncs = append(stopfuncs, cohortServer.Stop) 548 | } 549 | 550 | // start coordinators 551 | for i, coordConfig := range nodes[COORDINATOR_TYPE] { 552 | coordConfig.DBPath = filepath.Join(COORDINATOR_BADGER, strconv.Itoa(i)) 553 | failfast(os.MkdirAll(coordConfig.DBPath, os.FileMode(0o777))) 554 | // update cohorts addresses to use proxies 555 | updatedCohorts := make([]string, len(coordConfig.Cohorts)) 556 | for j, cohortAddr := range coordConfig.Cohorts { 557 | if proxyAddr := helper.getProxyAddress(cohortAddr); proxyAddr != "" { 558 | updatedCohorts[j] = proxyAddr 559 | } else { 560 | updatedCohorts[j] = cohortAddr 561 | } 562 | } 563 | coordConfig.Cohorts = updatedCohorts 564 | 565 | walConfig := gowal.Config{ 566 | Dir: "./tmp/coord/msgs" + strconv.Itoa(i), 567 | Prefix: "msgs", 568 | SegmentThreshold: 100, 569 | MaxSegments: 100, 570 | IsInSyncDiskMode: false, 571 | } 572 | 573 | c, err := gowal.NewWAL(walConfig) 574 | failfast(err) 575 | 576 | stateStore, recovery, err := store.New(c, coordConfig.DBPath) 577 | failfast(err) 578 | 579 | coord, err := coordinator.New(coordConfig, c, stateStore) 580 | failfast(err) 581 | coord.SetHeight(recovery.Height) 582 | 583 | coordServer, err := server.New(coordConfig, nil, coord, stateStore) 584 | failfast(err) 585 | 586 | go coordServer.Run(server.WhiteListChecker) 587 | time.Sleep(100 * time.Millisecond) 588 | stopfuncs = append(stopfuncs, coordServer.Stop) 589 | } 590 | 591 | return func() error { 592 | for _, f := range stopfuncs { 593 | f() 594 | } 595 | failfast(os.RemoveAll("./tmp")) 596 | return os.RemoveAll(BADGER_DIR) 597 | } 598 | } 599 | 600 | // checkValueOnCoordinator checks if a value exists on coordinator 601 | func checkValueOnCoordinator(t *testing.T, key string, expectedValue []byte) bool { 602 | t.Helper() 603 | 604 | coordClient, err := client.NewClientAPI(nodes[COORDINATOR_TYPE][1].Nodeaddr) 605 | require.NoError(t, err) 606 | 607 | coordValue, err := coordClient.Get(context.Background(), key) 608 | if err != nil { 609 | t.Logf("coordinator does not have value for key %s: %v", key, err) 610 | return false 611 | } 612 | 613 | require.Equal(t, expectedValue, coordValue.Value) 614 | t.Logf("coordinator has correct value for key %s", key) 615 | return true 616 | } 617 | 618 | // checkValueOnCohorts checks if a value exists on working cohorts (excluding failed one) 619 | func checkValueOnCohorts(t *testing.T, key string, expectedValue []byte, skipFailedIndex int) int { 620 | t.Helper() 621 | 622 | successCount := 0 623 | for i, cohortAddr := range nodes[COHORT_TYPE] { 624 | if i == skipFailedIndex { 625 | continue 626 | } 627 | 628 | cohortClient, err := client.NewClientAPI(cohortAddr.Nodeaddr) 629 | require.NoError(t, err) 630 | 631 | cohortValue, err := cohortClient.Get(context.Background(), key) 632 | require.NoError(t, err) 633 | require.Equal(t, expectedValue, cohortValue.Value) 634 | successCount++ 635 | } 636 | return successCount 637 | } 638 | 639 | // checkValueOnAllCohorts checks if a value exists on ALL cohorts 640 | func checkValueOnAllCohorts(t *testing.T, key string, expectedValue []byte) { 641 | t.Helper() 642 | 643 | for _, cohortAddr := range nodes[COHORT_TYPE] { 644 | cohortClient, err := client.NewClientAPI(cohortAddr.Nodeaddr) 645 | require.NoError(t, err) 646 | 647 | cohortValue, err := cohortClient.Get(context.Background(), key) 648 | require.NoError(t, err) 649 | require.Equal(t, expectedValue, cohortValue.Value) 650 | } 651 | } 652 | 653 | // checkValueNotOnNode checks that a value does NOT exist on specified node 654 | func checkValueNotOnNode(t *testing.T, nodeAddr string, key string) { 655 | t.Helper() 656 | 657 | nodeClient, err := client.NewClientAPI(nodeAddr) 658 | require.NoError(t, err) 659 | 660 | _, err = nodeClient.Get(context.Background(), key) 661 | if err != nil { 662 | t.Logf("node %s correctly does not have value (as expected for failed node)", nodeAddr) 663 | } else { 664 | t.Logf("WARNING: node %s unexpectedly has the value (should not happen for failed node)", nodeAddr) 665 | } 666 | } 667 | 668 | // checkFollowerStatesAfterCoordinatorFailure checks the state of cohort nodes after coordinator failure 669 | func checkFollowerStatesAfterCoordinatorFailure(t *testing.T, key string, expectedValue []byte) { 670 | t.Helper() 671 | 672 | committedCount := 0 673 | notCommittedCount := 0 674 | 675 | for _, cohortAddr := range nodes[COHORT_TYPE] { 676 | cohortClient, err := client.NewClientAPI(cohortAddr.Nodeaddr) 677 | require.NoError(t, err) 678 | 679 | cohortValue, err := cohortClient.Get(context.Background(), key) 680 | if err == nil { 681 | require.Equal(t, expectedValue, cohortValue.Value) 682 | committedCount++ 683 | } else { 684 | notCommittedCount++ 685 | } 686 | } 687 | 688 | totalCohorts := len(nodes[COHORT_TYPE]) 689 | t.Logf("cohort states: %d committed, %d not committed", committedCount, notCommittedCount) 690 | 691 | // validation: after coordinator failure, cohorts must be in consistent state 692 | // either all committed or none committed (no partial commits allowed) 693 | if committedCount > 0 && committedCount < totalCohorts { 694 | t.Errorf("inconsistent state detected: %d cohorts committed, %d did not commit. This violates consistency!", 695 | committedCount, notCommittedCount) 696 | } else { 697 | t.Logf("consistent state maintained: all cohorts are in the same state") 698 | } 699 | } 700 | --------------------------------------------------------------------------------