├── CODEOWNERS ├── .github ├── FUNDING.yml ├── dependabot.yml └── workflows │ ├── go.yml │ ├── go_cross.yml │ └── go_fuzz.yml ├── generate_test.go ├── .gitignore ├── loadbalancer_test.go ├── go.mod ├── resolver.go ├── resolver_test.go ├── query.go ├── helper.go ├── LICENSE ├── fuzz_test.go ├── Makefile ├── options_test.go ├── helper_test.go ├── loadbalancer.go ├── go.sum ├── options.go ├── conn.go ├── example_wrap_dbs_mutli_primary_multi_replicas_test.go ├── misc └── makefile │ └── tools.Makefile ├── tx.go ├── .golangci.yaml ├── README.md ├── stmt.go ├── db_test.go └── db.go /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @bxcodec @Nasfame 2 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [bxcodec] 2 | -------------------------------------------------------------------------------- /generate_test.go: -------------------------------------------------------------------------------- 1 | //go:build generate 2 | 3 | //go:generate go install github.com/mfridman/tparse@v0.15.0 4 | //go:generate go install gotest.tools/gotestsum@latest 5 | //go:generate tparse -v 6 | //go:generate gotestsum --version 7 | 8 | /* 9 | Installs test deps 10 | */ 11 | 12 | package dbresolver_test 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Macos 2 | .DS_Store 3 | 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | vendor/ 19 | bin/ 20 | coverage.txt 21 | oryxBuildBinary 22 | .idea/ 23 | .dist/ 24 | 25 | cover.txt 26 | 27 | 28 | # Created by go-fuzz 29 | testdata/* 30 | -------------------------------------------------------------------------------- /loadbalancer_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | "testing/quick" 7 | ) 8 | 9 | func TestReplicaRoundRobin(t *testing.T) { 10 | db := &RoundRobinLoadBalancer[*sql.DB]{} 11 | last := -1 12 | 13 | err := quick.Check(func(n int) bool { 14 | index := db.predict(n) 15 | if n <= 1 { 16 | return index == 0 17 | } 18 | 19 | result := index > 0 && index < n && index != last 20 | last = index 21 | 22 | return result 23 | }, nil) 24 | 25 | if err != nil { 26 | t.Error(err) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "gomod" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/bxcodec/dbresolver/v2 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.5.2 7 | github.com/google/gofuzz v1.2.0 8 | github.com/lib/pq v1.10.9 9 | go.uber.org/multierr v1.11.0 10 | ) 11 | 12 | require github.com/stretchr/testify v1.8.1 // indirect 13 | 14 | retract ( 15 | // below versions doesn't support Update,Insert queries with "RETURNING CLAUSE" 16 | // v1.0.0 17 | // v1.0.0-beta 18 | // v1.0.1 19 | // v1.0.2 20 | // v1.1.0 21 | v2.0.0 22 | v2.0.0-beta.2 23 | v2.0.0-beta 24 | v2.0.0-alpha.5 25 | v2.0.0-alpha.4 26 | v2.0.0-alpha.3 27 | v2.0.0-alpha.2 28 | v2.0.0-alpha 29 | ) 30 | -------------------------------------------------------------------------------- /resolver.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | // New will resolve all the passed connection with configurable parameters 4 | func New(opts ...OptionFunc) DB { 5 | opt := defaultOption() 6 | for _, optFunc := range opts { 7 | optFunc(opt) 8 | } 9 | 10 | if len(opt.PrimaryDBs) == 0 { 11 | panic("required primary db connection, set the primary db " + 12 | "connection with dbresolver.New(dbresolver.WithPrimaryDBs(primaryDB))") 13 | } 14 | return &sqlDB{ 15 | primaries: opt.PrimaryDBs, 16 | replicas: opt.ReplicaDBs, 17 | loadBalancer: opt.DBLB, 18 | stmtLoadBalancer: opt.StmtLB, 19 | queryTypeChecker: opt.QueryTypeChecker, 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /resolver_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver_test 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/bxcodec/dbresolver/v2" 8 | ) 9 | 10 | func TestWrapDBWithMultiDBs(t *testing.T) { 11 | db1 := &sql.DB{} 12 | db2 := &sql.DB{} 13 | db3 := &sql.DB{} 14 | 15 | db := dbresolver.New(dbresolver.WithPrimaryDBs(db1), dbresolver.WithReplicaDBs(db2, db3)) 16 | 17 | if db == nil { 18 | t.Errorf("expected %v, got %v", "not nil", db) 19 | } 20 | } 21 | 22 | func TestWrapDBWithOneDB(t *testing.T) { 23 | db1 := &sql.DB{} 24 | 25 | db := dbresolver.New(dbresolver.WithPrimaryDBs(db1)) 26 | 27 | if db == nil { 28 | t.Errorf("expected %v, got %v", "not nil", db) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go Test 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | workflow_dispatch: 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref}} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - name: Set up Go 22 | uses: actions/setup-go@v4 23 | with: 24 | cache-dependency-path: go.sum 25 | go-version-file: go.mod 26 | check-latest: true 27 | 28 | - name: Linter 29 | run: make lint-prepare && make lint 30 | 31 | - name: Test 32 | run: make test 33 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import "strings" 4 | 5 | type QueryType int 6 | 7 | const ( 8 | QueryTypeUnknown QueryType = iota 9 | QueryTypeRead 10 | QueryTypeWrite 11 | ) 12 | 13 | // QueryTypeChecker is used to try to detect the query type, like for detecting RETURNING clauses in 14 | // INSERT/UPDATE clauses. 15 | type QueryTypeChecker interface { 16 | Check(query string) QueryType 17 | } 18 | 19 | // DefaultQueryTypeChecker searches for a "RETURNING" string inside the query to detect a write query. 20 | type DefaultQueryTypeChecker struct { 21 | } 22 | 23 | func (c DefaultQueryTypeChecker) Check(query string) QueryType { 24 | if strings.Contains(strings.ToUpper(query), "RETURNING") { 25 | return QueryTypeWrite 26 | } 27 | return QueryTypeUnknown 28 | } 29 | -------------------------------------------------------------------------------- /helper.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | 7 | "go.uber.org/multierr" 8 | ) 9 | 10 | func doParallely(n int, fn func(i int) error) error { 11 | errors := make(chan error, n) 12 | wg := &sync.WaitGroup{} 13 | wg.Add(n) 14 | for i := 0; i < n; i++ { 15 | go func(i int) { 16 | errors <- fn(i) 17 | wg.Done() 18 | }(i) 19 | } 20 | 21 | go func(wg *sync.WaitGroup) { 22 | wg.Wait() 23 | close(errors) 24 | }(wg) 25 | 26 | var arrErrs []error 27 | for err := range errors { 28 | if err != nil { 29 | arrErrs = append(arrErrs, err) 30 | } 31 | } 32 | 33 | return multierr.Combine(arrErrs...) 34 | } 35 | 36 | func isDBConnectionError(err error) bool { 37 | if _, ok := err.(net.Error); ok { 38 | return ok 39 | } 40 | 41 | if _, ok := err.(*net.OpError); ok { 42 | return ok 43 | } 44 | return false 45 | } 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Iman Tumorang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fuzz_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | fuzz "github.com/google/gofuzz" 8 | ) 9 | 10 | func FuzzMultiWrite(f *testing.F) { 11 | func() { // generate corpus 12 | 13 | var rdbCount, wdbCount, lbPolicyID uint8 = 1, 1, 1 14 | 15 | if !testing.Short() { 16 | fuzzer := fuzz.New() 17 | fuzzer.Fuzz(&rdbCount) 18 | fuzzer.Fuzz(&wdbCount) 19 | fuzzer.Fuzz(&lbPolicyID) 20 | } 21 | 22 | f.Add(wdbCount, rdbCount, lbPolicyID) 23 | }() 24 | 25 | f.Fuzz(func(t *testing.T, wdbCount, rdbCount, lbPolicyID uint8) { //next-release: uint8 -> uint 26 | 27 | policyID := lbPolicyID % uint8(len(LoadBalancerPolicies)) 28 | 29 | config := DBConfig{ 30 | wdbCount, rdbCount, LoadBalancerPolicies[policyID], 31 | } 32 | 33 | if config.primaryDBCount == 0 { 34 | t.Skipf("skipping due to mising primary db") 35 | } 36 | 37 | t.Log("dbConf", config) 38 | 39 | t.Run(fmt.Sprintf("%v", config), func(t *testing.T) { 40 | 41 | dbConf := config 42 | 43 | testMW(t, dbConf) 44 | }) 45 | 46 | }) 47 | } 48 | 49 | /*func FuzzTest(f *testing.F) { 50 | 51 | f.Add(1) 52 | 53 | f.Fuzz(func(t *testing.T, dbCount int) { 54 | t.Fatal(dbCount) 55 | }) 56 | 57 | } 58 | */ 59 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Exporting bin folder to the path for makefile 2 | export PATH := $(PWD)/bin:$(PATH) 3 | # Default Shell 4 | export SHELL := bash 5 | # Type of OS: Linux or Darwin. 6 | export OSTYPE := $(shell uname -s | tr A-Z a-z) 7 | export ARCH := $(shell uname -m) 8 | 9 | ifeq ($(OSTYPE),Darwin) 10 | export MallocNanoZone=0 11 | endif 12 | 13 | include ./misc/makefile/tools.Makefile 14 | 15 | build: test 16 | @go build ./... 17 | 18 | install-deps: gotestsum tparse ## Install Development Dependencies (localy). 19 | deps: $(GOTESTSUM) $(TPARSE) ## Checks for Global Development Dependencies. 20 | deps: 21 | @echo "Required Tools Are Available" 22 | 23 | TESTS_ARGS := --format testname --jsonfile gotestsum.json.out 24 | TESTS_ARGS += --max-fails 2 25 | TESTS_ARGS += -- ./... 26 | TESTS_ARGS += -parallel 2 27 | TESTS_ARGS += -count 1 28 | TESTS_ARGS += -failfast 29 | TESTS_ARGS += -coverprofile coverage.out 30 | TESTS_ARGS += -timeout 60s 31 | TESTS_ARGS += -race 32 | run-tests: $(GOTESTSUM) 33 | @ gotestsum $(TESTS_ARGS) -short 34 | 35 | test: run-tests $(TPARSE) ## Run Tests & parse details 36 | @cat gotestsum.json.out | $(TPARSE) -all -notests 37 | 38 | 39 | lint: $(GOLANGCI) ## Runs golangci-lint with predefined configuration 40 | @echo "Applying linter" 41 | golangci-lint version 42 | golangci-lint run -c .golangci.yaml ./... 43 | 44 | .PHONY: lint lint-prepare clean build unittest -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver_test 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/bxcodec/dbresolver/v2" 8 | ) 9 | 10 | func TestOptionWithPrimaryDBs(t *testing.T) { 11 | dbPrimary := &sql.DB{} 12 | optFunc := dbresolver.WithPrimaryDBs(dbPrimary) 13 | opt := &dbresolver.Option{} 14 | optFunc(opt) 15 | 16 | if len(opt.PrimaryDBs) != 1 { 17 | t.Errorf("want %v, got %v", 1, len(opt.PrimaryDBs)) 18 | } 19 | } 20 | 21 | func TestOptionWithReplicaDBs(t *testing.T) { 22 | dbReplica := &sql.DB{} 23 | optFunc := dbresolver.WithReplicaDBs(dbReplica) 24 | opt := &dbresolver.Option{} 25 | optFunc(opt) 26 | 27 | if len(opt.ReplicaDBs) != 1 { 28 | t.Errorf("want %v, got %v", 1, len(opt.PrimaryDBs)) 29 | } 30 | } 31 | 32 | func TestOptionWithLoadBalancer(t *testing.T) { 33 | optFunc := dbresolver.WithLoadBalancer(dbresolver.RoundRobinLB) 34 | opt := &dbresolver.Option{} 35 | optFunc(opt) 36 | 37 | if opt.DBLB.Name() != dbresolver.RoundRobinLB { 38 | t.Errorf("want %v, got %v", dbresolver.RoundRobinLB, opt.DBLB.Name()) 39 | } 40 | } 41 | 42 | func TestOptionWithLoadBalancerNonExist(t *testing.T) { 43 | defer func() { 44 | if r := recover(); r == nil { 45 | t.Errorf("Should throw panic, but it does not") 46 | } 47 | }() 48 | 49 | optFunc := dbresolver.WithLoadBalancer(dbresolver.LoadBalancerPolicy("NON_EXIST")) 50 | opt := &dbresolver.Option{} 51 | optFunc(opt) 52 | } 53 | -------------------------------------------------------------------------------- /helper_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "runtime" 8 | "testing" 9 | ) 10 | 11 | func TestParallelFunction(t *testing.T) { 12 | runtime.GOMAXPROCS(runtime.NumCPU()) 13 | seq := []int{1, 2, 3, 4, 5, 6, 7, 8} 14 | err := doParallely(len(seq), func(i int) error { 15 | if seq[i]%2 == 1 { 16 | seq[i] *= seq[i] 17 | return nil 18 | } 19 | return fmt.Errorf("%d is an even number", seq[i]) 20 | }) 21 | 22 | if err == nil { 23 | t.Fatal("Expected error, got nil") 24 | } 25 | 26 | // this is the expected end result 27 | want := []int{1, 2, 9, 4, 25, 6, 49, 8} 28 | for i, wanted := range want { 29 | if wanted != seq[i] { 30 | t.Errorf("Wrong value at position %d. Want: %d, Got: %d", i, wanted, seq[i]) 31 | } 32 | } 33 | } 34 | 35 | func TestIsDBConnectionError(t *testing.T) { 36 | // test connection timeout error 37 | timeoutError := &net.OpError{Op: "dial", Net: "tcp", Err: &net.DNSError{IsTimeout: true}} 38 | if !isDBConnectionError(timeoutError) { 39 | t.Error("Expected true for timeout error") 40 | } 41 | 42 | // test general network error 43 | networkError := &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("network error")} 44 | if !isDBConnectionError(networkError) { 45 | t.Error("Expected true for network error") 46 | } 47 | 48 | // test non-network error 49 | otherError := errors.New("other error") 50 | if isDBConnectionError(otherError) { 51 | t.Error("Expected false for non-network error") 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /.github/workflows/go_cross.yml: -------------------------------------------------------------------------------- 1 | name: Cross Compatibility Test 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | workflow_dispatch: 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref}} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | build: 16 | strategy: 17 | matrix: 18 | go-version: ["1.21.x", "1.22.x"] 19 | arch: [x64, arm, arm64] 20 | os: [macos-latest, ubuntu-latest] #windows-latest 21 | 22 | include: 23 | - os: ubuntu-latest 24 | gocache: /tmp/go/gocache 25 | # - os: windows-latest 26 | # gocache: C:/gocache 27 | - os: macos-latest 28 | gocache: /tmp/go/gocache 29 | 30 | fail-fast: true 31 | max-parallel: 5 32 | 33 | runs-on: ${{ matrix.os }} 34 | 35 | timeout-minutes: 10 36 | 37 | env: 38 | GOCACHE: ${{matrix.gocache}} 39 | 40 | steps: 41 | - uses: actions/checkout@v4 42 | with: 43 | fetch-depth: 0 44 | 45 | - name: Fetch latest changes 46 | run: git fetch --all 47 | - name: Set up Go 48 | uses: actions/setup-go@v4 49 | with: 50 | go-version: ${{ matrix.go-version }} 51 | check-latest: true 52 | 53 | # - name: Clear Go cache 54 | # run: go clean -cache 55 | - name: Cache Go tests 56 | uses: actions/cache@v3 57 | with: 58 | path: | 59 | ${{env.GOCACHE}} 60 | key: ${{ github.workflow }}-${{ runner.os }}-${{ matrix.arch }}-go-${{matrix.go-version}}-${{ hashFiles('**/go.mod','*_test.go') }} 61 | restore-keys: | 62 | ${{ github.workflow }}-${{ runner.os }}-${{ matrix.arch }}-go-${{matrix.go-version}}-${{ hashFiles('**/go.mod','*_test.go') }} 63 | 64 | - name: Linter 65 | continue-on-error: true 66 | run: make lint-prepare && make lint 67 | 68 | - name: Test 69 | run: make test 70 | -------------------------------------------------------------------------------- /loadbalancer.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "database/sql" 5 | "math/rand" 6 | "sync/atomic" 7 | "time" 8 | ) 9 | 10 | // DBConnection is the generic type for DB and Stmt operation 11 | type DBConnection interface { 12 | *sql.DB | *sql.Stmt 13 | } 14 | 15 | // LoadBalancer define the load balancer contract 16 | type LoadBalancer[T DBConnection] interface { 17 | Resolve([]T) T 18 | Name() LoadBalancerPolicy 19 | predict(n int) int 20 | } 21 | 22 | // RandomLoadBalancer represent for Random LB policy 23 | type RandomLoadBalancer[T DBConnection] struct { 24 | randInt chan int 25 | } 26 | 27 | // RandomLoadBalancer return the LB policy name 28 | func (lb RandomLoadBalancer[T]) Name() LoadBalancerPolicy { 29 | return RandomLB 30 | } 31 | 32 | // Resolve return the resolved option for Random LB. 33 | // Marked with go:nosplit to prevent preemption. 34 | // 35 | //go:nosplit 36 | func (lb RandomLoadBalancer[T]) Resolve(dbs []T) T { 37 | if len(lb.randInt) == 0 { 38 | lb.predict(len(dbs)) 39 | } 40 | randomInt := <-lb.randInt 41 | return dbs[randomInt] 42 | } 43 | 44 | func (lb RandomLoadBalancer[T]) predict(n int) int { 45 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 46 | max := n - 1 //nolint 47 | min := 0 //nolint 48 | idx := r.Intn(max-min+1) + min 49 | lb.randInt <- idx 50 | return idx 51 | } 52 | 53 | // RoundRobinLoadBalancer represent for RoundRobin LB policy 54 | type RoundRobinLoadBalancer[T DBConnection] struct { 55 | counter uint64 // Monotonically incrementing counter on every call 56 | } 57 | 58 | // Name return the LB policy name 59 | func (lb RoundRobinLoadBalancer[T]) Name() LoadBalancerPolicy { 60 | return RoundRobinLB 61 | } 62 | 63 | // Resolve return the resolved option for RoundRobin LB 64 | func (lb *RoundRobinLoadBalancer[T]) Resolve(dbs []T) T { 65 | idx := lb.predict(len(dbs)) 66 | return dbs[idx] 67 | } 68 | 69 | func (lb *RoundRobinLoadBalancer[T]) predict(n int) int { 70 | if n <= 1 { 71 | return 0 72 | } 73 | // counter := lb.counter 74 | return int(atomic.AddUint64(&lb.counter, 1) % uint64(n)) 75 | } 76 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= 2 | github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= 7 | github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 8 | github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= 9 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 10 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 13 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 14 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 15 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 16 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 17 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 18 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 19 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 20 | go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= 21 | go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= 22 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 23 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 24 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 25 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 26 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | ) 7 | 8 | // LoadBalancerPolicy define the loadbalancer policy data type 9 | type LoadBalancerPolicy string 10 | 11 | // Supported Loadbalancer policy 12 | const ( 13 | RoundRobinLB LoadBalancerPolicy = "ROUND_ROBIN" 14 | RandomLB LoadBalancerPolicy = "RANDOM" 15 | ) 16 | 17 | // Option define the option property 18 | type Option struct { 19 | PrimaryDBs []*sql.DB 20 | ReplicaDBs []*sql.DB 21 | StmtLB StmtLoadBalancer 22 | DBLB DBLoadBalancer 23 | QueryTypeChecker QueryTypeChecker 24 | } 25 | 26 | // OptionFunc used for option chaining 27 | type OptionFunc func(opt *Option) 28 | 29 | // WithPrimaryDBs add primaryDBs to the resolver 30 | func WithPrimaryDBs(primaryDBs ...*sql.DB) OptionFunc { 31 | return func(opt *Option) { 32 | opt.PrimaryDBs = primaryDBs 33 | } 34 | } 35 | 36 | // WithReplicaDBs add replica DBs to the resolver 37 | func WithReplicaDBs(replicaDBs ...*sql.DB) OptionFunc { 38 | return func(opt *Option) { 39 | opt.ReplicaDBs = replicaDBs 40 | } 41 | } 42 | 43 | // WithQueryTypeChecker sets the query type checker instance. 44 | // The default one just checks for the presence of the string "RETURNING" in the uppercase query. 45 | func WithQueryTypeChecker(checker QueryTypeChecker) OptionFunc { 46 | return func(opt *Option) { 47 | opt.QueryTypeChecker = checker 48 | } 49 | } 50 | 51 | // WithLoadBalancer configure the loadbalancer for the resolver 52 | func WithLoadBalancer(lb LoadBalancerPolicy) OptionFunc { 53 | return func(opt *Option) { 54 | switch lb { 55 | case RoundRobinLB: 56 | opt.DBLB = &RoundRobinLoadBalancer[*sql.DB]{} 57 | opt.StmtLB = &RoundRobinLoadBalancer[*sql.Stmt]{} 58 | case RandomLB: 59 | opt.DBLB = &RandomLoadBalancer[*sql.DB]{ 60 | randInt: make(chan int, 1), 61 | } 62 | opt.StmtLB = &RandomLoadBalancer[*sql.Stmt]{ 63 | randInt: make(chan int, 1), 64 | } 65 | default: 66 | panic(fmt.Sprintf("LoadBalancer: %s is not supported", lb)) 67 | } 68 | } 69 | } 70 | 71 | func defaultOption() *Option { 72 | return &Option{ 73 | DBLB: &RoundRobinLoadBalancer[*sql.DB]{}, 74 | StmtLB: &RoundRobinLoadBalancer[*sql.Stmt]{}, 75 | QueryTypeChecker: &DefaultQueryTypeChecker{}, 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "strings" 7 | ) 8 | 9 | // Conn is a *sql.Conn wrapper. 10 | // Its main purpose is to be able to return the internal Tx and Stmt interfaces. 11 | type Conn interface { 12 | Close() error 13 | BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) 14 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 15 | PingContext(ctx context.Context) error 16 | PrepareContext(ctx context.Context, query string) (Stmt, error) 17 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 18 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 19 | Raw(f func(driverConn interface{}) error) (err error) 20 | } 21 | 22 | type conn struct { 23 | sourceDB *sql.DB 24 | conn *sql.Conn 25 | } 26 | 27 | func (c *conn) Close() error { 28 | return c.conn.Close() 29 | } 30 | 31 | func (c *conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { 32 | stx, err := c.conn.BeginTx(ctx, opts) 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | return &tx{ 38 | sourceDB: c.sourceDB, 39 | tx: stx, 40 | }, nil 41 | } 42 | 43 | func (c *conn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 44 | return c.conn.ExecContext(ctx, query, args...) 45 | } 46 | 47 | func (c *conn) PingContext(ctx context.Context) error { 48 | return c.conn.PingContext(ctx) 49 | } 50 | 51 | func (c *conn) PrepareContext(ctx context.Context, query string) (Stmt, error) { 52 | pstmt, err := c.conn.PrepareContext(ctx, query) 53 | if err != nil { 54 | return nil, err 55 | } 56 | 57 | _query := strings.ToUpper(query) 58 | writeFlag := strings.Contains(_query, "RETURNING") 59 | 60 | return newSingleDBStmt(c.sourceDB, pstmt, writeFlag), nil 61 | } 62 | 63 | func (c *conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 64 | return c.conn.QueryContext(ctx, query, args...) 65 | } 66 | 67 | func (c *conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 68 | return c.conn.QueryRowContext(ctx, query, args...) 69 | } 70 | 71 | func (c *conn) Raw(f func(driverConn interface{}) error) (err error) { 72 | return c.conn.Raw(f) 73 | } 74 | -------------------------------------------------------------------------------- /example_wrap_dbs_mutli_primary_multi_replicas_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | 9 | "github.com/bxcodec/dbresolver/v2" 10 | _ "github.com/lib/pq" 11 | ) 12 | 13 | func ExampleNew_multiPrimaryMultiReplicas() { 14 | var ( 15 | host1 = "localhost" 16 | port1 = 5432 17 | user1 = "postgresrw" 18 | password1 = "" 19 | host2 = "localhost" 20 | port2 = 5433 21 | user2 = "postgresro" 22 | password2 = "" 23 | dbname = "" 24 | ) 25 | // connection string 26 | rwPrimary := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", host1, port1, user1, password1, dbname) 27 | readOnlyReplica := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", host2, port2, user2, password2, dbname) 28 | 29 | // open database for primary 30 | dbPrimary1, err := sql.Open("postgres", rwPrimary) 31 | if err != nil { 32 | log.Print("go error when connecting to the DB") 33 | } 34 | // open database for primary 35 | dbPrimary2, err := sql.Open("postgres", rwPrimary) 36 | if err != nil { 37 | log.Print("go error when connecting to the DB") 38 | } 39 | 40 | // configure the DBs for other setup eg, tracing, etc 41 | // eg, tracing.Postgres(dbPrimary) 42 | 43 | // open database for replica 44 | dbReadOnlyReplica1, err := sql.Open("postgres", readOnlyReplica) 45 | if err != nil { 46 | log.Print("go error when connecting to the DB") 47 | } 48 | // open database for replica 49 | dbReadOnlyReplica2, err := sql.Open("postgres", readOnlyReplica) 50 | if err != nil { 51 | log.Print("go error when connecting to the DB") 52 | } 53 | // configure the DBs for other setup eg, tracing, etc 54 | // eg, tracing.Postgres(dbReadOnlyReplica) 55 | 56 | connectionDB := dbresolver.New( 57 | dbresolver.WithPrimaryDBs(dbPrimary1, dbPrimary2), 58 | dbresolver.WithReplicaDBs(dbReadOnlyReplica1, dbReadOnlyReplica2), 59 | dbresolver.WithLoadBalancer(dbresolver.RoundRobinLB)) 60 | 61 | // now you can use the connection for all DB operation 62 | _, err = connectionDB.ExecContext(context.Background(), "DELETE FROM book WHERE id=$1") // will use primaryDB 63 | if err != nil { 64 | log.Print("go error when executing the query to the DB", err) 65 | } 66 | _ = connectionDB.QueryRowContext(context.Background(), "SELECT * FROM book WHERE id=$1") // will use replicaReadOnlyDB 67 | 68 | // Output: 69 | // 70 | } 71 | -------------------------------------------------------------------------------- /misc/makefile/tools.Makefile: -------------------------------------------------------------------------------- 1 | # This makefile should be used to hold functions/variables 2 | 3 | ifeq ($(ARCH),x86_64) 4 | ARCH := amd64 5 | else ifeq ($(ARCH),aarch64) 6 | ARCH := arm64 7 | endif 8 | 9 | define github_url 10 | https://github.com/$(GITHUB)/releases/download/v$(VERSION)/$(ARCHIVE) 11 | endef 12 | 13 | # creates a directory bin. 14 | bin: 15 | @ mkdir -p $@ 16 | 17 | # ~~~ Tools ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | # ~~ [ gotestsum ] ~~~ https://github.com/gotestyourself/gotestsum ~~~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | GOTESTSUM := $(shell command -v gotestsum || echo "bin/gotestsum") 22 | gotestsum: bin/gotestsum ## Installs gotestsum (testing go code) 23 | 24 | bin/gotestsum: VERSION := 1.12.0 25 | bin/gotestsum: GITHUB := gotestyourself/gotestsum 26 | bin/gotestsum: ARCHIVE := gotestsum_$(VERSION)_$(OSTYPE)_$(ARCH).tar.gz 27 | bin/gotestsum: bin 28 | @ printf "Install gotestsum... " 29 | @ printf "$(github_url)\n" 30 | @ curl -Ls $(shell echo $(call github_url) | tr A-Z a-z) | tar -zOxf - gotestsum > $@ && chmod +x $@ 31 | @ echo "done." 32 | 33 | # ~~ [ tparse ] ~~~ https://github.com/mfridman/tparse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 34 | 35 | TPARSE := $(shell command -v tparse || echo "bin/tparse") 36 | tparse: bin/tparse ## Installs tparse (testing go code) 37 | 38 | # eg https://github.com/mfridman/tparse/releases/download/v0.13.2/tparse_darwin_arm64 39 | export TPARSE_ARCH := $(shell uname -m) 40 | bin/tparse: VERSION := 0.13.3 41 | bin/tparse: GITHUB := mfridman/tparse 42 | bin/tparse: ARCHIVE := tparse_$(OSTYPE)_$(TPARSE_ARCH) #this is custom 43 | bin/tparse: bin 44 | @ printf "Install tparse... " 45 | @ printf "$(github_url)\n" 46 | @ curl -Ls $(call github_url) > $@ && chmod +x $@ 47 | @ echo "done." 48 | 49 | # ~~ [ golangci-lint ] ~~~ https://github.com/golangci/golangci-lint ~~~~~~~~~~~~~~~~~~~~~ 50 | 51 | GOLANGCI := $(shell command -v golangci-lint || echo "bin/golangci-lint") 52 | golangci-lint: bin/golangci-lint ## Installs golangci-lint (linter) 53 | 54 | bin/golangci-lint: VERSION := 1.59.0 55 | bin/golangci-lint: GITHUB := golangci/golangci-lint 56 | bin/golangci-lint: ARCHIVE := golangci-lint-$(VERSION)-$(OSTYPE)-$(ARCH).tar.gz 57 | bin/golangci-lint: bin 58 | @ printf "Install golangci-linter... " 59 | @ printf "$(github_url)\n" 60 | @ curl -Ls $(shell echo $(call github_url) | tr A-Z a-z) | tar -zOxf - $(shell printf golangci-lint-$(VERSION)-$(OSTYPE)-$(ARCH)/golangci-lint | tr A-Z a-z ) > $@ && chmod +x $@ 61 | @ echo "done." -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | // Tx is a *sql.Tx wrapper. 9 | // Its main purpose is to be able to return the internal Stmt interface. 10 | type Tx interface { 11 | Commit() error 12 | Rollback() error 13 | Exec(query string, args ...interface{}) (sql.Result, error) 14 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 15 | Prepare(query string) (Stmt, error) 16 | PrepareContext(ctx context.Context, query string) (Stmt, error) 17 | Query(query string, args ...interface{}) (*sql.Rows, error) 18 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 19 | QueryRow(query string, args ...interface{}) *sql.Row 20 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 21 | Stmt(stmt Stmt) Stmt 22 | StmtContext(ctx context.Context, stmt Stmt) Stmt 23 | } 24 | 25 | type tx struct { 26 | sourceDB *sql.DB 27 | tx *sql.Tx 28 | } 29 | 30 | func (t *tx) Commit() error { 31 | return t.tx.Commit() 32 | } 33 | 34 | func (t *tx) Rollback() error { 35 | return t.tx.Rollback() 36 | } 37 | 38 | func (t *tx) Exec(query string, args ...interface{}) (sql.Result, error) { 39 | return t.ExecContext(context.Background(), query, args...) 40 | } 41 | 42 | func (t *tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 43 | return t.tx.ExecContext(ctx, query, args...) 44 | } 45 | 46 | func (t *tx) Prepare(query string) (Stmt, error) { 47 | return t.PrepareContext(context.Background(), query) 48 | } 49 | 50 | func (t *tx) PrepareContext(ctx context.Context, query string) (Stmt, error) { 51 | txstmt, err := t.tx.PrepareContext(ctx, query) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return newSingleDBStmt(t.sourceDB, txstmt, true), nil 57 | } 58 | 59 | func (t *tx) Query(query string, args ...interface{}) (*sql.Rows, error) { 60 | return t.QueryContext(context.Background(), query, args...) 61 | } 62 | 63 | func (t *tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 64 | return t.tx.QueryContext(ctx, query, args...) 65 | } 66 | 67 | func (t *tx) QueryRow(query string, args ...interface{}) *sql.Row { 68 | return t.QueryRowContext(context.Background(), query, args...) 69 | } 70 | 71 | func (t *tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 72 | return t.tx.QueryRowContext(ctx, query, args...) 73 | } 74 | 75 | func (t *tx) Stmt(s Stmt) Stmt { 76 | return t.StmtContext(context.Background(), s) 77 | } 78 | 79 | func (t *tx) StmtContext(ctx context.Context, s Stmt) Stmt { 80 | if rstmt, ok := s.(*stmt); ok { 81 | return newSingleDBStmt(t.sourceDB, t.tx.StmtContext(ctx, rstmt.stmtForDB(t.sourceDB)), true) 82 | } 83 | return s 84 | } 85 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | linters-settings: 2 | dupl: 3 | threshold: 100 4 | funlen: 5 | lines: 100 6 | statements: 50 7 | goconst: 8 | min-len: 2 9 | min-occurrences: 3 10 | gocritic: 11 | enabled-tags: 12 | - diagnostic 13 | - experimental 14 | - opinionated 15 | - performance 16 | - style 17 | disabled-checks: 18 | - dupImport # https://github.com/go-critic/go-critic/issues/845 19 | - ifElseChain 20 | - octalLiteral 21 | - whyNoLint 22 | - wrapperFunc 23 | gocyclo: 24 | min-complexity: 15 25 | goimports: 26 | local-prefixes: github.com/golangci/golangci-lint 27 | mnd: 28 | # don't include the "operation" and "assign" 29 | checks: 30 | - argument 31 | - case 32 | - condition 33 | - return 34 | ignored-numbers: 35 | - "0" 36 | - "1" 37 | - "2" 38 | - "3" 39 | ignored-functions: 40 | - strings.SplitN 41 | 42 | lll: 43 | line-length: 140 44 | misspell: 45 | locale: US 46 | nolintlint: 47 | allow-unused: false # report any unused nolint directives 48 | require-explanation: false # don't require an explanation for nolint directives 49 | require-specific: false # don't require nolint directives to be specific about which linter is being skipped 50 | 51 | linters: 52 | # Disable all linters. 53 | # Default: false 54 | disable-all: true 55 | # Enable specific linter 56 | # https://golangci-lint.run/usage/linters/#enabled-by-default-linters 57 | enable: 58 | - asciicheck 59 | - bodyclose 60 | # - deadcode #deprecated 61 | - depguard 62 | - dogsled 63 | - dupl 64 | - errcheck 65 | - exportloopref 66 | - funlen 67 | - gochecknoinits 68 | - goconst 69 | - gocritic 70 | - gocyclo 71 | - gofmt 72 | - goimports 73 | - mnd 74 | - goprintffuncname 75 | - gosec 76 | - gosimple 77 | - govet 78 | - ineffassign 79 | - lll 80 | - misspell 81 | - nakedret 82 | - noctx 83 | - nolintlint 84 | - staticcheck 85 | # - structcheck #deprecated 86 | - stylecheck 87 | - typecheck 88 | - unconvert 89 | - unparam 90 | - unused 91 | - revive 92 | # - varcheck #deprecated 93 | - whitespace 94 | 95 | # don't enable: 96 | # - asciicheck 97 | # - scopelint 98 | # - gochecknoglobals 99 | # - gocognit 100 | # - godot 101 | # - godox 102 | # - goerr113 103 | # - interfacer 104 | # - maligned 105 | # - nestif 106 | # - prealloc 107 | # - testpackage 108 | # - revive 109 | # - wsl 110 | 111 | issues: 112 | # Excluding configuration per-path, per-linter, per-text and per-source 113 | exclude-rules: 114 | - path: _test\.go 115 | linters: 116 | - gomnd 117 | - revive 118 | - depguard 119 | - path: db_test.go 120 | text: "deferInLoop: Possible resource leak, 'defer' is called in the 'for' loop" 121 | 122 | - path: \.go 123 | text: "commentedOutCode: may want to remove commented-out code" 124 | 125 | - path: \.go 126 | text: "commentFormatting: put a space between `//` and comment text" 127 | 128 | - path: _test.go 129 | linters: 130 | - whitespace 131 | text: "unnecessary (leading|trailing) newline" 132 | 133 | - path: db_test.go 134 | linters: 135 | - goconst 136 | - funlen 137 | - gocyclo 138 | - errcheck 139 | - gocritic 140 | - govet 141 | 142 | - path: fuzz_test.go 143 | linters: 144 | - goconst 145 | - funlen 146 | - gocyclo 147 | - errcheck 148 | - gocritic 149 | - govet 150 | 151 | - path: \.go 152 | text: "G404: Use of weak random number generator" #expected, just for randomLB policy 153 | 154 | run: 155 | timeout: 5m 156 | go: "1.22" 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dbresolver 2 | 3 | Golang Database Resolver and Wrapper for any multiple database connections topology, eg. master-slave replication database, cross-region application. 4 | 5 | [![Go](https://github.com/bxcodec/dbresolver/actions/workflows/go.yml/badge.svg?branch=main)](https://github.com/bxcodec/dbresolver/actions/workflows/go.yml) 6 | [![Go.Dev](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white)](https://pkg.go.dev/github.com/bxcodec/dbresolver/v2?tab=doc) 7 | 8 | ## Idea and Inspiration 9 | 10 | This DBResolver library will split your connections to correct defined DBs. Eg, all read query will routed to ReadOnly replica db, and all write operation(Insert, Update, Delete) will routed to Primary/Master DB. 11 | 12 | **Read More** 13 | |Items| Link| 14 | ------|-----| 15 | |Blogpost| [blog post](https://betterprogramming.pub/create-a-cross-region-rdbms-connection-library-with-dbresolver-5072bed6a7b8) | 16 | |Excalidraw| [diagram](https://excalidraw.com/#json=DTs8yxHOGF6uLkjnZny4z,RVo8iwhO0Rk6DRGkKuNZTg)| 17 | |GoSG Meetup Demo| [repository](https://github.com/bxcodec/dbresolver-examples) | 18 | | GoSG Presentation | [deck](https://www.canva.com/design/DAFgbpc7tfw/bEXVFtcHEnlFxKVBdnUggA/edit?utm_content=DAFgbpc7tfw&utm_campaign=designshare&utm_medium=link2&utm_source=sharebutton) | 19 | | Instagram | [post](https://www.instagram.com/p/CnlDFPsBAJG/?utm_source=ig_web_copy_link&igsh=MzRlODBiNWFlZA==)| 20 | 21 | ### Usecase 1: Separated RW and RO Database connection 22 | 23 |
24 | 25 | Click to Expand 26 | 27 | - You have your application deployed 28 | - Your application is heavy on read operations 29 | - Your DBs replicated to multiple replicas for faster queries 30 | - You separate the connections for optimized query 31 | - ![readonly-readwrite](https://user-images.githubusercontent.com/11002383/206952018-dd393059-c42c-4ffc-913a-f21c3870bd80.png) 32 | 33 |
34 | 35 | ### Usecase 2: Cross Region Database 36 | 37 |
38 | 39 | Click to Expand 40 | 41 | - Your application deployed to multi regions. 42 | - You have your Databases configured globally. 43 | - ![cross-region](https://user-images.githubusercontent.com/11002383/206952598-ed21a6f8-5542-4f26-aaa6-67d9c2aa5940.png) 44 | 45 |
46 | 47 | ### Usecase 3: Multi-Master (Multi-Primary) Database 48 | 49 |
50 | 51 | Click to Expand 52 | 53 | - You're using a Multi-Master database topology eg, Aurora Multi-Master 54 | - ![multi-master](https://user-images.githubusercontent.com/11002383/206953082-c2b1bfa8-050e-4a6e-88e8-e5c7047edd71.png) 55 | 56 |
57 | 58 | ## Support 59 | 60 | You can file an [Issue](https://github.com/bxcodec/dbresolver/issues/new). 61 | See documentation in [Go.Dev](https://pkg.go.dev/github.com/bxcodec/dbresolver/v2?tab=doc) 62 | 63 | ## Getting Started 64 | 65 | #### Download 66 | 67 | ```shell 68 | go get -u github.com/bxcodec/dbresolver/v2 69 | ``` 70 | 71 | # Example 72 | 73 | ### Implementing DB Resolver using \*sql.DB 74 | 75 |
76 | 77 | Click to Expand 78 | 79 | ```go 80 | package main 81 | 82 | import ( 83 | "context" 84 | "database/sql" 85 | "fmt" 86 | "log" 87 | 88 | "github.com/bxcodec/dbresolver/v2" 89 | _ "github.com/lib/pq" 90 | ) 91 | 92 | func main() { 93 | var ( 94 | host1 = "localhost" 95 | port1 = 5432 96 | user1 = "postgresrw" 97 | password1 = "" 98 | host2 = "localhost" 99 | port2 = 5433 100 | user2 = "postgresro" 101 | password2 = "" 102 | dbname = "" 103 | ) 104 | // connection string 105 | rwPrimary := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", host1, port1, user1, password1, dbname) 106 | readOnlyReplica := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", host2, port2, user2, password2, dbname) 107 | 108 | // open database for primary 109 | dbPrimary, err := sql.Open("postgres", rwPrimary) 110 | if err != nil { 111 | log.Print("go error when connecting to the DB") 112 | } 113 | // configure the DBs for other setup eg, tracing, etc 114 | // eg, tracing.Postgres(dbPrimary) 115 | 116 | // open database for replica 117 | dbReadOnlyReplica, err := sql.Open("postgres", readOnlyReplica) 118 | if err != nil { 119 | log.Print("go error when connecting to the DB") 120 | } 121 | // configure the DBs for other setup eg, tracing, etc 122 | // eg, tracing.Postgres(dbReadOnlyReplica) 123 | 124 | connectionDB := dbresolver.New( 125 | dbresolver.WithPrimaryDBs(dbPrimary), 126 | dbresolver.WithReplicaDBs(dbReadOnlyReplica), 127 | dbresolver.WithLoadBalancer(dbresolver.RoundRobinLB)) 128 | 129 | defer connectionDB.Close() 130 | // now you can use the connection for all DB operation 131 | _, err = connectionDB.ExecContext(context.Background(), "DELETE FROM book WHERE id=$1") // will use primaryDB 132 | if err != nil { 133 | log.Print("go error when executing the query to the DB", err) 134 | } 135 | connectionDB.QueryRowContext(context.Background(), "SELECT * FROM book WHERE id=$1") // will use replicaReadOnlyDB 136 | } 137 | ``` 138 | 139 |
140 | 141 | ## Important Notes 142 | 143 | - Primary Database will be used when you call these functions 144 | - `Exec` 145 | - `ExecContext` 146 | - `Begin` (transaction will use primary) 147 | - `BeginTx` 148 | - Queries with `"RETURNING"` clause 149 | - `Query` 150 | - `QueryContext` 151 | - `QueryRow` 152 | - `QueryRowContext` 153 | - Replica Databases will be used when you call these functions 154 | - `Query` 155 | - `QueryContext` 156 | - `QueryRow` 157 | - `QueryRowContext` 158 | 159 | ## Contribution 160 | 161 | To contrib to this project, you can open a PR or an issue. 162 | -------------------------------------------------------------------------------- /stmt.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "go.uber.org/multierr" 8 | ) 9 | 10 | // Stmt is an aggregate prepared statement. 11 | // It holds a prepared statement for each underlying physical db. 12 | type Stmt interface { 13 | Close() error 14 | Exec(...interface{}) (sql.Result, error) 15 | ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) 16 | Query(...interface{}) (*sql.Rows, error) 17 | QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) 18 | QueryRow(args ...interface{}) *sql.Row 19 | QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row 20 | } 21 | 22 | type stmt struct { 23 | loadBalancer StmtLoadBalancer 24 | primaryStmts []*sql.Stmt 25 | replicaStmts []*sql.Stmt 26 | writeFlag bool 27 | dbStmt map[*sql.DB]*sql.Stmt 28 | } 29 | 30 | // Close closes the statement by concurrently closing all underlying 31 | // statements concurrently, returning the first non nil error. 32 | func (s *stmt) Close() error { 33 | errPrimaries := doParallely(len(s.primaryStmts), func(i int) error { 34 | return s.primaryStmts[i].Close() 35 | }) 36 | errReplicas := doParallely(len(s.replicaStmts), func(i int) error { 37 | return s.replicaStmts[i].Close() 38 | }) 39 | 40 | return multierr.Combine(errPrimaries, errReplicas) 41 | } 42 | 43 | // Exec executes a prepared statement with the given arguments 44 | // and returns a Result summarizing the effect of the statement. 45 | // Exec uses the master as the underlying physical db. 46 | func (s *stmt) Exec(args ...interface{}) (sql.Result, error) { 47 | return s.ExecContext(context.Background(), args...) 48 | } 49 | 50 | // ExecContext executes a prepared statement with the given arguments 51 | // and returns a Result summarizing the effect of the statement. 52 | // Exec uses the master as the underlying physical db. 53 | func (s *stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { 54 | return s.RWStmt().ExecContext(ctx, args...) 55 | } 56 | 57 | // Query executes a prepared query statement with the given 58 | // arguments and returns the query results as a *sql.Rows. 59 | // Query uses the read only DB as the underlying physical db. 60 | func (s *stmt) Query(args ...interface{}) (*sql.Rows, error) { 61 | return s.QueryContext(context.Background(), args...) 62 | } 63 | 64 | // QueryContext executes a prepared query statement with the given 65 | // arguments and returns the query results as a *sql.Rows. 66 | // Query uses the read only DB as the underlying physical db. 67 | func (s *stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { 68 | var curStmt *sql.Stmt 69 | if s.writeFlag { 70 | curStmt = s.RWStmt() 71 | } else { 72 | curStmt = s.ROStmt() 73 | } 74 | 75 | rows, err := curStmt.QueryContext(ctx, args...) 76 | if isDBConnectionError(err) && !s.writeFlag { 77 | rows, err = s.RWStmt().QueryContext(ctx, args...) 78 | } 79 | return rows, err 80 | } 81 | 82 | // QueryRow executes a prepared query statement with the given arguments. 83 | // If an error occurs during the execution of the statement, that error 84 | // will be returned by a call to Scan on the returned *Row, which is always non-nil. 85 | // If the query selects no rows, the *Row's Scan will return ErrNoRows. 86 | // Otherwise, the *sql.Row's Scan scans the first selected row and discards the rest. 87 | // QueryRow uses the read only DB as the underlying physical db. 88 | func (s *stmt) QueryRow(args ...interface{}) *sql.Row { 89 | return s.QueryRowContext(context.Background(), args...) 90 | } 91 | 92 | // QueryRowContext executes a prepared query statement with the given arguments. 93 | // If an error occurs during the execution of the statement, that error 94 | // will be returned by a call to Scan on the returned *Row, which is always non-nil. 95 | // If the query selects no rows, the *Row's Scan will return ErrNoRows. 96 | // Otherwise, the *sql.Row's Scan scans the first selected row and discards the rest. 97 | // QueryRowContext uses the read only DB as the underlying physical db. 98 | func (s *stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { 99 | var curStmt *sql.Stmt 100 | if s.writeFlag { 101 | curStmt = s.RWStmt() 102 | } else { 103 | curStmt = s.ROStmt() 104 | } 105 | 106 | row := curStmt.QueryRowContext(ctx, args...) 107 | if isDBConnectionError(row.Err()) && !s.writeFlag { 108 | row = s.RWStmt().QueryRowContext(ctx, args...) 109 | } 110 | return row 111 | } 112 | 113 | // ROStmt return the replica statement 114 | func (s *stmt) ROStmt() *sql.Stmt { 115 | totalStmtsConn := len(s.replicaStmts) + len(s.primaryStmts) 116 | if totalStmtsConn == len(s.primaryStmts) { 117 | return s.loadBalancer.Resolve(s.primaryStmts) 118 | } 119 | return s.loadBalancer.Resolve(s.replicaStmts) 120 | } 121 | 122 | // RWStmt return the primary statement 123 | func (s *stmt) RWStmt() *sql.Stmt { 124 | return s.loadBalancer.Resolve(s.primaryStmts) 125 | } 126 | 127 | // stmtForDB returns the corresponding *sql.Stmt instance for the given *sql.DB. 128 | // Ihis is needed because sql.Tx.Stmt() requires that the passed *sql.Stmt be from the same database 129 | // as the transaction. 130 | func (s *stmt) stmtForDB(db *sql.DB) *sql.Stmt { 131 | xsm, ok := s.dbStmt[db] 132 | if ok { 133 | return xsm 134 | } 135 | 136 | // return any statement so errors can be detected by Tx.Stmt() 137 | return s.RWStmt() 138 | } 139 | 140 | // newSingleDBStmt creates a new stmt for a single DB connection. 141 | // This is used by statements return by transaction and connections. 142 | func newSingleDBStmt(sourceDB *sql.DB, st *sql.Stmt, writeFlag bool) *stmt { 143 | return &stmt{ 144 | loadBalancer: &RoundRobinLoadBalancer[*sql.Stmt]{}, 145 | primaryStmts: []*sql.Stmt{st}, 146 | dbStmt: map[*sql.DB]*sql.Stmt{ 147 | sourceDB: st, 148 | }, 149 | writeFlag: writeFlag, 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /.github/workflows/go_fuzz.yml: -------------------------------------------------------------------------------- 1 | name: Go Fuzz 2 | 3 | on: 4 | push: 5 | branches: 6 | - "*" 7 | - "**" 8 | pull_request: 9 | branches: 10 | - "main" 11 | workflow_dispatch: 12 | 13 | schedule: 14 | - cron: "0 * */1 * *" 15 | 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.ref }} 18 | cancel-in-progress: true 19 | 20 | #TODO: minimize jobs with all having -race and then of course committing testdata/ or o/p testdata into log in case of failure 21 | env: 22 | GOCACHE: /tmp/go/gocache 23 | GOBIN: ${{ github.workspace }}/bin 24 | 25 | jobs: 26 | setup: 27 | runs-on: ubuntu-latest 28 | timeout-minutes: 8 29 | 30 | steps: 31 | - uses: actions/checkout@v3 32 | 33 | - name: Set up Go 34 | uses: actions/setup-go@v4 35 | with: 36 | go-version: 1.20.4 37 | check-latest: true 38 | cache-dependency-path: go.sum 39 | 40 | # - name: Cache Go 41 | # uses: actions/cache@v3 42 | # with: 43 | # path: | 44 | # ${{ env.GOCACHE }} 45 | # ${{ env.GOBIN }} 46 | # key: ${{ github.workflow }}-${{ runner.os }}-${{ hashFiles('*_test.go') }} 47 | # restore-keys: | 48 | # ${{ github.workflow }}-${{ runner.os }}-${{ hashFiles('*_test.go') }} 49 | 50 | - name: Build 51 | timeout-minutes: 2 52 | run: | 53 | go build -v 54 | echo "${{ env.GOBIN }}" >> $GITHUB_PATH 55 | 56 | - name: Go Generate 57 | run: | 58 | go generate 59 | # go generate fuzz.go 60 | - name: Upload Artifacts 61 | uses: actions/upload-artifact@v4 62 | with: 63 | name: go-test-utils 64 | path: ${{ env.GOBIN }} 65 | 66 | - name: Test Fuzz Functions 67 | run: | 68 | go test -cover -covermode=atomic -timeout=8m -race -run="Fuzz*" -json -short | \ 69 | tparse -follow -all -sort=elapsed 70 | 71 | fuzz: 72 | needs: [setup] 73 | runs-on: ubuntu-latest 74 | 75 | env: 76 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 77 | 78 | permissions: 79 | contents: write 80 | pull-requests: write 81 | issues: read 82 | packages: none 83 | 84 | steps: 85 | - uses: actions/checkout@v3 86 | 87 | # - name: Set up Go 88 | # uses: actions/setup-go@v4 89 | # with: 90 | # go-version: 1.20.4 91 | # check-latest: true 92 | # 93 | # - name: Go Generate #Artifacts are not reliable 94 | # continue-on-error: true 95 | # run: | 96 | # echo "${{ env.GOBIN }}" >> $GITHUB_PATH 97 | # go generate 98 | 99 | - name: Download Artifacts 100 | uses: actions/download-artifact@v4 101 | with: 102 | name: go-test-utils 103 | path: ${{ env.GOBIN }} 104 | 105 | - run: chmod +x ${{ env.GOBIN }}/* 106 | 107 | - name: FuzzMultiWrite 108 | continue-on-error: true 109 | run: | 110 | go test -json -short -fuzztime=3m -timeout=15m -cover github.com/bxcodec/dbresolver/v2 -fuzz=FuzzMultiWrite -covermode=count -run ^$ | \ 111 | bin/tparse -follow -all -sort=elapsed 112 | 113 | # TODO: similarly write for every fuzz function 114 | 115 | - name: Check file existence 116 | id: check_files 117 | uses: andstor/file-existence-action@v1 118 | with: 119 | files: "testdata/fuzz" 120 | 121 | - name: Run All Corpus 122 | continue-on-error: true 123 | if: steps.check_files.outputs.files_exists == 'true' 124 | run: | 125 | for dir in $(find testdata/fuzz/* -type d); do 126 | echo "Walking $dir" 127 | for file in $(find "$dir" -type f); do 128 | echo "Running $file" 129 | go test -run="$(basename "$dir")/$(basename "$file")" -v 130 | rm "$file" 131 | done 132 | rm -r "$dir" 133 | done 134 | rm -r testdata/fuzz 135 | 136 | - name: Collect testdata 137 | continue-on-error: true 138 | # if: ${{ failure() }} || true 139 | if: github.event_name == 'push' && github.event.pull_request == null 140 | run: | 141 | if [ -d "testdata" ]; then 142 | echo "Fuzz tests have failed" 143 | git config --global user.email laciferin@gmail.com 144 | git config --global user.name GithubActions 145 | git add -f testdata 146 | git commit -m "ci: fuzz tests updated on $date" 147 | git push 148 | else 149 | echo "All Fuzz Tests have passed" 150 | fi 151 | 152 | - name: Upload TestCases 153 | uses: actions/upload-artifact@v4 154 | if: steps.check_files.outputs.files_exists == 'true' 155 | with: 156 | name: go-fuzz-testdata 157 | path: ${{ github.workspace }}/testdata 158 | 159 | - name: Fail Test 160 | if: steps.check_files.outputs.files_exists == 'true' 161 | run: | 162 | if [ -d "testdata/fuzz" ]; then 163 | echo "Failing this run" 164 | exit 1 165 | else 166 | echo "testdata dir present" 167 | echo "fuzz tests have passed on 2nd run" 168 | fi 169 | 170 | # fuzz: 171 | # needs: [ build, fuzz-multiwrite ] #will fail if more than 1 fuzz function is present 172 | # runs-on: ubuntu-latest 173 | # 174 | # steps: 175 | # - uses: actions/checkout@v3 176 | # 177 | # - name: Download Artifacts 178 | # uses: actions/download-artifact@v2 179 | # with: 180 | # name: go-test-utils 181 | # path: ${{ env.GOBIN }} 182 | # 183 | # - run: chmod +x bin/* 184 | # - name: Fuzz Short 185 | # run: | 186 | # # go test -fuzz=Fuzz -short -v -fuzztime=1m -timeout=15m -cover -run="Fuzz*" 187 | # go test -fuzz=Fuzz -fuzztime=1m -timeout=15m -cover -covermode=count -run="Fuzz*" -json | \ 188 | # bin/tparse -follow -all -sort=elapsed 189 | # 190 | # 191 | # 192 | # race-fuzz: 193 | # runs-on: ubuntu-latest 194 | # needs: [ fuzz ] 195 | # 196 | # steps: 197 | # - uses: actions/checkout@v3 198 | # 199 | # - name: Race Short Fuzz 200 | # continue-on-error: true 201 | # run: | 202 | # go test -fuzz=Fuzz -short -race -v -fuzztime=30s -timeout=15m -cover -covermode=count -run="Fuzz*" 203 | # 204 | # - name: Fuzz normalize 205 | # if: ${{ failure() }} 206 | # uses: nick-fields/retry@v2 207 | # with: 208 | # max_attempts: 10 209 | # retry_on: error 210 | # timeout_minutes: 360m 211 | # #working-directory: ${{ github.workspace }} 212 | # command: | 213 | # echo "go fuzz intensive failed" 214 | 215 | # Fails if multiple Fuzz Functions match 216 | # go test -fuzz=Fuzz -fuzztime=30s -cover -covermode=count -run="Fuzz*" -json -short | \ 217 | # tparse -follow -all -sort=elapsed 218 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "testing" 8 | "time" 9 | 10 | "github.com/DATA-DOG/go-sqlmock" 11 | ) 12 | 13 | type DBConfig struct { 14 | primaryDBCount uint8 15 | replicaDBCount uint8 16 | lbPolicy LoadBalancerPolicy 17 | } 18 | 19 | var LoadBalancerPolicies = []LoadBalancerPolicy{ 20 | RandomLB, 21 | RoundRobinLB, 22 | } 23 | 24 | func handleDBError(t *testing.T, err error) { 25 | if err != nil { 26 | t.Errorf("db error: %s", err) 27 | } 28 | 29 | } 30 | 31 | func testMW(t *testing.T, config DBConfig) { 32 | 33 | noOfPrimaries, noOfReplicas := int(config.primaryDBCount), int(config.replicaDBCount) 34 | lbPolicy := config.lbPolicy 35 | 36 | primaries := make([]*sql.DB, noOfPrimaries) 37 | replicas := make([]*sql.DB, noOfReplicas) 38 | 39 | mockPimaries := make([]sqlmock.Sqlmock, noOfPrimaries) 40 | mockReplicas := make([]sqlmock.Sqlmock, noOfReplicas) 41 | 42 | for i := 0; i < noOfPrimaries; i++ { 43 | db, mock, err := createMock() 44 | 45 | if err != nil { 46 | t.Fatal("creating of mock failed") 47 | } 48 | 49 | primaries[i] = db 50 | mockPimaries[i] = mock 51 | } 52 | 53 | for i := 0; i < noOfReplicas; i++ { 54 | db, mock, err := createMock() 55 | if err != nil { 56 | t.Fatal("creating of mock failed") 57 | } 58 | 59 | replicas[i] = db 60 | mockReplicas[i] = mock 61 | } 62 | 63 | resolver := New(WithPrimaryDBs(primaries...), WithReplicaDBs(replicas...), WithLoadBalancer(lbPolicy)).(*sqlDB) 64 | 65 | t.Run("primary dbs", func(t *testing.T) { 66 | var err error 67 | 68 | for i := 0; i < noOfPrimaries*6; i++ { 69 | robin := resolver.loadBalancer.predict(noOfPrimaries) 70 | mock := mockPimaries[robin] 71 | 72 | switch i % 6 { 73 | case 0: 74 | query := "SET timezone TO 'Asia/Tokyo'" 75 | mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(0, 0)) 76 | _, err = resolver.Exec(query) 77 | case 1: 78 | query := "CREATE DATABASE test; use test" 79 | mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(0, 0)).WillDelayFor(time.Millisecond * 50) 80 | _, err = resolver.ExecContext(context.Background(), query) 81 | case 2: 82 | t.Log("transactions:begin") 83 | 84 | mock.ExpectBegin() 85 | tx, err := resolver.Begin() 86 | handleDBError(t, err) 87 | 88 | query := `CREATE TABLE users (id serial PRIMARY KEY, name varchar(50) unique)` 89 | mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(0, 0)) 90 | 91 | _, err = tx.Exec(query) 92 | handleDBError(t, err) 93 | 94 | mock.ExpectCommit() 95 | tx.Commit() 96 | 97 | case 3: 98 | t.Log("tx: query-return clause") 99 | 100 | mock.ExpectBegin() 101 | tx, err1 := resolver.BeginTx(context.TODO(), &sql.TxOptions{ 102 | Isolation: sql.LevelDefault, 103 | ReadOnly: false, 104 | }) 105 | handleDBError(t, err1) 106 | 107 | query := "INSERT INTO users(id,name) VALUES ($1,$2) RETURNING id" 108 | mock.ExpectQuery(query). 109 | WithArgs(1, "Hiro"). 110 | WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) 111 | 112 | _, err = tx.Query(query, 1, "Hiro") 113 | 114 | mock.ExpectCommit() 115 | tx.Commit() 116 | case 4: 117 | query := `UPDATE users SET name='Hiro' where id=1 RETURNING id,name` 118 | mock.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) 119 | _, err = resolver.Query(query) 120 | 121 | case 5: 122 | query := `delete from users where id=1 returning id,name` 123 | mock.ExpectQuery(query).WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) 124 | resolver.QueryRow(query) 125 | default: 126 | t.Fatal("developer needs to work on the tests") 127 | } 128 | 129 | handleDBError(t, err) 130 | 131 | if err := mock.ExpectationsWereMet(); err != nil { 132 | t.Skipf("sqlmock:unmet expectations: %s", err) 133 | } 134 | } 135 | }) 136 | 137 | t.Run("replica dbs", func(t *testing.T) { 138 | 139 | var query string 140 | 141 | for i := 0; i < noOfReplicas*5; i++ { 142 | robin := resolver.loadBalancer.predict(noOfReplicas) 143 | mock := mockReplicas[robin] 144 | 145 | switch i % 4 { 146 | case 0: 147 | query = "select '1'" 148 | mock.ExpectQuery(query) 149 | resolver.Query(query) 150 | case 1: 151 | query := "select 'row'" 152 | mock.ExpectQuery(query) 153 | resolver.QueryRow(query) 154 | case 2: 155 | query = "select 'query-ctx' " 156 | mock.ExpectQuery(query) 157 | resolver.QueryContext(context.TODO(), query) 158 | case 3: 159 | query = "select 'row'" 160 | mock.ExpectQuery(query) 161 | resolver.QueryRowContext(context.TODO(), query) 162 | } 163 | if err := mock.ExpectationsWereMet(); err != nil { 164 | t.Logf("failed query-%s", query) 165 | t.Skipf("sqlmock:unmet expectations: %s", err) 166 | } 167 | } 168 | }) 169 | 170 | t.Run("prepare", func(t *testing.T) { 171 | query := "select 1" 172 | 173 | for _, mock := range mockPimaries { 174 | mock.ExpectPrepare(query) 175 | defer func(mock sqlmock.Sqlmock) { 176 | if err := mock.ExpectationsWereMet(); err != nil { 177 | t.Errorf("sqlmock:unmet expectations: %s", err) 178 | } 179 | }(mock) 180 | } 181 | for _, mock := range mockReplicas { 182 | mock.ExpectPrepare(query) 183 | defer func(mock sqlmock.Sqlmock) { 184 | if err := mock.ExpectationsWereMet(); err != nil { 185 | t.Errorf("sqlmock:unmet expectations: %s", err) 186 | } 187 | }(mock) 188 | } 189 | 190 | stmt, err := resolver.Prepare(query) 191 | if err != nil { 192 | t.Error("prepare failed") 193 | return 194 | } 195 | 196 | robin := resolver.stmtLoadBalancer.predict(noOfPrimaries) 197 | mock := mockPimaries[robin] 198 | 199 | mock.ExpectExec(query) 200 | 201 | stmt.Exec() 202 | }) 203 | 204 | t.Run("prepare tx", func(t *testing.T) { 205 | query := "select 1" 206 | 207 | for _, mock := range mockPimaries { 208 | mock.ExpectPrepare(query) 209 | defer func(mock sqlmock.Sqlmock) { 210 | if err := mock.ExpectationsWereMet(); err != nil { 211 | t.Errorf("sqlmock:unmet expectations: %s", err) 212 | } 213 | }(mock) 214 | } 215 | for _, mock := range mockReplicas { 216 | mock.ExpectPrepare(query) 217 | defer func(mock sqlmock.Sqlmock) { 218 | if err := mock.ExpectationsWereMet(); err != nil { 219 | t.Errorf("sqlmock:unmet expectations: %s", err) 220 | } 221 | }(mock) 222 | } 223 | 224 | stmt, err := resolver.Prepare(query) 225 | if err != nil { 226 | t.Error("prepare failed") 227 | return 228 | } 229 | 230 | robin := resolver.loadBalancer.predict(noOfPrimaries) 231 | mock := mockPimaries[robin] 232 | 233 | mock.ExpectBegin() 234 | 235 | tx, err := resolver.Begin() 236 | if err != nil { 237 | t.Error("begin failed", err) 238 | return 239 | } 240 | 241 | txstmt := tx.Stmt(stmt) 242 | 243 | mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(0, 0)) 244 | _, err = txstmt.Exec() 245 | if err != nil { 246 | t.Error("stmt exec failed", err) 247 | return 248 | } 249 | 250 | mock.ExpectCommit() 251 | tx.Commit() 252 | }) 253 | 254 | t.Run("ping", func(t *testing.T) { 255 | for _, mock := range mockPimaries { 256 | mock.ExpectPing() 257 | mock.ExpectPing() 258 | defer func(mock sqlmock.Sqlmock) { 259 | if err := mock.ExpectationsWereMet(); err != nil { 260 | t.Errorf("sqlmock:unmet expectations: %s", err) 261 | } 262 | }(mock) 263 | } 264 | for _, mock := range mockReplicas { 265 | mock.ExpectPing() 266 | mock.ExpectPing() 267 | defer func(mock sqlmock.Sqlmock) { 268 | if err := mock.ExpectationsWereMet(); err != nil { 269 | t.Errorf("sqlmock:unmet expectations: %s", err) 270 | } 271 | }(mock) 272 | } 273 | 274 | err := resolver.Ping() 275 | if err != nil { 276 | t.Errorf("ping failed %s", err) 277 | } 278 | err = resolver.PingContext(context.TODO()) 279 | if err != nil { 280 | t.Errorf("ping failed %s", err) 281 | } 282 | }) 283 | 284 | t.Run("close", func(t *testing.T) { 285 | for _, mock := range mockPimaries { 286 | mock.ExpectClose() 287 | } 288 | for _, mock := range mockReplicas { 289 | mock.ExpectClose() 290 | } 291 | err := resolver.Close() 292 | handleDBError(t, err) 293 | 294 | t.Logf("closed:DB-CLUSTER-%dP%dR", noOfPrimaries, noOfReplicas) 295 | }) 296 | 297 | } 298 | 299 | func TestMultiWrite(t *testing.T) { 300 | t.Parallel() 301 | 302 | loadBalancerPolices := []LoadBalancerPolicy{ 303 | RoundRobinLB, 304 | RandomLB, 305 | } 306 | 307 | retrieveLoadBalancer := func() (loadBalancerPolicy LoadBalancerPolicy) { 308 | loadBalancerPolicy = loadBalancerPolices[0] 309 | loadBalancerPolices = loadBalancerPolices[1:] 310 | return 311 | } 312 | 313 | BEGIN_TEST: 314 | loadBalancerPolicy := retrieveLoadBalancer() 315 | 316 | t.Logf("LoadBalancer-%s", loadBalancerPolicy) 317 | 318 | testCases := []DBConfig{ 319 | {1, 0, ""}, 320 | {1, 1, ""}, 321 | {1, 2, ""}, 322 | {1, 10, ""}, 323 | {2, 0, ""}, 324 | {2, 1, ""}, 325 | {3, 0, ""}, 326 | {3, 1, ""}, 327 | {3, 2, ""}, 328 | {3, 3, ""}, 329 | {3, 6, ""}, 330 | {5, 6, ""}, 331 | {7, 20, ""}, 332 | {10, 10, ""}, 333 | {10, 20, ""}, 334 | } 335 | 336 | retrieveTestCase := func() DBConfig { 337 | testCase := testCases[0] 338 | testCases = testCases[1:] 339 | return testCase 340 | } 341 | 342 | BEGIN_TEST_CASE: 343 | if len(testCases) == 0 { 344 | if len(loadBalancerPolices) == 0 { 345 | return 346 | } 347 | goto BEGIN_TEST 348 | } 349 | 350 | dbConfig := retrieveTestCase() 351 | 352 | dbConfig.lbPolicy = loadBalancerPolicy 353 | 354 | t.Run(fmt.Sprintf("DBCluster P%dR%d", dbConfig.primaryDBCount, dbConfig.replicaDBCount), func(t *testing.T) { 355 | testMW(t, dbConfig) 356 | }) 357 | 358 | if testing.Short() { 359 | return 360 | } 361 | 362 | goto BEGIN_TEST_CASE 363 | } 364 | 365 | func createMock() (db *sql.DB, mock sqlmock.Sqlmock, err error) { 366 | db, mock, err = sqlmock.New(sqlmock.MonitorPingsOption(true), sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) 367 | return 368 | } 369 | 370 | type QueryMatcher struct { 371 | } 372 | 373 | func (*QueryMatcher) Match(expectedSQL, actualSQL string) error { 374 | return nil 375 | } 376 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "go.uber.org/multierr" 12 | ) 13 | 14 | // DB interface is a contract that supported by this library. 15 | // All offered function of this library defined here. 16 | // This supposed to be aligned with sql.DB, but since some of the functions is not relevant 17 | // with multi dbs connection, we decided to forward all single connection DB related function to the first primary DB 18 | // For example, function like, `Conn()“, or `Stats()` only available for the primary DB, or the first primary DB (if using multi-primary) 19 | type DB interface { 20 | Begin() (Tx, error) 21 | BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) 22 | Close() error 23 | // Conn only available for the primary db or the first primary db (if using multi-primary) 24 | Conn(ctx context.Context) (Conn, error) 25 | Driver() driver.Driver 26 | Exec(query string, args ...interface{}) (sql.Result, error) 27 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 28 | Ping() error 29 | PingContext(ctx context.Context) error 30 | Prepare(query string) (Stmt, error) 31 | PrepareContext(ctx context.Context, query string) (Stmt, error) 32 | Query(query string, args ...interface{}) (*sql.Rows, error) 33 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 34 | QueryRow(query string, args ...interface{}) *sql.Row 35 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 36 | SetConnMaxIdleTime(d time.Duration) 37 | SetConnMaxLifetime(d time.Duration) 38 | SetMaxIdleConns(n int) 39 | SetMaxOpenConns(n int) 40 | PrimaryDBs() []*sql.DB 41 | ReplicaDBs() []*sql.DB 42 | // Stats only available for the primary db or the first primary db (if using multi-primary) 43 | Stats() sql.DBStats 44 | } 45 | 46 | // DBLoadBalancer is loadbalancer for physical DBs 47 | type DBLoadBalancer LoadBalancer[*sql.DB] 48 | 49 | // StmtLoadBalancer is loadbalancer for query prepared statements 50 | type StmtLoadBalancer LoadBalancer[*sql.Stmt] 51 | 52 | // sqlDB is a logical database with multiple underlying physical databases 53 | // forming a single ReadWrite (primary) with multiple ReadOnly(replicas) db. 54 | // Reads and writes are automatically directed to the correct db connection 55 | 56 | type sqlDB struct { 57 | primaries []*sql.DB 58 | replicas []*sql.DB 59 | loadBalancer DBLoadBalancer 60 | stmtLoadBalancer StmtLoadBalancer 61 | queryTypeChecker QueryTypeChecker 62 | } 63 | 64 | // PrimaryDBs return all the active primary DB 65 | func (db *sqlDB) PrimaryDBs() []*sql.DB { 66 | return db.primaries 67 | } 68 | 69 | // ReplicaDBs return all the active replica DB 70 | func (db *sqlDB) ReplicaDBs() []*sql.DB { 71 | return db.replicas 72 | } 73 | 74 | // Close closes all physical databases concurrently, releasing any open resources. 75 | func (db *sqlDB) Close() error { 76 | errPrimaries := doParallely(len(db.primaries), func(i int) error { 77 | return db.primaries[i].Close() 78 | }) 79 | errReplicas := doParallely(len(db.replicas), func(i int) error { 80 | return db.replicas[i].Close() 81 | }) 82 | return multierr.Combine(errPrimaries, errReplicas) 83 | } 84 | 85 | // Driver returns the physical database's underlying driver. 86 | func (db *sqlDB) Driver() driver.Driver { 87 | return db.ReadWrite().Driver() 88 | } 89 | 90 | // Begin starts a transaction on the RW-db. The isolation level is dependent on the driver. 91 | func (db *sqlDB) Begin() (Tx, error) { 92 | return db.BeginTx(context.Background(), nil) 93 | } 94 | 95 | // BeginTx starts a transaction with the provided context on the RW-db. 96 | // 97 | // The provided TxOptions is optional and may be nil if defaults should be used. 98 | // If a non-default isolation level is used that the driver doesn't support, 99 | // an error will be returned. 100 | func (db *sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { 101 | sourceDB := db.ReadWrite() 102 | 103 | stx, err := sourceDB.BeginTx(ctx, opts) 104 | if err != nil { 105 | return nil, err 106 | } 107 | 108 | return &tx{ 109 | sourceDB: sourceDB, 110 | tx: stx, 111 | }, nil 112 | } 113 | 114 | // Exec executes a query without returning any rows. 115 | // The args are for any placeholder parameters in the query. 116 | // Exec uses the RW-database as the underlying db connection 117 | func (db *sqlDB) Exec(query string, args ...interface{}) (sql.Result, error) { 118 | return db.ExecContext(context.Background(), query, args...) 119 | } 120 | 121 | // ExecContext executes a query without returning any rows. 122 | // The args are for any placeholder parameters in the query. 123 | // Exec uses the RW-database as the underlying db connection 124 | func (db *sqlDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 125 | return db.ReadWrite().ExecContext(ctx, query, args...) 126 | } 127 | 128 | // Ping verifies if a connection to each physical database is still alive, 129 | // establishing a connection if necessary. 130 | func (db *sqlDB) Ping() error { 131 | return db.PingContext(context.Background()) 132 | } 133 | 134 | // PingContext verifies if a connection to each physical database is still 135 | // alive, establishing a connection if necessary. 136 | func (db *sqlDB) PingContext(ctx context.Context) error { 137 | errPrimaries := doParallely(len(db.primaries), func(i int) error { 138 | return db.primaries[i].PingContext(ctx) 139 | }) 140 | errReplicas := doParallely(len(db.replicas), func(i int) error { 141 | return db.replicas[i].PingContext(ctx) 142 | }) 143 | return multierr.Combine(errPrimaries, errReplicas) 144 | } 145 | 146 | // Prepare creates a prepared statement for later queries or executions 147 | // on each physical database, concurrently. 148 | func (db *sqlDB) Prepare(query string) (_stmt Stmt, err error) { 149 | return db.PrepareContext(context.Background(), query) 150 | } 151 | 152 | // PrepareContext creates a prepared statement for later queries or executions 153 | // on each physical database, concurrently. 154 | // 155 | // The provided context is used for the preparation of the statement, not for 156 | // the execution of the statement. 157 | func (db *sqlDB) PrepareContext(ctx context.Context, query string) (_stmt Stmt, err error) { 158 | dbStmt := map[*sql.DB]*sql.Stmt{} 159 | var dbStmtLock sync.Mutex 160 | roStmts := make([]*sql.Stmt, len(db.replicas)) 161 | primaryStmts := make([]*sql.Stmt, len(db.primaries)) 162 | errPrimaries := doParallely(len(db.primaries), func(i int) (err error) { 163 | primaryStmts[i], err = db.primaries[i].PrepareContext(ctx, query) 164 | dbStmtLock.Lock() 165 | dbStmt[db.primaries[i]] = primaryStmts[i] 166 | dbStmtLock.Unlock() 167 | return 168 | }) 169 | 170 | errReplicas := doParallely(len(db.replicas), func(i int) (err error) { 171 | roStmts[i], err = db.replicas[i].PrepareContext(ctx, query) 172 | dbStmtLock.Lock() 173 | dbStmt[db.replicas[i]] = roStmts[i] 174 | dbStmtLock.Unlock() 175 | 176 | // if connection error happens on RO connection, 177 | // ignore and fallback to RW connection 178 | if isDBConnectionError(err) { 179 | roStmts[i] = primaryStmts[0] 180 | return nil 181 | } 182 | return err 183 | }) 184 | 185 | err = multierr.Combine(errPrimaries, errReplicas) 186 | if err != nil { 187 | return //nolint: nakedret 188 | } 189 | 190 | _query := strings.ToUpper(query) 191 | writeFlag := strings.Contains(_query, "RETURNING") 192 | 193 | _stmt = &stmt{ 194 | loadBalancer: db.stmtLoadBalancer, 195 | primaryStmts: primaryStmts, 196 | replicaStmts: roStmts, 197 | dbStmt: dbStmt, 198 | writeFlag: writeFlag, 199 | } 200 | return _stmt, nil 201 | } 202 | 203 | // Query executes a query that returns rows, typically a SELECT. 204 | // The args are for any placeholder parameters in the query. 205 | func (db *sqlDB) Query(query string, args ...interface{}) (*sql.Rows, error) { 206 | return db.QueryContext(context.Background(), query, args...) 207 | } 208 | 209 | // QueryContext executes a query that returns rows, typically a SELECT. 210 | // The args are for any placeholder parameters in the query. 211 | func (db *sqlDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { 212 | var curDB *sql.DB 213 | writeFlag := db.queryTypeChecker.Check(query) == QueryTypeWrite 214 | 215 | if writeFlag { 216 | curDB = db.ReadWrite() 217 | } else { 218 | curDB = db.ReadOnly() 219 | } 220 | 221 | rows, err = curDB.QueryContext(ctx, query, args...) 222 | if isDBConnectionError(err) && !writeFlag { 223 | rows, err = db.ReadWrite().QueryContext(ctx, query, args...) 224 | } 225 | return 226 | } 227 | 228 | // QueryRow executes a query that is expected to return at most one row. 229 | // QueryRow always return a non-nil value. 230 | // Errors are deferred until Row's Scan method is called. 231 | func (db *sqlDB) QueryRow(query string, args ...interface{}) *sql.Row { 232 | return db.QueryRowContext(context.Background(), query, args...) 233 | } 234 | 235 | // QueryRowContext executes a query that is expected to return at most one row. 236 | // QueryRowContext always return a non-nil value. 237 | // Errors are deferred until Row's Scan method is called. 238 | func (db *sqlDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 239 | var curDB *sql.DB 240 | writeFlag := db.queryTypeChecker.Check(query) == QueryTypeWrite 241 | 242 | if writeFlag { 243 | curDB = db.ReadWrite() 244 | } else { 245 | curDB = db.ReadOnly() 246 | } 247 | 248 | row := curDB.QueryRowContext(ctx, query, args...) 249 | if isDBConnectionError(row.Err()) && !writeFlag { 250 | row = db.ReadWrite().QueryRowContext(ctx, query, args...) 251 | } 252 | 253 | return row 254 | } 255 | 256 | // SetMaxIdleConns sets the maximum number of connections in the idle 257 | // connection pool for each underlying db connection 258 | // If MaxOpenConns is greater than 0 but less than the new MaxIdleConns then the 259 | // new MaxIdleConns will be reduced to match the MaxOpenConns limit 260 | // If n <= 0, no idle connections are retained. 261 | func (db *sqlDB) SetMaxIdleConns(n int) { 262 | for i := range db.primaries { 263 | db.primaries[i].SetMaxIdleConns(n) 264 | } 265 | 266 | for i := range db.replicas { 267 | db.replicas[i].SetMaxIdleConns(n) 268 | } 269 | } 270 | 271 | // SetMaxOpenConns sets the maximum number of open connections 272 | // to each physical db. 273 | // If MaxIdleConns is greater than 0 and the new MaxOpenConns 274 | // is less than MaxIdleConns, then MaxIdleConns will be reduced to match 275 | // the new MaxOpenConns limit. If n <= 0, then there is no limit on the number 276 | // of open connections. The default is 0 (unlimited). 277 | func (db *sqlDB) SetMaxOpenConns(n int) { 278 | for i := range db.primaries { 279 | db.primaries[i].SetMaxOpenConns(n) 280 | } 281 | for i := range db.replicas { 282 | db.replicas[i].SetMaxOpenConns(n) 283 | } 284 | } 285 | 286 | // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. 287 | // Expired connections may be closed lazily before reuse. 288 | // If d <= 0, connections are reused forever. 289 | func (db *sqlDB) SetConnMaxLifetime(d time.Duration) { 290 | for i := range db.primaries { 291 | db.primaries[i].SetConnMaxLifetime(d) 292 | } 293 | for i := range db.replicas { 294 | db.replicas[i].SetConnMaxLifetime(d) 295 | } 296 | } 297 | 298 | // SetConnMaxIdleTime sets the maximum amount of time a connection may be idle. 299 | // Expired connections may be closed lazily before reuse. 300 | // If d <= 0, connections are not closed due to a connection's idle time. 301 | func (db *sqlDB) SetConnMaxIdleTime(d time.Duration) { 302 | for i := range db.primaries { 303 | db.primaries[i].SetConnMaxIdleTime(d) 304 | } 305 | 306 | for i := range db.replicas { 307 | db.replicas[i].SetConnMaxIdleTime(d) 308 | } 309 | } 310 | 311 | // ReadOnly returns the readonly database 312 | func (db *sqlDB) ReadOnly() *sql.DB { 313 | if len(db.replicas) == 0 { 314 | return db.loadBalancer.Resolve(db.primaries) 315 | } 316 | return db.loadBalancer.Resolve(db.replicas) 317 | } 318 | 319 | // ReadWrite returns the primary database 320 | func (db *sqlDB) ReadWrite() *sql.DB { 321 | return db.loadBalancer.Resolve(db.primaries) 322 | } 323 | 324 | // Conn returns a single connection by either opening a new connection or returning an existing connection from the 325 | // connection pool of the first primary db. 326 | func (db *sqlDB) Conn(ctx context.Context) (Conn, error) { 327 | c, err := db.primaries[0].Conn(ctx) 328 | if err != nil { 329 | return nil, err 330 | } 331 | 332 | return &conn{ 333 | sourceDB: db.primaries[0], 334 | conn: c, 335 | }, nil 336 | } 337 | 338 | // Stats returns database statistics for the first primary db 339 | func (db *sqlDB) Stats() sql.DBStats { 340 | return db.primaries[0].Stats() 341 | } 342 | --------------------------------------------------------------------------------