├── .github └── workflows │ ├── lint.yml │ └── test.yml ├── .gitignore ├── .golangci.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── benchmark_test.go ├── compose.go ├── compose_test.go ├── doc.go ├── go.mod ├── go.sum ├── hooks ├── loghooks │ ├── example_test.go │ ├── examples │ │ └── main.go │ └── loghooks.go └── othooks │ ├── examples │ └── main.go │ ├── othooks.go │ └── othooks_test.go ├── sqlhooks.go ├── sqlhooks_1_10.go ├── sqlhooks_1_10_interface_test.go ├── sqlhooks_interface_test.go ├── sqlhooks_mysql_test.go ├── sqlhooks_postgres_test.go ├── sqlhooks_pre_1_10.go ├── sqlhooks_sqlite3_test.go └── sqlhooks_test.go /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | on: 3 | pull_request: 4 | jobs: 5 | golangci: 6 | name: lint 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: golangci-lint 11 | uses: golangci/golangci-lint-action@v2 12 | with: 13 | # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version 14 | version: latest 15 | 16 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: "test" 2 | on: ["push","pull_request"] 3 | jobs: 4 | test: 5 | name: "Run unit tests" 6 | strategy: 7 | matrix: 8 | os: [ubuntu-latest] 9 | go-version: ["1.15.x", "1.16.x", "1.17.x"] 10 | runs-on: ${{ matrix.os }} 11 | 12 | services: 13 | mysql: 14 | image: mysql 15 | env: 16 | MYSQL_USER: test 17 | MYSQL_PASSWORD: test 18 | MYSQL_DATABASE: sqlhooks 19 | MYSQL_ALLOW_EMPTY_PASSWORD: true 20 | ports: 21 | - 3306:3306 22 | options: >- 23 | --health-cmd="mysqladmin -v ping" 24 | --health-interval=10s 25 | --health-timeout=5s 26 | --health-retries=5 27 | 28 | postgres: 29 | image: postgres 30 | env: 31 | POSTGRES_PASSWORD: test 32 | POSTGRES_DB: sqlhooks 33 | ports: 34 | - 5432:5432 35 | options: >- 36 | --health-cmd pg_isready 37 | --health-interval 10s 38 | --health-timeout 5s 39 | --health-retries 5 40 | 41 | steps: 42 | - name: Install Go 43 | uses: actions/setup-go@v2 44 | with: 45 | go-version: ${{ matrix.go-version }} 46 | 47 | - name: Checkout code 48 | uses: actions/checkout@v2 49 | with: 50 | fetch-depth: 1 51 | 52 | - uses: actions/cache@v2 53 | with: 54 | path: | 55 | ~/go/pkg/mod 56 | ~/.cache/go-build 57 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 58 | restore-keys: | 59 | ${{ runner.os }}-go- 60 | 61 | - name: Test 62 | env: 63 | SQLHOOKS_MYSQL_DSN: "test:test@/sqlhooks?interpolateParams=true" 64 | SQLHOOKS_POSTGRES_DSN: "postgres://postgres:test@localhost/sqlhooks?sslmode=disable" 65 | run: go test -race -covermode atomic -coverprofile=covprofile ./... 66 | - name: Install goveralls 67 | run: go get github.com/mattn/goveralls@v0.0.11 68 | - name: Send coverage 69 | env: 70 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 71 | run: goveralls -coverprofile=covprofile -service=github 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters-settings: 2 | staticcheck: 3 | checks: ["all", "-SA1019"] 4 | issues: 5 | exclude-rules: 6 | - path: example_test.go 7 | linters: 8 | - errcheck 9 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | 3 | ## [Unreleased](https://github.com/qustavo/sqlhooks/tree/HEAD) 4 | 5 | [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v1.0.0...HEAD) 6 | 7 | **Closed issues:** 8 | 9 | - Add Benchmarks [\#9](https://github.com/qustavo/sqlhooks/issues/9) 10 | ## [v1.0.0](https://github.com/qustavo/sqlhooks/tree/v1.0.0) (2017-05-08) 11 | [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v0.4...v1.0.0) 12 | 13 | **Merged pull requests:** 14 | 15 | - Godoc [\#7](https://github.com/qustavo/sqlhooks/pull/7) ([qustavo](https://github.com/qustavo)) 16 | - Make covermode=count [\#6](https://github.com/qustavo/sqlhooks/pull/6) ([qustavo](https://github.com/qustavo)) 17 | - V1 [\#5](https://github.com/qustavo/sqlhooks/pull/5) ([qustavo](https://github.com/qustavo)) 18 | - Expose a WrapDriver function [\#4](https://github.com/qustavo/sqlhooks/issues/4) 19 | - Implement new 1.8 interfaces [\#3](https://github.com/qustavo/sqlhooks/issues/3) 20 | 21 | ## [v0.4](https://github.com/qustavo/sqlhooks/tree/v0.4) (2017-03-23) 22 | [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v0.3...v0.4) 23 | 24 | ## [v0.3](https://github.com/qustavo/sqlhooks/tree/v0.3) (2016-06-02) 25 | [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v0.2...v0.3) 26 | 27 | **Closed issues:** 28 | 29 | - Change Notifications [\#2](https://github.com/qustavo/sqlhooks/issues/2) 30 | 31 | ## [v0.2](https://github.com/qustavo/sqlhooks/tree/v0.2) (2016-05-01) 32 | [Full Changelog](https://github.com/qustavo/sqlhooks/compare/v0.1...v0.2) 33 | 34 | ## [v0.1](https://github.com/qustavo/sqlhooks/tree/v0.1) (2016-04-25) 35 | **Merged pull requests:** 36 | 37 | - Sqlite3 [\#1](https://github.com/qustavo/sqlhooks/pull/1) ([qustavo](https://github.com/qustavo)) 38 | 39 | 40 | 41 | \* *This Change Log was automatically generated by [github_changelog_generator](https://github.com/skywinder/Github-Changelog-Generator)* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Gustavo Chaín 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlhooks 2 | ![Build Status](https://github.com/qustavo/sqlhooks/actions/workflows/test.yml/badge.svg) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/qustavo/sqlhooks)](https://goreportcard.com/report/github.com/qustavo/sqlhooks) 4 | [![Coverage Status](https://coveralls.io/repos/github/qustavo/sqlhooks/badge.svg?branch=master)](https://coveralls.io/github/qustavo/sqlhooks?branch=master) 5 | 6 | Attach hooks to any database/sql driver. 7 | 8 | The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code. 9 | 10 | # Install 11 | ```bash 12 | go get github.com/qustavo/sqlhooks/v2 13 | ``` 14 | Requires Go >= 1.14.x 15 | 16 | ## Breaking changes 17 | `V2` isn't backward compatible with previous versions, if you want to fetch old versions, you can use go modules or get them from [gopkg.in](http://gopkg.in/) 18 | ```bash 19 | go get github.com/qustavo/sqlhooks 20 | go get gopkg.in/qustavo/sqlhooks.v1 21 | ``` 22 | 23 | # Usage [![GoDoc](https://godoc.org/github.com/qustavo/dotsql?status.svg)](https://godoc.org/github.com/qustavo/sqlhooks) 24 | 25 | ```go 26 | // This example shows how to instrument sql queries in order to display the time that they consume 27 | package main 28 | 29 | import ( 30 | "context" 31 | "database/sql" 32 | "fmt" 33 | "time" 34 | 35 | "github.com/qustavo/sqlhooks/v2" 36 | "github.com/mattn/go-sqlite3" 37 | ) 38 | 39 | // Hooks satisfies the sqlhook.Hooks interface 40 | type Hooks struct {} 41 | 42 | // Before hook will print the query with it's args and return the context with the timestamp 43 | func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 44 | fmt.Printf("> %s %q", query, args) 45 | return context.WithValue(ctx, "begin", time.Now()), nil 46 | } 47 | 48 | // After hook will get the timestamp registered on the Before hook and print the elapsed time 49 | func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 50 | begin := ctx.Value("begin").(time.Time) 51 | fmt.Printf(". took: %s\n", time.Since(begin)) 52 | return ctx, nil 53 | } 54 | 55 | func main() { 56 | // First, register the wrapper 57 | sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{})) 58 | 59 | // Connect to the registered wrapped driver 60 | db, _ := sql.Open("sqlite3WithHooks", ":memory:") 61 | 62 | // Do you're stuff 63 | db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))") 64 | db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar") 65 | db.Query("SELECT id, text FROM t") 66 | } 67 | 68 | /* 69 | Output should look like: 70 | > CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs 71 | > INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs 72 | > SELECT id, text FROM t []. took: 4.653µs 73 | */ 74 | ``` 75 | 76 | # Benchmarks 77 | ``` 78 | go test -bench=. -benchmem 79 | goos: linux 80 | goarch: amd64 81 | pkg: github.com/qustavo/sqlhooks/v2 82 | cpu: Intel(R) Xeon(R) W-10885M CPU @ 2.40GHz 83 | BenchmarkSQLite3/Without_Hooks-16 191196 6163 ns/op 456 B/op 14 allocs/op 84 | BenchmarkSQLite3/With_Hooks-16 189997 6329 ns/op 456 B/op 14 allocs/op 85 | BenchmarkMySQL/Without_Hooks-16 13278 83462 ns/op 309 B/op 7 allocs/op 86 | BenchmarkMySQL/With_Hooks-16 13460 87331 ns/op 309 B/op 7 allocs/op 87 | BenchmarkPostgres/Without_Hooks-16 13016 91421 ns/op 401 B/op 10 allocs/op 88 | BenchmarkPostgres/With_Hooks-16 12339 94033 ns/op 401 B/op 10 allocs/op 89 | PASS 90 | ok github.com/qustavo/sqlhooks/v2 10.294s 91 | ``` 92 | -------------------------------------------------------------------------------- /benchmark_test.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "database/sql" 5 | "os" 6 | "testing" 7 | 8 | "github.com/go-sql-driver/mysql" 9 | "github.com/lib/pq" 10 | "github.com/mattn/go-sqlite3" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func init() { 15 | hooks := &testHooks{} 16 | hooks.reset() 17 | 18 | sql.Register("sqlite3-benchmark", Wrap(&sqlite3.SQLiteDriver{}, hooks)) 19 | sql.Register("mysql-benchmark", Wrap(&mysql.MySQLDriver{}, hooks)) 20 | sql.Register("postgres-benchmark", Wrap(&pq.Driver{}, hooks)) 21 | } 22 | 23 | func benchmark(b *testing.B, driver, dsn string) { 24 | db, err := sql.Open(driver, dsn) 25 | require.NoError(b, err) 26 | defer db.Close() 27 | 28 | var query = "SELECT 'hello'" 29 | 30 | b.ResetTimer() 31 | for i := 0; i < b.N; i++ { 32 | rows, err := db.Query(query) 33 | require.NoError(b, err) 34 | require.NoError(b, rows.Close()) 35 | } 36 | } 37 | 38 | func BenchmarkSQLite3(b *testing.B) { 39 | b.Run("Without Hooks", func(b *testing.B) { 40 | benchmark(b, "sqlite3", ":memory:") 41 | }) 42 | 43 | b.Run("With Hooks", func(b *testing.B) { 44 | benchmark(b, "sqlite3-benchmark", ":memory:") 45 | }) 46 | } 47 | 48 | func BenchmarkMySQL(b *testing.B) { 49 | dsn := os.Getenv("SQLHOOKS_MYSQL_DSN") 50 | if dsn == "" { 51 | b.Skipf("SQLHOOKS_MYSQL_DSN not set") 52 | } 53 | 54 | b.Run("Without Hooks", func(b *testing.B) { 55 | benchmark(b, "mysql", dsn) 56 | }) 57 | 58 | b.Run("With Hooks", func(b *testing.B) { 59 | benchmark(b, "mysql-benchmark", dsn) 60 | }) 61 | } 62 | 63 | func BenchmarkPostgres(b *testing.B) { 64 | dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN") 65 | if dsn == "" { 66 | b.Skipf("SQLHOOKS_POSTGRES_DSN not set") 67 | } 68 | 69 | b.Run("Without Hooks", func(b *testing.B) { 70 | benchmark(b, "postgres", dsn) 71 | }) 72 | 73 | b.Run("With Hooks", func(b *testing.B) { 74 | benchmark(b, "postgres-benchmark", dsn) 75 | }) 76 | } 77 | -------------------------------------------------------------------------------- /compose.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | // Compose allows for composing multiple Hooks into one. 9 | // It runs every callback on every hook in argument order, 10 | // even if previous hooks return an error. 11 | // If multiple hooks return errors, the error return value will be 12 | // MultipleErrors, which allows for introspecting the errors if necessary. 13 | func Compose(hooks ...Hooks) Hooks { 14 | return composed(hooks) 15 | } 16 | 17 | type composed []Hooks 18 | 19 | func (c composed) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 20 | var errors []error 21 | for _, hook := range c { 22 | c, err := hook.Before(ctx, query, args...) 23 | if err != nil { 24 | errors = append(errors, err) 25 | } 26 | if c != nil { 27 | ctx = c 28 | } 29 | } 30 | return ctx, wrapErrors(nil, errors) 31 | } 32 | 33 | func (c composed) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 34 | var errors []error 35 | for _, hook := range c { 36 | var err error 37 | c, err := hook.After(ctx, query, args...) 38 | if err != nil { 39 | errors = append(errors, err) 40 | } 41 | if c != nil { 42 | ctx = c 43 | } 44 | } 45 | return ctx, wrapErrors(nil, errors) 46 | } 47 | 48 | func (c composed) OnError(ctx context.Context, cause error, query string, args ...interface{}) error { 49 | var errors []error 50 | for _, hook := range c { 51 | if onErrorer, ok := hook.(OnErrorer); ok { 52 | if err := onErrorer.OnError(ctx, cause, query, args...); err != nil && err != cause { 53 | errors = append(errors, err) 54 | } 55 | } 56 | } 57 | return wrapErrors(cause, errors) 58 | } 59 | 60 | func wrapErrors(def error, errors []error) error { 61 | switch len(errors) { 62 | case 0: 63 | return def 64 | case 1: 65 | return errors[0] 66 | default: 67 | return MultipleErrors(errors) 68 | } 69 | } 70 | 71 | // MultipleErrors is an error that contains multiple errors. 72 | type MultipleErrors []error 73 | 74 | func (m MultipleErrors) Error() string { 75 | return fmt.Sprint("multiple errors:", []error(m)) 76 | } 77 | -------------------------------------------------------------------------------- /compose_test.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | var ( 11 | oops = errors.New("oops") 12 | oopsHook = &testHooks{ 13 | before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 14 | return ctx, oops 15 | }, 16 | after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 17 | return ctx, oops 18 | }, 19 | onError: func(ctx context.Context, err error, query string, args ...interface{}) error { 20 | return oops 21 | }, 22 | } 23 | okHook = &testHooks{ 24 | before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 25 | return ctx, nil 26 | }, 27 | after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 28 | return ctx, nil 29 | }, 30 | onError: func(ctx context.Context, err error, query string, args ...interface{}) error { 31 | return nil 32 | }, 33 | } 34 | ) 35 | 36 | func TestCompose(t *testing.T) { 37 | for _, it := range []struct { 38 | name string 39 | hooks Hooks 40 | want error 41 | }{ 42 | {"happy case", Compose(okHook, okHook), nil}, 43 | {"no hooks", Compose(), nil}, 44 | {"multiple errors", Compose(oopsHook, okHook, oopsHook), MultipleErrors([]error{oops, oops})}, 45 | {"single error", Compose(okHook, oopsHook, okHook), oops}, 46 | } { 47 | t.Run(it.name, func(t *testing.T) { 48 | t.Run("Before", func(t *testing.T) { 49 | _, got := it.hooks.Before(context.Background(), "query") 50 | if !reflect.DeepEqual(it.want, got) { 51 | t.Errorf("unexpected error. want: %q, got: %q", it.want, got) 52 | } 53 | }) 54 | t.Run("After", func(t *testing.T) { 55 | _, got := it.hooks.After(context.Background(), "query") 56 | if !reflect.DeepEqual(it.want, got) { 57 | t.Errorf("unexpected error. want: %q, got: %q", it.want, got) 58 | } 59 | }) 60 | t.Run("OnError", func(t *testing.T) { 61 | cause := errors.New("crikey") 62 | want := it.want 63 | if want == nil { 64 | want = cause 65 | } 66 | got := it.hooks.(OnErrorer).OnError(context.Background(), cause, "query") 67 | if !reflect.DeepEqual(want, got) { 68 | t.Errorf("unexpected error. want: %q, got: %q", want, got) 69 | } 70 | }) 71 | }) 72 | } 73 | } 74 | 75 | func TestWrapErrors(t *testing.T) { 76 | var ( 77 | err1 = errors.New("oops") 78 | err2 = errors.New("oops2") 79 | ) 80 | for _, it := range []struct { 81 | name string 82 | def error 83 | errors []error 84 | want error 85 | }{ 86 | {"no errors", err1, nil, err1}, 87 | {"single error", nil, []error{err1}, err1}, 88 | {"multiple errors", nil, []error{err1, err2}, MultipleErrors([]error{err1, err2})}, 89 | } { 90 | t.Run(it.name, func(t *testing.T) { 91 | if want, got := it.want, wrapErrors(it.def, it.errors); !reflect.DeepEqual(want, got) { 92 | t.Errorf("unexpected wrapping. want: %q, got %q", want, got) 93 | } 94 | }) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // package sqlhooks allows you to attach hooks to any database/sql driver. 2 | // The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code. 3 | 4 | // This example shows how to instrument sql queries in order to display the time that they consume 5 | // package main 6 | // 7 | // import ( 8 | // "context" 9 | // "database/sql" 10 | // "fmt" 11 | // "time" 12 | // 13 | // "github.com/qustavo/sqlhooks/v2" 14 | // "github.com/mattn/go-sqlite3" 15 | // ) 16 | // 17 | // // Hooks satisfies the sqlhook.Hooks interface 18 | // type Hooks struct {} 19 | // 20 | // // Before hook will print the query with it's args and return the context with the timestamp 21 | // func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 22 | // fmt.Printf("> %s %q", query, args) 23 | // return context.WithValue(ctx, "begin", time.Now()), nil 24 | // } 25 | // 26 | // // After hook will get the timestamp registered on the Before hook and print the elapsed time 27 | // func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 28 | // begin := ctx.Value("begin").(time.Time) 29 | // fmt.Printf(". took: %s\n", time.Since(begin)) 30 | // return ctx, nil 31 | // } 32 | // 33 | // func main() { 34 | // // First, register the wrapper 35 | // sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{})) 36 | // 37 | // // Connect to the registered wrapped driver 38 | // db, _ := sql.Open("sqlite3WithHooks", ":memory:") 39 | // 40 | // // Do you're stuff 41 | // db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))") 42 | // db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar") 43 | // db.Query("SELECT id, text FROM t") 44 | // } 45 | // 46 | // /* 47 | // Output should look like: 48 | // > CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs 49 | // > INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs 50 | // > SELECT id, text FROM t []. took: 4.653µs 51 | // */ 52 | package sqlhooks 53 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/qustavo/sqlhooks/v2 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/go-sql-driver/mysql v1.4.1 7 | github.com/lib/pq v1.2.0 8 | github.com/mattn/go-sqlite3 v1.10.0 9 | github.com/opentracing/opentracing-go v1.1.0 10 | github.com/stretchr/testify v1.4.0 11 | golang.org/x/tools v0.1.7 // indirect 12 | ) 13 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= 4 | github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 5 | github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= 6 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 7 | github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= 8 | github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= 9 | github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= 10 | github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= 11 | github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= 12 | github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= 13 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 14 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 16 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 17 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 18 | github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 19 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 20 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 21 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 22 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 23 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 24 | golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 25 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 26 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 27 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 28 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 29 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 30 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 31 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e h1:WUoyKPm6nCo1BnNUvPGnFG3T5DUVem42yDJZZ4CNxMA= 32 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 33 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 34 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 35 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 36 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 37 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 38 | golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ= 39 | golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= 40 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 41 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 42 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 43 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 44 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 45 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 46 | -------------------------------------------------------------------------------- /hooks/loghooks/example_test.go: -------------------------------------------------------------------------------- 1 | package loghooks 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "github.com/qustavo/sqlhooks/v2" 7 | sqlite3 "github.com/mattn/go-sqlite3" 8 | ) 9 | 10 | func Example() { 11 | driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New()) 12 | sql.Register("sqlite3-logger", driver) 13 | db, _ := sql.Open("sqlite3-logger", ":memory:") 14 | 15 | // This query will output logs 16 | db.Query("SELECT 1+1") 17 | } 18 | -------------------------------------------------------------------------------- /hooks/loghooks/examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "log" 6 | 7 | "github.com/qustavo/sqlhooks/v2" 8 | "github.com/qustavo/sqlhooks/v2/hooks/loghooks" 9 | "github.com/mattn/go-sqlite3" 10 | ) 11 | 12 | func main() { 13 | sql.Register("sqlite3log", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, loghooks.New())) 14 | db, err := sql.Open("sqlite3log", ":memory:") 15 | if err != nil { 16 | log.Fatal(err) 17 | } 18 | 19 | if _, err := db.Exec("CREATE TABLE users(ID int, name text)"); err != nil { 20 | log.Fatal(err) 21 | } 22 | 23 | if _, err := db.Exec(`INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil { 24 | log.Fatal(err) 25 | } 26 | 27 | if _, err := db.Query(`SELECT id, name FROM users`); err != nil { 28 | log.Fatal(err) 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /hooks/loghooks/loghooks.go: -------------------------------------------------------------------------------- 1 | package loghooks 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | "time" 8 | ) 9 | 10 | var started int 11 | 12 | type logger interface { 13 | Printf(string, ...interface{}) 14 | } 15 | 16 | type Hook struct { 17 | log logger 18 | } 19 | 20 | func New() *Hook { 21 | return &Hook{ 22 | log: log.New(os.Stderr, "", log.LstdFlags), 23 | } 24 | } 25 | func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 26 | return context.WithValue(ctx, &started, time.Now()), nil 27 | } 28 | 29 | func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 30 | h.log.Printf("Query: `%s`, Args: `%q`. took: %s", query, args, time.Since(ctx.Value(&started).(time.Time))) 31 | return ctx, nil 32 | } 33 | 34 | func (h *Hook) OnError(ctx context.Context, err error, query string, args ...interface{}) error { 35 | h.log.Printf("Error: %v, Query: `%s`, Args: `%q`, Took: %s", 36 | err, query, args, time.Since(ctx.Value(&started).(time.Time))) 37 | return err 38 | } 39 | -------------------------------------------------------------------------------- /hooks/othooks/examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "log" 7 | 8 | "github.com/qustavo/sqlhooks/v2" 9 | "github.com/qustavo/sqlhooks/v2/hooks/othooks" 10 | "github.com/mattn/go-sqlite3" 11 | "github.com/opentracing/opentracing-go" 12 | ) 13 | 14 | func main() { 15 | tracer := opentracing.GlobalTracer() 16 | hooks := othooks.New(tracer) 17 | sql.Register("sqlite3ot", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, hooks)) 18 | db, err := sql.Open("sqlite3ot", ":memory:") 19 | if err != nil { 20 | log.Fatal(err) 21 | } 22 | 23 | span := tracer.StartSpan("sql") 24 | defer span.Finish() 25 | ctx := opentracing.ContextWithSpan(context.Background(), span) 26 | 27 | if _, err := db.ExecContext(ctx, "CREATE TABLE users(ID int, name text)"); err != nil { 28 | log.Fatal(err) 29 | } 30 | 31 | if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil { 32 | log.Fatal(err) 33 | } 34 | 35 | if _, err := db.QueryContext(ctx, `SELECT id, name FROM users`); err != nil { 36 | log.Fatal(err) 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /hooks/othooks/othooks.go: -------------------------------------------------------------------------------- 1 | package othooks 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/opentracing/opentracing-go" 7 | "github.com/opentracing/opentracing-go/log" 8 | ) 9 | 10 | type Hook struct { 11 | tracer opentracing.Tracer 12 | } 13 | 14 | func New(tracer opentracing.Tracer) *Hook { 15 | return &Hook{tracer: tracer} 16 | } 17 | 18 | func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 19 | parent := opentracing.SpanFromContext(ctx) 20 | if parent == nil { 21 | return ctx, nil 22 | } 23 | 24 | span := h.tracer.StartSpan("sql", opentracing.ChildOf(parent.Context())) 25 | span.LogFields( 26 | log.String("query", query), 27 | log.Object("args", args), 28 | ) 29 | 30 | return opentracing.ContextWithSpan(ctx, span), nil 31 | } 32 | 33 | func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 34 | span := opentracing.SpanFromContext(ctx) 35 | if span != nil { 36 | defer span.Finish() 37 | } 38 | 39 | return ctx, nil 40 | } 41 | 42 | func (h *Hook) OnError(ctx context.Context, err error, query string, args ...interface{}) error { 43 | span := opentracing.SpanFromContext(ctx) 44 | if span != nil { 45 | defer span.Finish() 46 | span.SetTag("error", true) 47 | span.LogFields( 48 | log.Error(err), 49 | ) 50 | } 51 | 52 | return err 53 | } 54 | -------------------------------------------------------------------------------- /hooks/othooks/othooks_test.go: -------------------------------------------------------------------------------- 1 | package othooks 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "testing" 7 | 8 | "github.com/qustavo/sqlhooks/v2" 9 | sqlite3 "github.com/mattn/go-sqlite3" 10 | opentracing "github.com/opentracing/opentracing-go" 11 | "github.com/opentracing/opentracing-go/mocktracer" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | var ( 17 | tracer *mocktracer.MockTracer 18 | ) 19 | 20 | func init() { 21 | tracer = mocktracer.New() 22 | driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New(tracer)) 23 | sql.Register("ot", driver) 24 | } 25 | 26 | func TestSpansAreRecorded(t *testing.T) { 27 | db, err := sql.Open("ot", ":memory:") 28 | require.NoError(t, err) 29 | defer db.Close() 30 | tracer.Reset() 31 | 32 | parent := tracer.StartSpan("parent") 33 | ctx := opentracing.ContextWithSpan(context.Background(), parent) 34 | 35 | { 36 | rows, err := db.QueryContext(ctx, "SELECT 1+?", "1") 37 | require.NoError(t, err) 38 | rows.Close() 39 | } 40 | 41 | { 42 | rows, err := db.QueryContext(ctx, "SELECT 1+?", "1") 43 | require.NoError(t, err) 44 | rows.Close() 45 | } 46 | 47 | parent.Finish() 48 | 49 | spans := tracer.FinishedSpans() 50 | require.Len(t, spans, 3) 51 | 52 | span := spans[1] 53 | assert.Equal(t, "sql", span.OperationName) 54 | 55 | logFields := span.Logs()[0].Fields 56 | assert.Equal(t, "query", logFields[0].Key) 57 | assert.Equal(t, "SELECT 1+?", logFields[0].ValueString) 58 | assert.Equal(t, "args", logFields[1].Key) 59 | assert.Equal(t, "[1]", logFields[1].ValueString) 60 | assert.NotEmpty(t, span.FinishTime) 61 | } 62 | 63 | func TestNoSpansAreRecorded(t *testing.T) { 64 | db, err := sql.Open("ot", ":memory:") 65 | require.NoError(t, err) 66 | defer db.Close() 67 | tracer.Reset() 68 | 69 | rows, err := db.QueryContext(context.Background(), "SELECT 1") 70 | require.NoError(t, err) 71 | rows.Close() 72 | 73 | assert.Empty(t, tracer.FinishedSpans()) 74 | } 75 | -------------------------------------------------------------------------------- /sqlhooks.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "errors" 7 | ) 8 | 9 | // Hook is the hook callback signature 10 | type Hook func(ctx context.Context, query string, args ...interface{}) (context.Context, error) 11 | 12 | // ErrorHook is the error handling callback signature 13 | type ErrorHook func(ctx context.Context, err error, query string, args ...interface{}) error 14 | 15 | // Hooks instances may be passed to Wrap() to define an instrumented driver 16 | type Hooks interface { 17 | Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) 18 | After(ctx context.Context, query string, args ...interface{}) (context.Context, error) 19 | } 20 | 21 | // OnErrorer instances will be called if any error happens 22 | type OnErrorer interface { 23 | OnError(ctx context.Context, err error, query string, args ...interface{}) error 24 | } 25 | 26 | func handlerErr(ctx context.Context, hooks Hooks, err error, query string, args ...interface{}) error { 27 | h, ok := hooks.(OnErrorer) 28 | if !ok { 29 | return err 30 | } 31 | 32 | if err := h.OnError(ctx, err, query, args...); err != nil { 33 | return err 34 | } 35 | 36 | return err 37 | } 38 | 39 | // Driver implements a database/sql/driver.Driver 40 | type Driver struct { 41 | driver.Driver 42 | hooks Hooks 43 | } 44 | 45 | // Open opens a connection 46 | func (drv *Driver) Open(name string) (driver.Conn, error) { 47 | conn, err := drv.Driver.Open(name) 48 | if err != nil { 49 | return conn, err 50 | } 51 | 52 | // Drivers that don't implement driver.ConnBeginTx are not supported. 53 | if _, ok := conn.(driver.ConnBeginTx); !ok { 54 | return nil, errors.New("driver must implement driver.ConnBeginTx") 55 | } 56 | 57 | wrapped := &Conn{conn, drv.hooks} 58 | if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) { 59 | return &ExecerQueryerContextWithSessionResetter{wrapped, 60 | &ExecerContext{wrapped}, &QueryerContext{wrapped}, 61 | &SessionResetter{wrapped}}, nil 62 | } else if isExecer(conn) && isQueryer(conn) { 63 | return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, 64 | &QueryerContext{wrapped}}, nil 65 | } else if isExecer(conn) { 66 | // If conn implements an Execer interface, return a driver.Conn which 67 | // also implements Execer 68 | return &ExecerContext{wrapped}, nil 69 | } else if isQueryer(conn) { 70 | // If conn implements an Queryer interface, return a driver.Conn which 71 | // also implements Queryer 72 | return &QueryerContext{wrapped}, nil 73 | } 74 | return wrapped, nil 75 | } 76 | 77 | // Conn implements a database/sql.driver.Conn 78 | type Conn struct { 79 | Conn driver.Conn 80 | hooks Hooks 81 | } 82 | 83 | func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 84 | var ( 85 | stmt driver.Stmt 86 | err error 87 | ) 88 | 89 | if c, ok := conn.Conn.(driver.ConnPrepareContext); ok { 90 | stmt, err = c.PrepareContext(ctx, query) 91 | } else { 92 | stmt, err = conn.Prepare(query) 93 | } 94 | 95 | if err != nil { 96 | return stmt, err 97 | } 98 | 99 | return &Stmt{stmt, conn.hooks, query}, nil 100 | } 101 | 102 | func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) } 103 | func (conn *Conn) Close() error { return conn.Conn.Close() } 104 | func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() } 105 | func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 106 | return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts) 107 | } 108 | 109 | // ExecerContext implements a database/sql.driver.ExecerContext 110 | type ExecerContext struct { 111 | *Conn 112 | } 113 | 114 | func isExecer(conn driver.Conn) bool { 115 | switch conn.(type) { 116 | case driver.ExecerContext: 117 | return true 118 | case driver.Execer: 119 | return true 120 | default: 121 | return false 122 | } 123 | } 124 | 125 | func (conn *ExecerContext) execContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 126 | switch c := conn.Conn.Conn.(type) { 127 | case driver.ExecerContext: 128 | return c.ExecContext(ctx, query, args) 129 | case driver.Execer: 130 | dargs, err := namedValueToValue(args) 131 | if err != nil { 132 | return nil, err 133 | } 134 | return c.Exec(query, dargs) 135 | default: 136 | // This should not happen 137 | return nil, errors.New("ExecerContext created for a non Execer driver.Conn") 138 | } 139 | } 140 | 141 | func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 142 | var err error 143 | 144 | list := namedToInterface(args) 145 | 146 | // Exec `Before` Hooks 147 | if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { 148 | return nil, err 149 | } 150 | 151 | results, err := conn.execContext(ctx, query, args) 152 | if err != nil { 153 | return results, handlerErr(ctx, conn.hooks, err, query, list...) 154 | } 155 | 156 | if _, err := conn.hooks.After(ctx, query, list...); err != nil { 157 | return nil, err 158 | } 159 | 160 | return results, err 161 | } 162 | 163 | func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Result, error) { 164 | // We have to implement Exec since it is required in the current version of 165 | // Go for it to run ExecContext. From Go 10 it will be optional. However, 166 | // this code should never run since database/sql always prefers to run 167 | // ExecContext. 168 | return nil, errors.New("Exec was called when ExecContext was implemented") 169 | } 170 | 171 | // QueryerContext implements a database/sql.driver.QueryerContext 172 | type QueryerContext struct { 173 | *Conn 174 | } 175 | 176 | func isQueryer(conn driver.Conn) bool { 177 | switch conn.(type) { 178 | case driver.QueryerContext: 179 | return true 180 | case driver.Queryer: 181 | return true 182 | default: 183 | return false 184 | } 185 | } 186 | 187 | func (conn *QueryerContext) queryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 188 | switch c := conn.Conn.Conn.(type) { 189 | case driver.QueryerContext: 190 | return c.QueryContext(ctx, query, args) 191 | case driver.Queryer: 192 | dargs, err := namedValueToValue(args) 193 | if err != nil { 194 | return nil, err 195 | } 196 | return c.Query(query, dargs) 197 | default: 198 | // This should not happen 199 | return nil, errors.New("QueryerContext created for a non Queryer driver.Conn") 200 | } 201 | } 202 | 203 | func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 204 | var err error 205 | 206 | list := namedToInterface(args) 207 | 208 | // Query `Before` Hooks 209 | if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { 210 | return nil, err 211 | } 212 | 213 | results, err := conn.queryContext(ctx, query, args) 214 | if err != nil { 215 | return results, handlerErr(ctx, conn.hooks, err, query, list...) 216 | } 217 | 218 | if _, err := conn.hooks.After(ctx, query, list...); err != nil { 219 | return nil, err 220 | } 221 | 222 | return results, err 223 | } 224 | 225 | // ExecerQueryerContext implements database/sql.driver.ExecerContext and 226 | // database/sql.driver.QueryerContext 227 | type ExecerQueryerContext struct { 228 | *Conn 229 | *ExecerContext 230 | *QueryerContext 231 | } 232 | 233 | // ExecerQueryerContext implements database/sql.driver.ExecerContext and 234 | // database/sql.driver.QueryerContext 235 | type ExecerQueryerContextWithSessionResetter struct { 236 | *Conn 237 | *ExecerContext 238 | *QueryerContext 239 | *SessionResetter 240 | } 241 | 242 | type SessionResetter struct { 243 | *Conn 244 | } 245 | 246 | // Stmt implements a database/sql/driver.Stmt 247 | type Stmt struct { 248 | Stmt driver.Stmt 249 | hooks Hooks 250 | query string 251 | } 252 | 253 | func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 254 | if s, ok := stmt.Stmt.(driver.StmtExecContext); ok { 255 | return s.ExecContext(ctx, args) 256 | } 257 | 258 | values := make([]driver.Value, len(args)) 259 | for _, arg := range args { 260 | values[arg.Ordinal-1] = arg.Value 261 | } 262 | 263 | return stmt.Exec(values) 264 | } 265 | 266 | func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 267 | var err error 268 | 269 | list := namedToInterface(args) 270 | 271 | // Exec `Before` Hooks 272 | if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { 273 | return nil, err 274 | } 275 | 276 | results, err := stmt.execContext(ctx, args) 277 | if err != nil { 278 | return results, handlerErr(ctx, stmt.hooks, err, stmt.query, list...) 279 | } 280 | 281 | if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil { 282 | return nil, err 283 | } 284 | 285 | return results, err 286 | } 287 | 288 | func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 289 | if s, ok := stmt.Stmt.(driver.StmtQueryContext); ok { 290 | return s.QueryContext(ctx, args) 291 | } 292 | 293 | values := make([]driver.Value, len(args)) 294 | for _, arg := range args { 295 | values[arg.Ordinal-1] = arg.Value 296 | } 297 | return stmt.Query(values) 298 | } 299 | 300 | func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 301 | var err error 302 | 303 | list := namedToInterface(args) 304 | 305 | // Exec Before Hooks 306 | if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { 307 | return nil, err 308 | } 309 | 310 | rows, err := stmt.queryContext(ctx, args) 311 | if err != nil { 312 | return rows, handlerErr(ctx, stmt.hooks, err, stmt.query, list...) 313 | } 314 | 315 | if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil { 316 | return nil, err 317 | } 318 | 319 | return rows, err 320 | } 321 | 322 | func (stmt *Stmt) Close() error { return stmt.Stmt.Close() } 323 | func (stmt *Stmt) NumInput() int { return stmt.Stmt.NumInput() } 324 | func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { return stmt.Stmt.Exec(args) } 325 | func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.Stmt.Query(args) } 326 | 327 | // Wrap is used to create a new instrumented driver, it takes a vendor specific driver, and a Hooks instance to produce a new driver instance. 328 | // It's usually used inside a sql.Register() statement 329 | func Wrap(driver driver.Driver, hooks Hooks) driver.Driver { 330 | return &Driver{driver, hooks} 331 | } 332 | 333 | func namedToInterface(args []driver.NamedValue) []interface{} { 334 | list := make([]interface{}, len(args)) 335 | for i, a := range args { 336 | list[i] = a.Value 337 | } 338 | return list 339 | } 340 | 341 | // namedValueToValue copied from database/sql 342 | func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { 343 | dargs := make([]driver.Value, len(named)) 344 | for n, param := range named { 345 | if len(param.Name) > 0 { 346 | return nil, errors.New("sql: driver does not support the use of Named Parameters") 347 | } 348 | dargs[n] = param.Value 349 | } 350 | return dargs, nil 351 | } 352 | 353 | /* 354 | type hooks struct { 355 | } 356 | 357 | func (h *hooks) Before(ctx context.Context, query string, args ...interface{}) error { 358 | log.Printf("before> ctx = %+v, q=%s, args = %+v\n", ctx, query, args) 359 | return nil 360 | } 361 | 362 | func (h *hooks) After(ctx context.Context, query string, args ...interface{}) error { 363 | log.Printf("after> ctx = %+v, q=%s, args = %+v\n", ctx, query, args) 364 | return nil 365 | } 366 | 367 | func main() { 368 | sql.Register("sqlite3-proxy", Wrap(&sqlite3.SQLiteDriver{}, &hooks{})) 369 | db, err := sql.Open("sqlite3-proxy", ":memory:") 370 | if err != nil { 371 | log.Fatalln(err) 372 | } 373 | 374 | if _, ok := driver.Stmt(&Stmt{}).(driver.StmtExecContext); !ok { 375 | panic("NOPE") 376 | } 377 | 378 | if _, err := db.Exec("CREATE table users(id int)"); err != nil { 379 | log.Printf("|err| = %+v\n", err) 380 | } 381 | 382 | if _, err := db.QueryContext(context.Background(), "SELECT * FROM users WHERE id = ?", 1); err != nil { 383 | log.Printf("err = %+v\n", err) 384 | } 385 | 386 | } 387 | */ 388 | -------------------------------------------------------------------------------- /sqlhooks_1_10.go: -------------------------------------------------------------------------------- 1 | // +build go1.10 2 | 3 | package sqlhooks 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | ) 9 | 10 | func isSessionResetter(conn driver.Conn) bool { 11 | _, ok := conn.(driver.SessionResetter) 12 | return ok 13 | } 14 | 15 | func (s *SessionResetter) ResetSession(ctx context.Context) error { 16 | c := s.Conn.Conn.(driver.SessionResetter) 17 | return c.ResetSession(ctx) 18 | } 19 | -------------------------------------------------------------------------------- /sqlhooks_1_10_interface_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.10 2 | 3 | package sqlhooks 4 | 5 | import "database/sql/driver" 6 | 7 | func init() { 8 | interfaceTestCases = append(interfaceTestCases, 9 | struct { 10 | name string 11 | expectedInterfaces []interface{} 12 | }{ 13 | "ExecerQueryerContextSessionResetter", []interface{}{ 14 | (*driver.ExecerContext)(nil), 15 | (*driver.QueryerContext)(nil), 16 | (*driver.SessionResetter)(nil)}}) 17 | } 18 | -------------------------------------------------------------------------------- /sqlhooks_interface_test.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "errors" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | var interfaceTestCases = []struct { 14 | name string 15 | expectedInterfaces []interface{} 16 | }{ 17 | {"Basic", []interface{}{(*driver.Conn)(nil)}}, 18 | {"Execer", []interface{}{(*driver.Execer)(nil)}}, 19 | {"ExecerContext", []interface{}{(*driver.ExecerContext)(nil)}}, 20 | {"Queryer", []interface{}{(*driver.QueryerContext)(nil)}}, 21 | {"QueryerContext", []interface{}{(*driver.QueryerContext)(nil)}}, 22 | {"ExecerQueryerContext", []interface{}{ 23 | (*driver.ExecerContext)(nil), 24 | (*driver.QueryerContext)(nil)}}, 25 | } 26 | 27 | type fakeDriver struct{} 28 | 29 | func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 30 | switch dsn { 31 | case "Basic": 32 | return &struct{ *FakeConnBasic }{}, nil 33 | case "Execer": 34 | return &struct { 35 | *FakeConnBasic 36 | *FakeConnExecer 37 | }{}, nil 38 | case "ExecerContext": 39 | return &struct { 40 | *FakeConnBasic 41 | *FakeConnExecerContext 42 | }{}, nil 43 | case "Queryer": 44 | return &struct { 45 | *FakeConnBasic 46 | *FakeConnQueryer 47 | }{}, nil 48 | case "QueryerContext": 49 | return &struct { 50 | *FakeConnBasic 51 | *FakeConnQueryerContext 52 | }{}, nil 53 | case "ExecerQueryerContext": 54 | return &struct { 55 | *FakeConnBasic 56 | *FakeConnExecerContext 57 | *FakeConnQueryerContext 58 | }{}, nil 59 | case "ExecerQueryerContextSessionResetter": 60 | return &struct { 61 | *FakeConnBasic 62 | *FakeConnExecer 63 | *FakeConnQueryer 64 | *FakeConnSessionResetter 65 | }{}, nil 66 | case "NonConnBeginTx": 67 | return &FakeConnUnsupported{}, nil 68 | } 69 | 70 | return nil, errors.New("Fake driver not implemented") 71 | } 72 | 73 | // Conn implements a database/sql.driver.Conn 74 | type FakeConnBasic struct{} 75 | 76 | func (*FakeConnBasic) Prepare(query string) (driver.Stmt, error) { 77 | return nil, errors.New("Not implemented") 78 | } 79 | func (*FakeConnBasic) Close() error { 80 | return errors.New("Not implemented") 81 | } 82 | func (*FakeConnBasic) Begin() (driver.Tx, error) { 83 | return nil, errors.New("Not implemented") 84 | } 85 | func (*FakeConnBasic) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) { 86 | return nil, errors.New("Not implemented") 87 | } 88 | 89 | type FakeConnExecer struct{} 90 | 91 | func (*FakeConnExecer) Exec(query string, args []driver.Value) (driver.Result, error) { 92 | return nil, errors.New("Not implemented") 93 | } 94 | 95 | type FakeConnExecerContext struct{} 96 | 97 | func (*FakeConnExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 98 | return nil, errors.New("Not implemented") 99 | } 100 | 101 | type FakeConnQueryer struct{} 102 | 103 | func (*FakeConnQueryer) Query(query string, args []driver.Value) (driver.Rows, error) { 104 | return nil, errors.New("Not implemented") 105 | } 106 | 107 | type FakeConnQueryerContext struct{} 108 | 109 | func (*FakeConnQueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 110 | return nil, errors.New("Not implemented") 111 | } 112 | 113 | type FakeConnSessionResetter struct{} 114 | 115 | func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error { 116 | return errors.New("Not implemented") 117 | } 118 | 119 | // FakeConnUnsupported implements a database/sql.driver.Conn but doesn't implement 120 | // driver.ConnBeginTx. 121 | type FakeConnUnsupported struct{} 122 | 123 | func (*FakeConnUnsupported) Prepare(query string) (driver.Stmt, error) { 124 | return nil, errors.New("Not implemented") 125 | } 126 | func (*FakeConnUnsupported) Close() error { 127 | return errors.New("Not implemented") 128 | } 129 | func (*FakeConnUnsupported) Begin() (driver.Tx, error) { 130 | return nil, errors.New("Not implemented") 131 | } 132 | 133 | func TestInterfaces(t *testing.T) { 134 | drv := Wrap(&fakeDriver{}, &testHooks{}) 135 | 136 | for _, c := range interfaceTestCases { 137 | conn, err := drv.Open(c.name) 138 | require.NoErrorf(t, err, "Driver name %s", c.name) 139 | 140 | for _, i := range c.expectedInterfaces { 141 | assert.Implements(t, i, conn) 142 | } 143 | } 144 | } 145 | 146 | func TestUnsupportedDrivers(t *testing.T) { 147 | drv := Wrap(&fakeDriver{}, &testHooks{}) 148 | _, err := drv.Open("NonConnBeginTx") 149 | require.EqualError(t, err, "driver must implement driver.ConnBeginTx") 150 | } 151 | -------------------------------------------------------------------------------- /sqlhooks_mysql_test.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "database/sql" 5 | "os" 6 | "testing" 7 | 8 | "github.com/go-sql-driver/mysql" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func setUpMySQL(t *testing.T, dsn string) { 14 | db, err := sql.Open("mysql", dsn) 15 | require.NoError(t, err) 16 | require.NoError(t, db.Ping()) 17 | defer db.Close() 18 | 19 | _, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)") 20 | require.NoError(t, err) 21 | } 22 | 23 | func TestMySQL(t *testing.T) { 24 | dsn := os.Getenv("SQLHOOKS_MYSQL_DSN") 25 | if dsn == "" { 26 | t.Skipf("SQLHOOKS_MYSQL_DSN not set") 27 | } 28 | 29 | setUpMySQL(t, dsn) 30 | 31 | s := newSuite(t, &mysql.MySQLDriver{}, dsn) 32 | 33 | s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1) 34 | s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus") 35 | s.TestHooksErrors(t, "SELECT 1+1") 36 | s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") 37 | 38 | t.Run("DBWorks", func(t *testing.T) { 39 | s.hooks.reset() 40 | if _, err := s.db.Exec("DELETE FROM users"); err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)") 45 | require.NoError(t, err) 46 | for i := range [5]struct{}{} { 47 | _, err := stmt.Exec(i, "gus") 48 | require.NoError(t, err) 49 | } 50 | 51 | var count int 52 | require.NoError(t, 53 | s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), 54 | ) 55 | assert.Equal(t, 5, count) 56 | }) 57 | } 58 | -------------------------------------------------------------------------------- /sqlhooks_postgres_test.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "database/sql" 5 | "os" 6 | "testing" 7 | 8 | "github.com/lib/pq" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func setUpPostgres(t *testing.T, dsn string) { 14 | db, err := sql.Open("postgres", dsn) 15 | require.NoError(t, err) 16 | require.NoError(t, db.Ping()) 17 | defer db.Close() 18 | 19 | _, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)") 20 | require.NoError(t, err) 21 | } 22 | 23 | func TestPostgres(t *testing.T) { 24 | dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN") 25 | if dsn == "" { 26 | t.Skipf("SQLHOOKS_POSTGRES_DSN not set") 27 | } 28 | 29 | setUpPostgres(t, dsn) 30 | 31 | s := newSuite(t, &pq.Driver{}, dsn) 32 | 33 | s.TestHooksExecution(t, "SELECT * FROM users WHERE id = $1", 1) 34 | s.TestHooksArguments(t, "SELECT * FROM users WHERE id = $1 AND name = $2", int64(1), "Gus") 35 | s.TestHooksErrors(t, "SELECT 1+1") 36 | s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") 37 | 38 | t.Run("DBWorks", func(t *testing.T) { 39 | s.hooks.reset() 40 | if _, err := s.db.Exec("DELETE FROM users"); err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES($1, $2)") 45 | require.NoError(t, err) 46 | for i := range [5]struct{}{} { 47 | _, err := stmt.Exec(i, "gus") 48 | require.NoError(t, err) 49 | } 50 | 51 | var count int 52 | require.NoError(t, 53 | s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), 54 | ) 55 | assert.Equal(t, 5, count) 56 | }) 57 | } 58 | -------------------------------------------------------------------------------- /sqlhooks_pre_1_10.go: -------------------------------------------------------------------------------- 1 | // +build !go1.10 2 | 3 | package sqlhooks 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | "errors" 9 | ) 10 | 11 | func isSessionResetter(conn driver.Conn) bool { 12 | return false 13 | } 14 | 15 | func (s *SessionResetter) ResetSession(ctx context.Context) error { 16 | return errors.New("SessionResetter not implemented") 17 | } 18 | -------------------------------------------------------------------------------- /sqlhooks_sqlite3_test.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "database/sql" 5 | "os" 6 | "testing" 7 | "time" 8 | 9 | "github.com/mattn/go-sqlite3" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func setUp(t *testing.T) func() { 15 | dbName := "sqlite3test.db" 16 | 17 | db, err := sql.Open("sqlite3", dbName) 18 | require.NoError(t, err) 19 | defer db.Close() 20 | 21 | _, err = db.Exec("CREATE table users(id int, name text)") 22 | require.NoError(t, err) 23 | 24 | return func() { os.Remove(dbName) } 25 | } 26 | 27 | func TestSQLite3(t *testing.T) { 28 | defer setUp(t)() 29 | s := newSuite(t, &sqlite3.SQLiteDriver{}, "sqlite3test.db") 30 | 31 | s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1) 32 | s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus") 33 | s.TestHooksErrors(t, "SELECT 1+1") 34 | s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") 35 | 36 | t.Run("DBWorks", func(t *testing.T) { 37 | s.hooks.reset() 38 | if _, err := s.db.Exec("DELETE FROM users"); err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)") 43 | require.NoError(t, err) 44 | for range [5]struct{}{} { 45 | _, err := stmt.Exec(time.Now().UnixNano(), "gus") 46 | require.NoError(t, err) 47 | } 48 | 49 | var count int 50 | require.NoError(t, 51 | s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), 52 | ) 53 | assert.Equal(t, 5, count) 54 | }) 55 | } 56 | -------------------------------------------------------------------------------- /sqlhooks_test.go: -------------------------------------------------------------------------------- 1 | package sqlhooks 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "errors" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | type testHooks struct { 17 | before Hook 18 | after Hook 19 | onError ErrorHook 20 | } 21 | 22 | func newTestHooks() *testHooks { 23 | th := &testHooks{} 24 | th.reset() 25 | return th 26 | } 27 | 28 | func (h *testHooks) reset() { 29 | noop := func(ctx context.Context, _ string, _ ...interface{}) (context.Context, error) { 30 | return ctx, nil 31 | } 32 | 33 | noopErr := func(_ context.Context, err error, _ string, _ ...interface{}) error { 34 | return err 35 | } 36 | 37 | h.before, h.after, h.onError = noop, noop, noopErr 38 | } 39 | 40 | func (h *testHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 41 | return h.before(ctx, query, args...) 42 | } 43 | 44 | func (h *testHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 45 | return h.after(ctx, query, args...) 46 | } 47 | 48 | func (h *testHooks) OnError(ctx context.Context, err error, query string, args ...interface{}) error { 49 | return h.onError(ctx, err, query, args...) 50 | } 51 | 52 | type suite struct { 53 | db *sql.DB 54 | hooks *testHooks 55 | } 56 | 57 | func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite { 58 | hooks := newTestHooks() 59 | 60 | driverName := fmt.Sprintf("sqlhooks-%s", time.Now().String()) 61 | sql.Register(driverName, Wrap(driver, hooks)) 62 | 63 | db, err := sql.Open(driverName, dsn) 64 | require.NoError(t, err) 65 | require.NoError(t, db.Ping()) 66 | 67 | return &suite{db, hooks} 68 | } 69 | 70 | func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) { 71 | var before, after bool 72 | 73 | s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { 74 | before = true 75 | return ctx, nil 76 | } 77 | s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { 78 | after = true 79 | return ctx, nil 80 | } 81 | 82 | t.Run("Query", func(t *testing.T) { 83 | before, after = false, false 84 | _, err := s.db.Query(query, args...) 85 | require.NoError(t, err) 86 | assert.True(t, before, "Before Hook did not run for query: "+query) 87 | assert.True(t, after, "After Hook did not run for query: "+query) 88 | }) 89 | 90 | t.Run("QueryContext", func(t *testing.T) { 91 | before, after = false, false 92 | _, err := s.db.QueryContext(context.Background(), query, args...) 93 | require.NoError(t, err) 94 | assert.True(t, before, "Before Hook did not run for query: "+query) 95 | assert.True(t, after, "After Hook did not run for query: "+query) 96 | }) 97 | 98 | t.Run("Exec", func(t *testing.T) { 99 | before, after = false, false 100 | _, err := s.db.Exec(query, args...) 101 | require.NoError(t, err) 102 | assert.True(t, before, "Before Hook did not run for query: "+query) 103 | assert.True(t, after, "After Hook did not run for query: "+query) 104 | }) 105 | 106 | t.Run("ExecContext", func(t *testing.T) { 107 | before, after = false, false 108 | _, err := s.db.ExecContext(context.Background(), query, args...) 109 | require.NoError(t, err) 110 | assert.True(t, before, "Before Hook did not run for query: "+query) 111 | assert.True(t, after, "After Hook did not run for query: "+query) 112 | }) 113 | 114 | t.Run("Statements", func(t *testing.T) { 115 | before, after = false, false 116 | stmt, err := s.db.Prepare(query) 117 | require.NoError(t, err) 118 | 119 | // Hooks just run when the stmt is executed (Query or Exec) 120 | assert.False(t, before, "Before Hook run before execution: "+query) 121 | assert.False(t, after, "After Hook run before execution: "+query) 122 | 123 | _, err = stmt.Query(args...) 124 | require.NoError(t, err) 125 | assert.True(t, before, "Before Hook did not run for query: "+query) 126 | assert.True(t, after, "After Hook did not run for query: "+query) 127 | }) 128 | } 129 | 130 | func (s *suite) testHooksArguments(t *testing.T, query string, args ...interface{}) { 131 | hook := func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { 132 | assert.Equal(t, query, q) 133 | assert.Equal(t, args, a) 134 | assert.Equal(t, "val", ctx.Value("key").(string)) 135 | return ctx, nil 136 | } 137 | s.hooks.before = hook 138 | s.hooks.after = hook 139 | 140 | ctx := context.WithValue(context.Background(), "key", "val") //nolint:staticcheck 141 | { 142 | _, err := s.db.QueryContext(ctx, query, args...) 143 | require.NoError(t, err) 144 | } 145 | 146 | { 147 | _, err := s.db.ExecContext(ctx, query, args...) 148 | require.NoError(t, err) 149 | } 150 | } 151 | 152 | func (s *suite) TestHooksArguments(t *testing.T, query string, args ...interface{}) { 153 | t.Run("TestHooksArguments", func(t *testing.T) { s.testHooksArguments(t, query, args...) }) 154 | } 155 | 156 | func (s *suite) testHooksErrors(t *testing.T, query string) { 157 | boom := errors.New("boom") 158 | s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 159 | return ctx, boom 160 | } 161 | 162 | s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 163 | assert.False(t, true, "this should not run") 164 | return ctx, nil 165 | } 166 | 167 | _, err := s.db.Query(query) 168 | assert.Equal(t, boom, err) 169 | } 170 | 171 | func (s *suite) TestHooksErrors(t *testing.T, query string) { 172 | t.Run("TestHooksErrors", func(t *testing.T) { s.testHooksErrors(t, query) }) 173 | } 174 | 175 | func (s *suite) testErrHookHook(t *testing.T, query string, args ...interface{}) { 176 | s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 177 | return ctx, nil 178 | } 179 | 180 | s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 181 | assert.False(t, true, "after hook should not run") 182 | return ctx, nil 183 | } 184 | 185 | s.hooks.onError = func(ctx context.Context, err error, query string, args ...interface{}) error { 186 | assert.True(t, true, "onError hook should run") 187 | return err 188 | } 189 | 190 | _, err := s.db.Query(query) 191 | require.Error(t, err) 192 | } 193 | 194 | func (s *suite) TestErrHookHook(t *testing.T, query string, args ...interface{}) { 195 | t.Run("TestErrHookHook", func(t *testing.T) { s.testErrHookHook(t, query, args...) }) 196 | } 197 | 198 | func TestNamedValueToValue(t *testing.T) { 199 | named := []driver.NamedValue{ 200 | {Ordinal: 1, Value: "foo"}, 201 | {Ordinal: 2, Value: 42}, 202 | } 203 | want := []driver.Value{"foo", 42} 204 | dargs, err := namedValueToValue(named) 205 | require.NoError(t, err) 206 | assert.Equal(t, want, dargs) 207 | } 208 | --------------------------------------------------------------------------------