├── .github └── workflows │ └── go.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── README.md ├── conn.go ├── conn_go110.go ├── conn_go115.go ├── conn_go19.go ├── conn_test.go ├── connector.go ├── connector_test.go ├── contributors ├── driver.go ├── driver_go110.go ├── fakedb_test.go ├── go.mod ├── helpers.go ├── interceptor.go ├── result.go ├── rows.go ├── rows_picker.go ├── rows_test.go ├── stmt.go ├── stmt_go19.go ├── stmt_go19_test.go ├── stmt_test.go ├── tools └── rows_picker_gen.go └── tx.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: 8 | - "*" 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: 1.17 20 | 21 | - name: build 22 | run: go build -v ./... 23 | 24 | test: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v2 28 | 29 | - name: Set up Go 30 | uses: actions/setup-go@v2 31 | with: 32 | go-version: 1.17 33 | 34 | - name: Test 35 | run: go test -v -race ./... 36 | 37 | check: 38 | runs-on: ubuntu-latest 39 | steps: 40 | - uses: actions/checkout@v2 41 | 42 | - name: golangci-lint 43 | run: docker run -v $GITHUB_WORKSPACE:/repo -w /repo golangci/golangci-lint:v1.42 golangci-lint run 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | # IDEs 27 | .idea -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | issues: 2 | exclude: 3 | - "SA1019: .* has been deprecated since Go 1.*: Drivers should implement .*" 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Expansive Worlds 4 | Copyright (c) 2017 Avalanche Studios 5 | Copyright (c) 2020 Alan Shreve 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![GoDoc](https://godoc.org/github.com/ngrok/sqlmw?status.svg)](https://godoc.org/github.com/ngrok/sqlmw) 2 | 3 | # sqlmw 4 | sqlmw provides an absurdly simple API that allows a caller to wrap a `database/sql` driver 5 | with middleware. 6 | 7 | This provides an abstraction similar to http middleware or GRPC interceptors but for the database/sql package. 8 | This allows a caller to implement observability like tracing and logging easily. More importantly, it also enables 9 | powerful possible behaviors like transparently modifying arguments, results or query execution strategy. This power allows programmers to implement 10 | functionality like automatic sharding, selective tracing, automatic caching, transparent query mirroring, retries, fail-over 11 | in a reuseable way, and more. 12 | 13 | ## Usage 14 | 15 | - Define a new type and embed the `sqlmw.NullInterceptor` type. 16 | - Add a method you want to intercept from the `sqlmw.Interceptor` interface. 17 | - Wrap the driver with your interceptor with `sqlmw.Driver` and then install it with `sql.Register`. 18 | - Use `sql.Open` on the new driver string that was passed to register. 19 | 20 | Here's a complete example: 21 | 22 | ```go 23 | func run(dsn string) { 24 | // install the wrapped driver 25 | sql.Register("postgres-mw", sqlmw.Driver(pq.Driver{}, new(sqlInterceptor))) 26 | db, err := sql.Open("postgres-mw", dsn) 27 | ... 28 | } 29 | 30 | type sqlInterceptor struct { 31 | sqlmw.NullInterceptor 32 | } 33 | 34 | func (in *sqlInterceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQueryContext, query string, args []driver.NamedValue) (driver.Rows, error) { 35 | startedAt := time.Now() 36 | rows, err := conn.QueryContext(ctx, args) 37 | log.Debug("executed sql query", "duration", time.Since(startedAt), "query", query, "args", args, "err", err) 38 | return rows, err 39 | } 40 | ``` 41 | 42 | You may override any subset of methods to intercept in the `Interceptor` interface (https://godoc.org/github.com/ngrok/sqlmw#Interceptor): 43 | 44 | ```go 45 | type Interceptor interface { 46 | // Connection interceptors 47 | ConnBeginTx(context.Context, driver.ConnBeginTx, driver.TxOptions) (driver.Tx, error) 48 | ConnPrepareContext(context.Context, driver.ConnPrepareContext, string) (driver.Stmt, error) 49 | ConnPing(context.Context, driver.Pinger) error 50 | ConnExecContext(context.Context, driver.ExecerContext, string, []driver.NamedValue) (driver.Result, error) 51 | ConnQueryContext(context.Context, driver.QueryerContext, string, []driver.NamedValue) (driver.Rows, error) 52 | 53 | // Connector interceptors 54 | ConnectorConnect(context.Context, driver.Connector) (driver.Conn, error) 55 | 56 | // Results interceptors 57 | ResultLastInsertId(driver.Result) (int64, error) 58 | ResultRowsAffected(driver.Result) (int64, error) 59 | 60 | // Rows interceptors 61 | RowsNext(context.Context, driver.Rows, []driver.Value) error 62 | 63 | // Stmt interceptors 64 | StmtExecContext(context.Context, driver.StmtExecContext, string, []driver.NamedValue) (driver.Result, error) 65 | StmtQueryContext(context.Context, driver.StmtQueryContext, string, []driver.NamedValue) (driver.Rows, error) 66 | StmtClose(context.Context, driver.Stmt) error 67 | 68 | // Tx interceptors 69 | TxCommit(context.Context, driver.Tx) error 70 | TxRollback(context.Context, driver.Tx) error 71 | } 72 | ``` 73 | 74 | Bear in mind that because you are intercepting the calls entirely, that you are responsible for passing control up to the wrapped 75 | driver in any function that you override, like so: 76 | 77 | ```go 78 | func (in *sqlInterceptor) ConnPing(ctx context.Context, conn driver.Pinger) error { 79 | return conn.Ping(ctx) 80 | } 81 | ``` 82 | 83 | ## Examples 84 | 85 | ### Logging 86 | 87 | ```go 88 | func (in *sqlInterceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQueryContext, query string, args []driver.NamedValue) (driver.Rows, error) { 89 | startedAt := time.Now() 90 | rows, err := conn.QueryContext(ctx, args) 91 | log.Debug("executed sql query", "duration", time.Since(startedAt), "query", query, "args", args, "err", err) 92 | return rows, err 93 | } 94 | ``` 95 | 96 | ### Tracing 97 | 98 | ```go 99 | func (in *sqlInterceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQueryContext, query string, args []driver.NamedValue) (driver.Rows, error) { 100 | span := trace.FromContext(ctx).NewSpan(ctx, "StmtQueryContext") 101 | span.Tags["query"] = query 102 | defer span.Finish() 103 | rows, err := conn.QueryContext(ctx, args) 104 | if err != nil { 105 | span.Error(err) 106 | } 107 | return rows, err 108 | } 109 | ``` 110 | 111 | ### Retries 112 | 113 | ```go 114 | func (in *sqlInterceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQueryContext, query string, args []driver.NamedValue) (driver.Rows, error) { 115 | for { 116 | rows, err := conn.QueryContext(ctx, args) 117 | if err == nil { 118 | return rows, nil 119 | } 120 | if err != nil && !isIdempotent(query) { 121 | return nil, err 122 | } 123 | select { 124 | case <-ctx.Done(): 125 | return nil, ctx.Err() 126 | case <-time.After(time.Second): 127 | } 128 | } 129 | } 130 | ``` 131 | 132 | 133 | ## Comparison with similar projects 134 | 135 | There are a number of other packages that allow the programmer to wrap a `database/sql/driver.Driver` to add logging or tracing. 136 | 137 | Examples of tracing packages: 138 | - github.com/ExpansiveWorlds/instrumentedsql 139 | - contrib.go.opencensus.io/integrations/ocsql 140 | - gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql 141 | 142 | A few provide a much more flexible setup of arbitrary before/after hooks to facilitate custom observability. 143 | 144 | Packages that provide before/after hooks: 145 | - github.com/gchaincl/sqlhooks 146 | - github.com/shogo82148/go-sql-proxy 147 | 148 | None of these packages provide an interface with sufficient power. `sqlmw` passes control of executing the 149 | sql query to the caller which allows the caller to completely disintermediate the sql calls. This is what provides 150 | the power to implement advanced behaviors like caching, sharding, retries, etc. 151 | 152 | ## Go version support 153 | 154 | Go versions 1.9 and forward are supported. 155 | 156 | ## Fork 157 | 158 | This project began by forking the code in github.com/luna-duclos/instrumentedsql, which itself is a fork of github.com/ExpansiveWorlds/instrumentedsql 159 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | ) 7 | 8 | type wrappedConn struct { 9 | intr Interceptor 10 | parent driver.Conn 11 | } 12 | 13 | // Compile time validation that our types implement the expected interfaces 14 | var ( 15 | _ driver.Conn = wrappedConn{} 16 | _ driver.ConnBeginTx = wrappedConn{} 17 | _ driver.ConnPrepareContext = wrappedConn{} 18 | _ driver.Execer = wrappedConn{} 19 | _ driver.ExecerContext = wrappedConn{} 20 | _ driver.Pinger = wrappedConn{} 21 | _ driver.Queryer = wrappedConn{} 22 | _ driver.QueryerContext = wrappedConn{} 23 | ) 24 | 25 | func (c wrappedConn) Prepare(query string) (driver.Stmt, error) { 26 | stmt, err := c.parent.Prepare(query) 27 | if err != nil { 28 | return nil, err 29 | } 30 | return wrappedStmt{intr: c.intr, query: query, parent: stmt, conn: c}, nil 31 | } 32 | 33 | func (c wrappedConn) Close() error { 34 | return c.parent.Close() 35 | } 36 | 37 | func (c wrappedConn) Begin() (driver.Tx, error) { 38 | tx, err := c.parent.Begin() 39 | if err != nil { 40 | return nil, err 41 | } 42 | return wrappedTx{intr: c.intr, parent: tx}, nil 43 | } 44 | 45 | func (c wrappedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { 46 | wrappedParent := wrappedParentConn{c.parent} 47 | ctx, tx, err = c.intr.ConnBeginTx(ctx, wrappedParent, opts) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return wrappedTx{intr: c.intr, ctx: ctx, parent: tx}, nil 52 | } 53 | 54 | func (c wrappedConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { 55 | wrappedParent := wrappedParentConn{c.parent} 56 | ctx, stmt, err = c.intr.ConnPrepareContext(ctx, wrappedParent, query) 57 | if err != nil { 58 | return nil, err 59 | } 60 | return wrappedStmt{intr: c.intr, ctx: ctx, query: query, parent: stmt, conn: c}, nil 61 | } 62 | 63 | func (c wrappedConn) Exec(query string, args []driver.Value) (driver.Result, error) { 64 | if execer, ok := c.parent.(driver.Execer); ok { 65 | res, err := execer.Exec(query, args) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | return wrappedResult{intr: c.intr, parent: res}, nil 71 | } 72 | return nil, driver.ErrSkip 73 | } 74 | 75 | func (c wrappedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { 76 | wrappedParent := wrappedParentConn{c.parent} 77 | r, err = c.intr.ConnExecContext(ctx, wrappedParent, query, args) 78 | if err != nil { 79 | return nil, err 80 | } 81 | return wrappedResult{intr: c.intr, ctx: ctx, parent: r}, nil 82 | } 83 | 84 | func (c wrappedConn) Ping(ctx context.Context) (err error) { 85 | if pinger, ok := c.parent.(driver.Pinger); ok { 86 | return c.intr.ConnPing(ctx, pinger) 87 | } 88 | return nil 89 | } 90 | 91 | func (c wrappedConn) Query(query string, args []driver.Value) (driver.Rows, error) { 92 | if queryer, ok := c.parent.(driver.Queryer); ok { 93 | rows, err := queryer.Query(query, args) 94 | if err != nil { 95 | return nil, err 96 | } 97 | return wrapRows(context.Background(), c.intr, rows), nil //nolint 98 | } 99 | return nil, driver.ErrSkip 100 | } 101 | 102 | func (c wrappedConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { 103 | // Quick skip path: If the wrapped connection implements neither QueryerContext nor Queryer, we have absolutely nothing to do 104 | _, hasQueryerContext := c.parent.(driver.QueryerContext) 105 | _, hasQueryer := c.parent.(driver.Queryer) 106 | if !hasQueryerContext && !hasQueryer { 107 | return nil, driver.ErrSkip 108 | } 109 | 110 | wrappedParent := wrappedParentConn{c.parent} 111 | ctx, rows, err = c.intr.ConnQueryContext(ctx, wrappedParent, query, args) 112 | if err != nil { 113 | return nil, err 114 | } 115 | 116 | return wrapRows(ctx, c.intr, rows), nil 117 | } 118 | 119 | type wrappedParentConn struct { 120 | driver.Conn 121 | } 122 | 123 | func (c wrappedParentConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { 124 | if connBeginTx, ok := c.Conn.(driver.ConnBeginTx); ok { 125 | return connBeginTx.BeginTx(ctx, opts) 126 | } 127 | // Fallback implementation 128 | select { 129 | case <-ctx.Done(): 130 | return nil, ctx.Err() 131 | default: 132 | return c.Conn.Begin() 133 | } 134 | } 135 | 136 | func (c wrappedParentConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 137 | if connPrepareCtx, ok := c.Conn.(driver.ConnPrepareContext); ok { 138 | return connPrepareCtx.PrepareContext(ctx, query) 139 | } 140 | // Fallback implementation 141 | select { 142 | case <-ctx.Done(): 143 | return nil, ctx.Err() 144 | default: 145 | return c.Conn.Prepare(query) 146 | } 147 | } 148 | 149 | func (c wrappedParentConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { 150 | if execContext, ok := c.Conn.(driver.ExecerContext); ok { 151 | return execContext.ExecContext(ctx, query, args) 152 | } 153 | // Fallback implementation 154 | dargs, err := namedValueToValue(args) 155 | if err != nil { 156 | return nil, err 157 | } 158 | select { 159 | case <-ctx.Done(): 160 | return nil, ctx.Err() 161 | default: 162 | return c.Conn.(driver.Execer).Exec(query, dargs) 163 | } 164 | } 165 | 166 | func (c wrappedParentConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { 167 | if queryerContext, ok := c.Conn.(driver.QueryerContext); ok { 168 | return queryerContext.QueryContext(ctx, query, args) 169 | } 170 | // Fallback implementation 171 | dargs, err := namedValueToValue(args) 172 | if err != nil { 173 | return nil, err 174 | } 175 | select { 176 | case <-ctx.Done(): 177 | return nil, ctx.Err() 178 | default: 179 | return c.Conn.(driver.Queryer).Query(query, dargs) 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /conn_go110.go: -------------------------------------------------------------------------------- 1 | // +build go1.10 2 | 3 | package sqlmw 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | ) 9 | 10 | var _ driver.SessionResetter = wrappedConn{} 11 | 12 | func (c wrappedConn) ResetSession(ctx context.Context) error { 13 | conn, ok := c.parent.(driver.SessionResetter) 14 | if !ok { 15 | return nil 16 | } 17 | 18 | return conn.ResetSession(ctx) 19 | } 20 | -------------------------------------------------------------------------------- /conn_go115.go: -------------------------------------------------------------------------------- 1 | // +build go1.15 2 | 3 | package sqlmw 4 | 5 | import ( 6 | "database/sql/driver" 7 | ) 8 | 9 | var _ driver.SessionResetter = wrappedConn{} 10 | 11 | func (c wrappedConn) IsValid() bool { 12 | conn, ok := c.parent.(driver.Validator) 13 | if !ok { 14 | // the default if driver.Validator is not supported 15 | return true 16 | } 17 | 18 | return conn.IsValid() 19 | } 20 | -------------------------------------------------------------------------------- /conn_go19.go: -------------------------------------------------------------------------------- 1 | // +build go1.9 2 | 3 | package sqlmw 4 | 5 | import "database/sql/driver" 6 | 7 | var ( 8 | _ driver.NamedValueChecker = wrappedConn{} 9 | ) 10 | 11 | func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { 12 | nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) 13 | return err 14 | } 15 | 16 | func (c wrappedConn) CheckNamedValue(v *driver.NamedValue) error { 17 | if checker, ok := c.parent.(driver.NamedValueChecker); ok { 18 | return checker.CheckNamedValue(v) 19 | } 20 | 21 | return defaultCheckNamedValue(v) 22 | } 23 | -------------------------------------------------------------------------------- /conn_test.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "testing" 8 | ) 9 | 10 | type connCtxKey string 11 | 12 | const ( 13 | connRowContextKey connCtxKey = "context" 14 | connRowContextValue string = "value" 15 | connStmtContextKey connCtxKey = "stmtcontext" 16 | connStmtContextValue string = "stmtvalue" 17 | connTxContextKey connCtxKey = "txcontext" 18 | connTxContextValue string = "txvalue" 19 | ) 20 | 21 | type connTestInterceptor struct { 22 | T *testing.T 23 | RowsNextValid bool 24 | RowsCloseValid bool 25 | StmtCloseValid bool 26 | TxCommitValid bool 27 | TxRollbackValid bool 28 | NullInterceptor 29 | } 30 | 31 | func (i *connTestInterceptor) ConnPrepareContext(ctx context.Context, conn driver.ConnPrepareContext, query string) (context.Context, driver.Stmt, error) { 32 | ctx = context.WithValue(ctx, connStmtContextKey, connStmtContextValue) 33 | 34 | s, err := conn.PrepareContext(ctx, query) 35 | return ctx, s, err 36 | } 37 | 38 | func (i *connTestInterceptor) ConnQueryContext(ctx context.Context, conn driver.QueryerContext, query string, args []driver.NamedValue) (context.Context, driver.Rows, error) { 39 | ctx = context.WithValue(ctx, connRowContextKey, connRowContextValue) 40 | 41 | r, err := conn.QueryContext(ctx, query, args) 42 | return ctx, r, err 43 | } 44 | 45 | func (i *connTestInterceptor) ConnBeginTx(ctx context.Context, conn driver.ConnBeginTx, txOpts driver.TxOptions) (context.Context, driver.Tx, error) { 46 | ctx = context.WithValue(ctx, connTxContextKey, connTxContextValue) 47 | 48 | t, err := conn.BeginTx(ctx, txOpts) 49 | return ctx, t, err 50 | } 51 | 52 | func (i *connTestInterceptor) RowsNext(ctx context.Context, rows driver.Rows, dest []driver.Value) error { 53 | if ctx.Value(connRowContextKey) == connRowContextValue { 54 | i.RowsNextValid = true 55 | } 56 | 57 | return rows.Next(dest) 58 | } 59 | 60 | func (i *connTestInterceptor) RowsClose(ctx context.Context, rows driver.Rows) error { 61 | if ctx.Value(connRowContextKey) == connRowContextValue { 62 | i.RowsCloseValid = true 63 | } 64 | 65 | return rows.Close() 66 | } 67 | 68 | func (i *connTestInterceptor) StmtClose(ctx context.Context, stmt driver.Stmt) error { 69 | if ctx.Value(connStmtContextKey) == connStmtContextValue { 70 | i.StmtCloseValid = true 71 | } 72 | 73 | i.T.Log(ctx) 74 | 75 | return stmt.Close() 76 | } 77 | 78 | func (i *connTestInterceptor) TxCommit(ctx context.Context, tx driver.Tx) error { 79 | if ctx.Value(connTxContextKey) == connTxContextValue { 80 | i.TxCommitValid = true 81 | } 82 | 83 | i.T.Log(ctx) 84 | 85 | return tx.Commit() 86 | } 87 | 88 | func (i *connTestInterceptor) TxRollback(ctx context.Context, tx driver.Tx) error { 89 | if ctx.Value(connTxContextKey) == connTxContextValue { 90 | i.TxRollbackValid = true 91 | } 92 | 93 | i.T.Log(ctx) 94 | 95 | return tx.Rollback() 96 | } 97 | 98 | func TestConnQueryContext_PassWrappedRowContext(t *testing.T) { 99 | driverName := driverName(t) 100 | 101 | con := &fakeConn{} 102 | 103 | ti := &connTestInterceptor{T: t} 104 | 105 | sql.Register( 106 | driverName, 107 | Driver(&fakeDriver{conn: con}, ti), 108 | ) 109 | 110 | db, err := sql.Open(driverName, "") 111 | if err != nil { 112 | t.Fatalf("Failed to open: %v", err) 113 | } 114 | 115 | t.Cleanup(func() { 116 | if err := db.Close(); err != nil { 117 | t.Errorf("Failed to close db: %v", err) 118 | } 119 | }) 120 | 121 | rows, err := db.QueryContext(context.Background(), "") 122 | if err != nil { 123 | t.Fatalf("Prepare failed: %s", err) 124 | } 125 | 126 | rows.Next() 127 | rows.Close() 128 | 129 | if !ti.RowsCloseValid { 130 | t.Error("RowsClose context not valid") 131 | } 132 | if !ti.RowsNextValid { 133 | t.Error("RowsNext context not valid") 134 | } 135 | } 136 | 137 | func TestConnPrepareContext_PassWrappedStmtContext(t *testing.T) { 138 | driverName := driverName(t) 139 | 140 | con := &fakeConn{} 141 | fakeStmt := &fakeStmt{ 142 | rows: &fakeRows{ 143 | con: con, 144 | vals: [][]driver.Value{{}}, 145 | }, 146 | } 147 | con.stmt = fakeStmt 148 | 149 | ti := &connTestInterceptor{T: t} 150 | 151 | sql.Register( 152 | driverName, 153 | Driver(&fakeDriver{conn: con}, ti), 154 | ) 155 | 156 | db, err := sql.Open(driverName, "") 157 | if err != nil { 158 | t.Fatalf("Failed to open: %v", err) 159 | } 160 | 161 | t.Cleanup(func() { 162 | if err := db.Close(); err != nil { 163 | t.Errorf("Failed to close db: %v", err) 164 | } 165 | }) 166 | 167 | stmt, err := db.PrepareContext(context.Background(), "") 168 | if err != nil { 169 | t.Fatalf("Prepare failed: %s", err) 170 | } 171 | 172 | stmt.Close() 173 | 174 | if !ti.StmtCloseValid { 175 | t.Error("StmtClose context not valid") 176 | } 177 | } 178 | 179 | func TestConnBeginTx_PassWrappedTxContextCommit(t *testing.T) { 180 | driverName := driverName(t) 181 | 182 | con := &fakeConn{} 183 | fakeTx := &fakeTx{} 184 | con.tx = fakeTx 185 | 186 | ti := &connTestInterceptor{T: t} 187 | 188 | sql.Register( 189 | driverName, 190 | Driver(&fakeDriver{conn: con}, ti), 191 | ) 192 | 193 | db, err := sql.Open(driverName, "") 194 | if err != nil { 195 | t.Fatalf("Failed to open: %v", err) 196 | } 197 | 198 | t.Cleanup(func() { 199 | if err := db.Close(); err != nil { 200 | t.Errorf("Failed to close db: %v", err) 201 | } 202 | }) 203 | 204 | tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 205 | if err != nil { 206 | t.Fatalf("Prepare failed: %s", err) 207 | } 208 | 209 | err = tx.Commit() 210 | if err != nil { 211 | t.Fatalf("Commit failed: %s", err) 212 | } 213 | 214 | if !ti.TxCommitValid { 215 | t.Error("TxCommit context not valid") 216 | } 217 | } 218 | func TestConnBeginTx_PassWrappedTxContextRollback(t *testing.T) { 219 | driverName := driverName(t) 220 | 221 | con := &fakeConn{} 222 | fakeTx := &fakeTx{} 223 | con.tx = fakeTx 224 | 225 | ti := &connTestInterceptor{T: t} 226 | 227 | sql.Register( 228 | driverName, 229 | Driver(&fakeDriver{conn: con}, ti), 230 | ) 231 | 232 | db, err := sql.Open(driverName, "") 233 | if err != nil { 234 | t.Fatalf("Failed to open: %v", err) 235 | } 236 | 237 | t.Cleanup(func() { 238 | if err := db.Close(); err != nil { 239 | t.Errorf("Failed to close db: %v", err) 240 | } 241 | }) 242 | 243 | tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 244 | if err != nil { 245 | t.Fatalf("Prepare failed: %s", err) 246 | } 247 | 248 | err = tx.Rollback() 249 | if err != nil { 250 | t.Fatalf("Rollback failed: %s", err) 251 | } 252 | 253 | if !ti.TxRollbackValid { 254 | t.Error("TxRollback context not valid") 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /connector.go: -------------------------------------------------------------------------------- 1 | // +build go1.10 2 | 3 | package sqlmw 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | ) 9 | 10 | type wrappedConnector struct { 11 | parent driver.Connector 12 | driverRef *wrappedDriver 13 | } 14 | 15 | var ( 16 | _ driver.Connector = wrappedConnector{} 17 | ) 18 | 19 | func (c wrappedConnector) Connect(ctx context.Context) (conn driver.Conn, err error) { 20 | conn, err = c.driverRef.intr.ConnectorConnect(ctx, c.parent) 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | return wrappedConn{intr: c.driverRef.intr, parent: conn}, nil 26 | } 27 | 28 | func (c wrappedConnector) Driver() driver.Driver { 29 | return c.driverRef 30 | } 31 | 32 | // dsnConnector is a fallback connector placed in position of wrappedConnector.parent 33 | // when given Driver does not comply with DriverContext interface. 34 | type dsnConnector struct { 35 | dsn string 36 | driver driver.Driver 37 | } 38 | 39 | func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { 40 | return t.driver.Open(t.dsn) 41 | } 42 | 43 | func (t dsnConnector) Driver() driver.Driver { 44 | return t.driver 45 | } 46 | -------------------------------------------------------------------------------- /connector_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.10 2 | 3 | package sqlmw 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | "fmt" 9 | "testing" 10 | ) 11 | 12 | func TestConnectorWithDriverContext(t *testing.T) { 13 | err := fmt.Errorf("a generic error") 14 | 15 | tests := []struct { 16 | name string 17 | openConnectorErr error 18 | expectErr bool 19 | }{ 20 | { 21 | name: "should properly open connector and wrap it", 22 | }, 23 | { 24 | name: "should fail when calling OpenConnector", 25 | openConnectorErr: err, 26 | expectErr: true, 27 | }, 28 | } 29 | for _, test := range tests { 30 | t.Run(test.name, func(t *testing.T) { 31 | d := wrappedDriver{parent: &driverContextMock{err: test.openConnectorErr}} 32 | conn, err := d.OpenConnector("some-dsn") 33 | if err != nil { 34 | if test.expectErr { 35 | return 36 | } 37 | t.Fatalf("unexpected error from wrapped OpenConnector impl: %+v\n", err) 38 | } 39 | 40 | wc, ok := conn.(wrappedConnector) 41 | if !ok { 42 | t.Fatal("expected wrapped OpenConnector to return wrappedConnector instance") 43 | } 44 | 45 | _, ok = wc.parent.(*connMock) 46 | if !ok { 47 | t.Error("expected wrappedConnector to have connMock as parent") 48 | } 49 | }) 50 | } 51 | } 52 | 53 | func TestConnectorWithDriver(t *testing.T) { 54 | d := wrappedDriver{parent: &driverMock{}} 55 | conn, err := d.OpenConnector("some-dsn") 56 | if err != nil { 57 | t.Fatalf("unexpected error from wrapped OpenConnector impl: %+v\n", err) 58 | } 59 | 60 | wc, ok := conn.(wrappedConnector) 61 | if !ok { 62 | t.Fatal("expected wrapped OpenConnector to return wrappedConnector instance") 63 | } 64 | 65 | _, ok = wc.parent.(dsnConnector) 66 | if !ok { 67 | t.Error("expected wrappedConnector to have dsnConnector as parent") 68 | } 69 | } 70 | 71 | type driverMock struct{} 72 | 73 | func (d *driverMock) Open(name string) (driver.Conn, error) { 74 | panic("not implemented") 75 | } 76 | 77 | type driverContextMock struct { 78 | err error 79 | } 80 | 81 | func (d *driverContextMock) Open(name string) (driver.Conn, error) { 82 | panic("not implemented") 83 | } 84 | 85 | func (d *driverContextMock) OpenConnector(name string) (driver.Connector, error) { 86 | return &connMock{}, d.err 87 | } 88 | 89 | type connMock struct{} 90 | 91 | func (c *connMock) Connect(context.Context) (driver.Conn, error) { 92 | panic("not implemented") 93 | } 94 | 95 | func (c *connMock) Driver() driver.Driver { 96 | panic("not implemented") 97 | } 98 | -------------------------------------------------------------------------------- /contributors: -------------------------------------------------------------------------------- 1 | Luna Duclos 2 | Dominik Honnef 3 | Daniel Cormier 4 | Alan Shreve 5 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import "database/sql/driver" 4 | 5 | // driver wraps a sql.Driver with an interceptor. 6 | type wrappedDriver struct { 7 | intr Interceptor 8 | parent driver.Driver 9 | } 10 | 11 | // Compile time validation that our types implement the expected interfaces 12 | var ( 13 | _ driver.Driver = wrappedDriver{} 14 | ) 15 | 16 | // WrapDriver will wrap the passed SQL driver and return a new sql driver that uses it and also logs and traces calls using the passed logger and tracer 17 | // The returned driver will still have to be registered with the sql package before it can be used. 18 | // 19 | 20 | // Driver returns the supplied driver.Driver with a new object that has all of its calls intercepted by the supplied 21 | // Interceptor object. 22 | // 23 | // Important note: Seeing as the context passed into the various instrumentation calls this package calls, 24 | // Any call without a context passed will not be intercepted. Please be sure to use the ___Context() and BeginTx() 25 | // function calls added in Go 1.8 instead of the older calls which do not accept a context. 26 | func Driver(driver driver.Driver, intr Interceptor) driver.Driver { 27 | return wrappedDriver{parent: driver, intr: intr} 28 | } 29 | 30 | // Open implements the database/sql/driver.Driver interface for WrappedDriver. 31 | func (d wrappedDriver) Open(name string) (driver.Conn, error) { 32 | conn, err := d.parent.Open(name) 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | return wrappedConn{intr: d.intr, parent: conn}, nil 38 | } 39 | -------------------------------------------------------------------------------- /driver_go110.go: -------------------------------------------------------------------------------- 1 | // +build go1.10 2 | 3 | package sqlmw 4 | 5 | import "database/sql/driver" 6 | 7 | var _ driver.DriverContext = wrappedDriver{} 8 | 9 | func (d wrappedDriver) OpenConnector(name string) (driver.Connector, error) { 10 | driver, ok := d.parent.(driver.DriverContext) 11 | if !ok { 12 | return wrappedConnector{ 13 | parent: dsnConnector{dsn: name, driver: d.parent}, 14 | driverRef: &d, 15 | }, nil 16 | } 17 | conn, err := driver.OpenConnector(name) 18 | if err != nil { 19 | return nil, err 20 | } 21 | 22 | return wrappedConnector{parent: conn, driverRef: &d}, nil 23 | } 24 | -------------------------------------------------------------------------------- /fakedb_test.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "fmt" 7 | "io" 8 | "reflect" 9 | ) 10 | 11 | type fakeDriver struct { 12 | conn driver.Conn 13 | } 14 | 15 | func (d *fakeDriver) Open(_ string) (driver.Conn, error) { 16 | return d.conn, nil 17 | } 18 | 19 | type fakeTx struct{} 20 | 21 | func (f fakeTx) Commit() error { return nil } 22 | 23 | func (f fakeTx) Rollback() error { return nil } 24 | 25 | type fakeStmt struct { 26 | rows driver.Rows 27 | called bool // nolint:structcheck // ignore unused warning, it is accessed via reflection 28 | } 29 | 30 | type fakeStmtWithCheckNamedValue struct { 31 | fakeStmt 32 | } 33 | 34 | type fakeStmtWithoutCheckNamedValue struct { 35 | fakeStmt 36 | } 37 | 38 | func (s fakeStmt) Close() error { 39 | return nil 40 | } 41 | 42 | func (s fakeStmt) NumInput() int { 43 | return 1 44 | } 45 | 46 | func (s fakeStmt) Exec(_ []driver.Value) (driver.Result, error) { 47 | return nil, nil 48 | } 49 | 50 | func (s fakeStmt) Query(_ []driver.Value) (driver.Rows, error) { 51 | return s.rows, nil 52 | } 53 | 54 | func (s fakeStmt) QueryContext(_ context.Context, _ []driver.NamedValue) (driver.Rows, error) { 55 | return s.rows, nil 56 | } 57 | 58 | func (s *fakeStmtWithCheckNamedValue) CheckNamedValue(_ *driver.NamedValue) (err error) { 59 | s.called = true 60 | return 61 | } 62 | 63 | type fakeRows struct { 64 | con *fakeConn 65 | vals [][]driver.Value 66 | closeCalled bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 67 | nextCalled bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 68 | 69 | //These are here so that we can check things have not been called 70 | hasNextResultSetCalled bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 71 | nextResultSetCalled bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 72 | columnTypeDatabaseNameCalled bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 73 | columnTypeLengthCalled bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 74 | columnTypePrecisionScaleCalled bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 75 | columnTypeNullable bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 76 | columnTypeScanTypeCalled bool // nolint:structcheck,unused // ignore unused warning, it is accessed via reflection 77 | } 78 | 79 | func (r *fakeRows) Close() error { 80 | r.con.rowsCloseCalled = true 81 | r.closeCalled = true 82 | return nil 83 | } 84 | 85 | func (r *fakeRows) Columns() []string { 86 | if len(r.vals) == 0 { 87 | return nil 88 | } 89 | 90 | var cols []string 91 | for i := range r.vals[0] { 92 | cols = append(cols, fmt.Sprintf("col%d", i)) 93 | } 94 | return cols 95 | } 96 | 97 | func (r *fakeRows) Next(dest []driver.Value) error { 98 | r.nextCalled = true 99 | if len(r.vals) == 0 { 100 | return io.EOF 101 | } 102 | copy(dest, r.vals[0]) 103 | r.vals = r.vals[1:] 104 | return nil 105 | } 106 | 107 | type fakeWithRowsNextResultSet struct { 108 | r *fakeRows 109 | } 110 | 111 | func (f *fakeWithRowsNextResultSet) HasNextResultSet() bool { 112 | f.r.hasNextResultSetCalled = true 113 | return false 114 | } 115 | 116 | func (f *fakeWithRowsNextResultSet) NextResultSet() error { 117 | f.r.nextResultSetCalled = true 118 | return nil 119 | } 120 | 121 | type fakeWithColumnTypeDatabaseName struct { 122 | r *fakeRows 123 | names []string 124 | } 125 | 126 | func (f *fakeWithColumnTypeDatabaseName) ColumnTypeDatabaseTypeName(index int) string { 127 | f.r.columnTypeDatabaseNameCalled = true 128 | return f.names[index] 129 | } 130 | 131 | type fakeWithColumnTypeLength struct { 132 | r *fakeRows 133 | lengths []int64 134 | bools []bool 135 | } 136 | 137 | func (f *fakeWithColumnTypeLength) ColumnTypeLength(index int) (length int64, ok bool) { 138 | f.r.columnTypeLengthCalled = true 139 | return f.lengths[index], f.bools[index] 140 | } 141 | 142 | type fakeWithColumnTypePrecisionScale struct { 143 | r *fakeRows 144 | precisions, scales []int64 145 | bools []bool 146 | } 147 | 148 | func (f *fakeWithColumnTypePrecisionScale) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { 149 | f.r.columnTypePrecisionScaleCalled = true 150 | return f.precisions[index], f.scales[index], f.bools[index] 151 | } 152 | 153 | type fakeWithColumnTypeNullable struct { 154 | r *fakeRows 155 | nullables []bool 156 | oks []bool 157 | } 158 | 159 | func (f *fakeWithColumnTypeNullable) ColumnTypeNullable(index int) (nullable, ok bool) { 160 | f.r.columnTypeNullable = true 161 | return f.nullables[index], f.oks[index] 162 | } 163 | 164 | type fakeWithColumnTypeScanType struct { 165 | r *fakeRows 166 | scanTypes []reflect.Type 167 | } 168 | 169 | func (f *fakeWithColumnTypeScanType) ColumnTypeScanType(index int) reflect.Type { 170 | f.r.columnTypeScanTypeCalled = true 171 | return f.scanTypes[index] 172 | } 173 | 174 | type fakeRowsLikeMysql struct { 175 | fakeRows 176 | fakeWithRowsNextResultSet 177 | fakeWithColumnTypeDatabaseName 178 | fakeWithColumnTypePrecisionScale 179 | fakeWithColumnTypeNullable 180 | fakeWithColumnTypeScanType 181 | } 182 | 183 | // The set of interfaces support by pgx and sqlite3 184 | type fakeRowsLikePgx struct { 185 | fakeRows 186 | fakeWithColumnTypeDatabaseName 187 | fakeWithColumnTypeLength 188 | fakeWithColumnTypePrecisionScale 189 | fakeWithColumnTypeNullable 190 | fakeWithColumnTypeScanType 191 | } 192 | 193 | type fakeConn struct { 194 | called bool // nolint:structcheck // ignore unused warning, it is accessed via reflection 195 | rowsCloseCalled bool 196 | stmt driver.Stmt 197 | tx driver.Tx 198 | } 199 | 200 | type fakeConnWithCheckNamedValue struct { 201 | fakeConn 202 | } 203 | 204 | type fakeConnWithoutCheckNamedValue struct { 205 | fakeConn 206 | } 207 | 208 | func (c *fakeConn) Prepare(_ string) (driver.Stmt, error) { 209 | return nil, nil 210 | } 211 | 212 | func (c *fakeConn) PrepareContext(_ context.Context, _ string) (driver.Stmt, error) { 213 | return c.stmt, nil 214 | } 215 | 216 | func (c *fakeConn) ExecContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Result, error) { 217 | return nil, nil 218 | } 219 | 220 | func (c *fakeConn) Close() error { return nil } 221 | 222 | func (c *fakeConn) Begin() (driver.Tx, error) { return c.tx, nil } 223 | 224 | func (c *fakeConn) QueryContext(_ context.Context, _ string, nvs []driver.NamedValue) (driver.Rows, error) { 225 | if c.stmt == nil { 226 | return &fakeRows{con: c}, nil 227 | } 228 | 229 | var args []driver.Value 230 | for _, nv := range nvs { 231 | args = append(args, nv.Value) 232 | } 233 | 234 | return c.stmt.Query(args) 235 | } 236 | 237 | func (c *fakeConnWithCheckNamedValue) CheckNamedValue(_ *driver.NamedValue) (err error) { 238 | c.called = true 239 | return 240 | } 241 | 242 | type fakeInterceptor struct { 243 | NullInterceptor 244 | } 245 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ngrok/sqlmw 2 | 3 | go 1.13 4 | -------------------------------------------------------------------------------- /helpers.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | ) 7 | 8 | // namedValueToValue is a helper function copied from the database/sql package 9 | func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { 10 | dargs := make([]driver.Value, len(named)) 11 | for n, param := range named { 12 | if len(param.Name) > 0 { 13 | return nil, errors.New("sql: driver does not support the use of Named Parameters") 14 | } 15 | dargs[n] = param.Value 16 | } 17 | return dargs, nil 18 | } 19 | -------------------------------------------------------------------------------- /interceptor.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | ) 7 | 8 | type Interceptor interface { 9 | // Connection interceptors 10 | ConnBeginTx(context.Context, driver.ConnBeginTx, driver.TxOptions) (context.Context, driver.Tx, error) 11 | ConnPrepareContext(context.Context, driver.ConnPrepareContext, string) (context.Context, driver.Stmt, error) 12 | ConnPing(context.Context, driver.Pinger) error 13 | ConnExecContext(context.Context, driver.ExecerContext, string, []driver.NamedValue) (driver.Result, error) 14 | ConnQueryContext(context.Context, driver.QueryerContext, string, []driver.NamedValue) (context.Context, driver.Rows, error) 15 | 16 | // Connector interceptors 17 | ConnectorConnect(context.Context, driver.Connector) (driver.Conn, error) 18 | 19 | // Results interceptors 20 | ResultLastInsertId(driver.Result) (int64, error) 21 | ResultRowsAffected(driver.Result) (int64, error) 22 | 23 | // Rows interceptors 24 | RowsNext(context.Context, driver.Rows, []driver.Value) error 25 | RowsClose(context.Context, driver.Rows) error 26 | 27 | // Stmt interceptors 28 | StmtExecContext(context.Context, driver.StmtExecContext, string, []driver.NamedValue) (driver.Result, error) 29 | StmtQueryContext(context.Context, driver.StmtQueryContext, string, []driver.NamedValue) (context.Context, driver.Rows, error) 30 | StmtClose(context.Context, driver.Stmt) error 31 | 32 | // Tx interceptors 33 | TxCommit(context.Context, driver.Tx) error 34 | TxRollback(context.Context, driver.Tx) error 35 | } 36 | 37 | var _ Interceptor = NullInterceptor{} 38 | 39 | // NullInterceptor is a complete passthrough interceptor that implements every method of the Interceptor 40 | // interface and performs no additional logic. Users should Embed it in their own interceptor so that they 41 | // only need to define the specific functions they are interested in intercepting. 42 | type NullInterceptor struct{} 43 | 44 | func (NullInterceptor) ConnBeginTx(ctx context.Context, conn driver.ConnBeginTx, txOpts driver.TxOptions) (context.Context, driver.Tx, error) { 45 | t, err := conn.BeginTx(ctx, txOpts) 46 | return ctx, t, err 47 | } 48 | 49 | func (NullInterceptor) ConnPrepareContext(ctx context.Context, conn driver.ConnPrepareContext, query string) (context.Context, driver.Stmt, error) { 50 | s, err := conn.PrepareContext(ctx, query) 51 | return ctx, s, err 52 | } 53 | 54 | func (NullInterceptor) ConnPing(ctx context.Context, conn driver.Pinger) error { 55 | return conn.Ping(ctx) 56 | } 57 | 58 | func (NullInterceptor) ConnExecContext(ctx context.Context, conn driver.ExecerContext, query string, args []driver.NamedValue) (driver.Result, error) { 59 | return conn.ExecContext(ctx, query, args) 60 | } 61 | 62 | func (NullInterceptor) ConnQueryContext(ctx context.Context, conn driver.QueryerContext, query string, args []driver.NamedValue) (context.Context, driver.Rows, error) { 63 | r, err := conn.QueryContext(ctx, query, args) 64 | return ctx, r, err 65 | } 66 | 67 | func (NullInterceptor) ConnectorConnect(ctx context.Context, connect driver.Connector) (driver.Conn, error) { 68 | return connect.Connect(ctx) 69 | } 70 | 71 | func (NullInterceptor) ResultLastInsertId(res driver.Result) (int64, error) { 72 | return res.LastInsertId() 73 | } 74 | 75 | func (NullInterceptor) ResultRowsAffected(res driver.Result) (int64, error) { 76 | return res.RowsAffected() 77 | } 78 | 79 | func (NullInterceptor) RowsNext(ctx context.Context, rows driver.Rows, dest []driver.Value) error { 80 | return rows.Next(dest) 81 | } 82 | 83 | func (NullInterceptor) RowsClose(ctx context.Context, rows driver.Rows) error { 84 | return rows.Close() 85 | } 86 | 87 | func (NullInterceptor) StmtExecContext(ctx context.Context, stmt driver.StmtExecContext, _ string, args []driver.NamedValue) (driver.Result, error) { 88 | return stmt.ExecContext(ctx, args) 89 | } 90 | 91 | func (NullInterceptor) StmtQueryContext(ctx context.Context, stmt driver.StmtQueryContext, _ string, args []driver.NamedValue) (context.Context, driver.Rows, error) { 92 | r, err := stmt.QueryContext(ctx, args) 93 | return ctx, r, err 94 | } 95 | 96 | func (NullInterceptor) StmtClose(ctx context.Context, stmt driver.Stmt) error { 97 | return stmt.Close() 98 | } 99 | 100 | func (NullInterceptor) TxCommit(ctx context.Context, tx driver.Tx) error { 101 | return tx.Commit() 102 | } 103 | 104 | func (NullInterceptor) TxRollback(ctx context.Context, tx driver.Tx) error { 105 | return tx.Rollback() 106 | } 107 | -------------------------------------------------------------------------------- /result.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | ) 7 | 8 | type wrappedResult struct { 9 | intr Interceptor 10 | ctx context.Context 11 | parent driver.Result 12 | } 13 | 14 | func (r wrappedResult) LastInsertId() (id int64, err error) { 15 | return r.intr.ResultLastInsertId(r.parent) 16 | } 17 | 18 | func (r wrappedResult) RowsAffected() (num int64, err error) { 19 | return r.intr.ResultRowsAffected(r.parent) 20 | } 21 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "reflect" 7 | ) 8 | 9 | //go:generate go run ./tools/rows_picker_gen.go -o rows_picker.go 10 | 11 | // RowsUnwrapper must be used by any middleware that provides its own wrapping 12 | // for driver.Rows. Unwrap should return the original driver.Rows the 13 | // middleware received. You may wish to wrap the driver.Rows returned by the 14 | // Query methods if you want to pass extra information from the Query call to 15 | // the subsequent RowsNext and RowsClose calls. 16 | // 17 | // sqlmw needs to retrieve the original driver.Rows in order to determine the 18 | // original set of optional methods supported by the driver.Rows of the 19 | // database driver in use by the caller. 20 | // 21 | // If a middleware returns a custom driver.Rows, the custom implmentation 22 | // must support all the driver.Rows optional interfaces that are supported by 23 | // by the drivers that will be used with it. To support any arbitrary driver 24 | // all the optional methods must be supported. 25 | type RowsUnwrapper interface { 26 | Unwrap() driver.Rows 27 | } 28 | 29 | type wrappedRows struct { 30 | intr Interceptor 31 | ctx context.Context 32 | parent driver.Rows 33 | } 34 | 35 | func (r wrappedRows) Columns() []string { 36 | return r.parent.Columns() 37 | } 38 | 39 | func (r wrappedRows) Close() error { 40 | return r.intr.RowsClose(r.ctx, r.parent) 41 | } 42 | 43 | func (r wrappedRows) Next(dest []driver.Value) (err error) { 44 | return r.intr.RowsNext(r.ctx, r.parent, dest) 45 | } 46 | 47 | type wrappedRowsNextResultSet struct { 48 | rows driver.Rows 49 | } 50 | 51 | func (r wrappedRowsNextResultSet) HasNextResultSet() bool { 52 | return r.rows.(driver.RowsNextResultSet).HasNextResultSet() 53 | } 54 | 55 | func (r wrappedRowsNextResultSet) NextResultSet() error { 56 | return r.rows.(driver.RowsNextResultSet).NextResultSet() 57 | } 58 | 59 | type wrappedRowsColumnTypeDatabaseTypeName struct { 60 | rows driver.Rows 61 | } 62 | 63 | func (r wrappedRowsColumnTypeDatabaseTypeName) ColumnTypeDatabaseTypeName(index int) string { 64 | return r.rows.(driver.RowsColumnTypeDatabaseTypeName).ColumnTypeDatabaseTypeName(index) 65 | } 66 | 67 | type wrappedRowsColumnTypeLength struct { 68 | rows driver.Rows 69 | } 70 | 71 | func (r wrappedRowsColumnTypeLength) ColumnTypeLength(index int) (length int64, ok bool) { 72 | return r.rows.(driver.RowsColumnTypeLength).ColumnTypeLength(index) 73 | } 74 | 75 | type wrappedRowsColumnTypeNullable struct { 76 | rows driver.Rows 77 | } 78 | 79 | func (r wrappedRowsColumnTypeNullable) ColumnTypeNullable(index int) (nullable, ok bool) { 80 | return r.rows.(driver.RowsColumnTypeNullable).ColumnTypeNullable(index) 81 | } 82 | 83 | type wrappedRowsColumnTypePrecisionScale struct { 84 | rows driver.Rows 85 | } 86 | 87 | func (r wrappedRowsColumnTypePrecisionScale) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { 88 | return r.rows.(driver.RowsColumnTypePrecisionScale).ColumnTypePrecisionScale(index) 89 | } 90 | 91 | type wrappedRowsColumnTypeScanType struct { 92 | rows driver.Rows 93 | } 94 | 95 | func (r wrappedRowsColumnTypeScanType) ColumnTypeScanType(index int) reflect.Type { 96 | return r.rows.(driver.RowsColumnTypeScanType).ColumnTypeScanType(index) 97 | } 98 | -------------------------------------------------------------------------------- /rows_picker.go: -------------------------------------------------------------------------------- 1 | // Code generated using tool/rows_picker_gen.go DO NOT EDIT. 2 | // Date: Dec 20 09:54:15 3 | 4 | package sqlmw 5 | 6 | import ( 7 | "context" 8 | "database/sql/driver" 9 | ) 10 | 11 | const ( 12 | rowsNextResultSet = 1 << iota 13 | rowsColumnTypeDatabaseTypeName 14 | rowsColumnTypeLength 15 | rowsColumnTypeNullable 16 | rowsColumnTypePrecisionScale 17 | rowsColumnTypeScanType 18 | ) 19 | 20 | var pickRows = make([]func(*wrappedRows) driver.Rows, 64) 21 | 22 | func init() { 23 | 24 | // plain driver.Rows 25 | pickRows[0] = func(r *wrappedRows) driver.Rows { 26 | return r 27 | } 28 | 29 | // plain driver.Rows 30 | pickRows[1] = func(r *wrappedRows) driver.Rows { 31 | return struct { 32 | *wrappedRows 33 | wrappedRowsNextResultSet 34 | }{ 35 | r, 36 | wrappedRowsNextResultSet{r.parent}, 37 | } 38 | } 39 | 40 | // plain driver.Rows 41 | pickRows[2] = func(r *wrappedRows) driver.Rows { 42 | return struct { 43 | *wrappedRows 44 | wrappedRowsColumnTypeDatabaseTypeName 45 | }{ 46 | r, 47 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 48 | } 49 | } 50 | 51 | // plain driver.Rows 52 | pickRows[3] = func(r *wrappedRows) driver.Rows { 53 | return struct { 54 | *wrappedRows 55 | wrappedRowsNextResultSet 56 | wrappedRowsColumnTypeDatabaseTypeName 57 | }{ 58 | r, 59 | wrappedRowsNextResultSet{r.parent}, 60 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 61 | } 62 | } 63 | 64 | // plain driver.Rows 65 | pickRows[4] = func(r *wrappedRows) driver.Rows { 66 | return struct { 67 | *wrappedRows 68 | wrappedRowsColumnTypeLength 69 | }{ 70 | r, 71 | wrappedRowsColumnTypeLength{r.parent}, 72 | } 73 | } 74 | 75 | // plain driver.Rows 76 | pickRows[5] = func(r *wrappedRows) driver.Rows { 77 | return struct { 78 | *wrappedRows 79 | wrappedRowsNextResultSet 80 | wrappedRowsColumnTypeLength 81 | }{ 82 | r, 83 | wrappedRowsNextResultSet{r.parent}, 84 | wrappedRowsColumnTypeLength{r.parent}, 85 | } 86 | } 87 | 88 | // plain driver.Rows 89 | pickRows[6] = func(r *wrappedRows) driver.Rows { 90 | return struct { 91 | *wrappedRows 92 | wrappedRowsColumnTypeDatabaseTypeName 93 | wrappedRowsColumnTypeLength 94 | }{ 95 | r, 96 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 97 | wrappedRowsColumnTypeLength{r.parent}, 98 | } 99 | } 100 | 101 | // plain driver.Rows 102 | pickRows[7] = func(r *wrappedRows) driver.Rows { 103 | return struct { 104 | *wrappedRows 105 | wrappedRowsNextResultSet 106 | wrappedRowsColumnTypeDatabaseTypeName 107 | wrappedRowsColumnTypeLength 108 | }{ 109 | r, 110 | wrappedRowsNextResultSet{r.parent}, 111 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 112 | wrappedRowsColumnTypeLength{r.parent}, 113 | } 114 | } 115 | 116 | // plain driver.Rows 117 | pickRows[8] = func(r *wrappedRows) driver.Rows { 118 | return struct { 119 | *wrappedRows 120 | wrappedRowsColumnTypeNullable 121 | }{ 122 | r, 123 | wrappedRowsColumnTypeNullable{r.parent}, 124 | } 125 | } 126 | 127 | // plain driver.Rows 128 | pickRows[9] = func(r *wrappedRows) driver.Rows { 129 | return struct { 130 | *wrappedRows 131 | wrappedRowsNextResultSet 132 | wrappedRowsColumnTypeNullable 133 | }{ 134 | r, 135 | wrappedRowsNextResultSet{r.parent}, 136 | wrappedRowsColumnTypeNullable{r.parent}, 137 | } 138 | } 139 | 140 | // plain driver.Rows 141 | pickRows[10] = func(r *wrappedRows) driver.Rows { 142 | return struct { 143 | *wrappedRows 144 | wrappedRowsColumnTypeDatabaseTypeName 145 | wrappedRowsColumnTypeNullable 146 | }{ 147 | r, 148 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 149 | wrappedRowsColumnTypeNullable{r.parent}, 150 | } 151 | } 152 | 153 | // plain driver.Rows 154 | pickRows[11] = func(r *wrappedRows) driver.Rows { 155 | return struct { 156 | *wrappedRows 157 | wrappedRowsNextResultSet 158 | wrappedRowsColumnTypeDatabaseTypeName 159 | wrappedRowsColumnTypeNullable 160 | }{ 161 | r, 162 | wrappedRowsNextResultSet{r.parent}, 163 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 164 | wrappedRowsColumnTypeNullable{r.parent}, 165 | } 166 | } 167 | 168 | // plain driver.Rows 169 | pickRows[12] = func(r *wrappedRows) driver.Rows { 170 | return struct { 171 | *wrappedRows 172 | wrappedRowsColumnTypeLength 173 | wrappedRowsColumnTypeNullable 174 | }{ 175 | r, 176 | wrappedRowsColumnTypeLength{r.parent}, 177 | wrappedRowsColumnTypeNullable{r.parent}, 178 | } 179 | } 180 | 181 | // plain driver.Rows 182 | pickRows[13] = func(r *wrappedRows) driver.Rows { 183 | return struct { 184 | *wrappedRows 185 | wrappedRowsNextResultSet 186 | wrappedRowsColumnTypeLength 187 | wrappedRowsColumnTypeNullable 188 | }{ 189 | r, 190 | wrappedRowsNextResultSet{r.parent}, 191 | wrappedRowsColumnTypeLength{r.parent}, 192 | wrappedRowsColumnTypeNullable{r.parent}, 193 | } 194 | } 195 | 196 | // plain driver.Rows 197 | pickRows[14] = func(r *wrappedRows) driver.Rows { 198 | return struct { 199 | *wrappedRows 200 | wrappedRowsColumnTypeDatabaseTypeName 201 | wrappedRowsColumnTypeLength 202 | wrappedRowsColumnTypeNullable 203 | }{ 204 | r, 205 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 206 | wrappedRowsColumnTypeLength{r.parent}, 207 | wrappedRowsColumnTypeNullable{r.parent}, 208 | } 209 | } 210 | 211 | // plain driver.Rows 212 | pickRows[15] = func(r *wrappedRows) driver.Rows { 213 | return struct { 214 | *wrappedRows 215 | wrappedRowsNextResultSet 216 | wrappedRowsColumnTypeDatabaseTypeName 217 | wrappedRowsColumnTypeLength 218 | wrappedRowsColumnTypeNullable 219 | }{ 220 | r, 221 | wrappedRowsNextResultSet{r.parent}, 222 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 223 | wrappedRowsColumnTypeLength{r.parent}, 224 | wrappedRowsColumnTypeNullable{r.parent}, 225 | } 226 | } 227 | 228 | // plain driver.Rows 229 | pickRows[16] = func(r *wrappedRows) driver.Rows { 230 | return struct { 231 | *wrappedRows 232 | wrappedRowsColumnTypePrecisionScale 233 | }{ 234 | r, 235 | wrappedRowsColumnTypePrecisionScale{r.parent}, 236 | } 237 | } 238 | 239 | // plain driver.Rows 240 | pickRows[17] = func(r *wrappedRows) driver.Rows { 241 | return struct { 242 | *wrappedRows 243 | wrappedRowsNextResultSet 244 | wrappedRowsColumnTypePrecisionScale 245 | }{ 246 | r, 247 | wrappedRowsNextResultSet{r.parent}, 248 | wrappedRowsColumnTypePrecisionScale{r.parent}, 249 | } 250 | } 251 | 252 | // plain driver.Rows 253 | pickRows[18] = func(r *wrappedRows) driver.Rows { 254 | return struct { 255 | *wrappedRows 256 | wrappedRowsColumnTypeDatabaseTypeName 257 | wrappedRowsColumnTypePrecisionScale 258 | }{ 259 | r, 260 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 261 | wrappedRowsColumnTypePrecisionScale{r.parent}, 262 | } 263 | } 264 | 265 | // plain driver.Rows 266 | pickRows[19] = func(r *wrappedRows) driver.Rows { 267 | return struct { 268 | *wrappedRows 269 | wrappedRowsNextResultSet 270 | wrappedRowsColumnTypeDatabaseTypeName 271 | wrappedRowsColumnTypePrecisionScale 272 | }{ 273 | r, 274 | wrappedRowsNextResultSet{r.parent}, 275 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 276 | wrappedRowsColumnTypePrecisionScale{r.parent}, 277 | } 278 | } 279 | 280 | // plain driver.Rows 281 | pickRows[20] = func(r *wrappedRows) driver.Rows { 282 | return struct { 283 | *wrappedRows 284 | wrappedRowsColumnTypeLength 285 | wrappedRowsColumnTypePrecisionScale 286 | }{ 287 | r, 288 | wrappedRowsColumnTypeLength{r.parent}, 289 | wrappedRowsColumnTypePrecisionScale{r.parent}, 290 | } 291 | } 292 | 293 | // plain driver.Rows 294 | pickRows[21] = func(r *wrappedRows) driver.Rows { 295 | return struct { 296 | *wrappedRows 297 | wrappedRowsNextResultSet 298 | wrappedRowsColumnTypeLength 299 | wrappedRowsColumnTypePrecisionScale 300 | }{ 301 | r, 302 | wrappedRowsNextResultSet{r.parent}, 303 | wrappedRowsColumnTypeLength{r.parent}, 304 | wrappedRowsColumnTypePrecisionScale{r.parent}, 305 | } 306 | } 307 | 308 | // plain driver.Rows 309 | pickRows[22] = func(r *wrappedRows) driver.Rows { 310 | return struct { 311 | *wrappedRows 312 | wrappedRowsColumnTypeDatabaseTypeName 313 | wrappedRowsColumnTypeLength 314 | wrappedRowsColumnTypePrecisionScale 315 | }{ 316 | r, 317 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 318 | wrappedRowsColumnTypeLength{r.parent}, 319 | wrappedRowsColumnTypePrecisionScale{r.parent}, 320 | } 321 | } 322 | 323 | // plain driver.Rows 324 | pickRows[23] = func(r *wrappedRows) driver.Rows { 325 | return struct { 326 | *wrappedRows 327 | wrappedRowsNextResultSet 328 | wrappedRowsColumnTypeDatabaseTypeName 329 | wrappedRowsColumnTypeLength 330 | wrappedRowsColumnTypePrecisionScale 331 | }{ 332 | r, 333 | wrappedRowsNextResultSet{r.parent}, 334 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 335 | wrappedRowsColumnTypeLength{r.parent}, 336 | wrappedRowsColumnTypePrecisionScale{r.parent}, 337 | } 338 | } 339 | 340 | // plain driver.Rows 341 | pickRows[24] = func(r *wrappedRows) driver.Rows { 342 | return struct { 343 | *wrappedRows 344 | wrappedRowsColumnTypeNullable 345 | wrappedRowsColumnTypePrecisionScale 346 | }{ 347 | r, 348 | wrappedRowsColumnTypeNullable{r.parent}, 349 | wrappedRowsColumnTypePrecisionScale{r.parent}, 350 | } 351 | } 352 | 353 | // plain driver.Rows 354 | pickRows[25] = func(r *wrappedRows) driver.Rows { 355 | return struct { 356 | *wrappedRows 357 | wrappedRowsNextResultSet 358 | wrappedRowsColumnTypeNullable 359 | wrappedRowsColumnTypePrecisionScale 360 | }{ 361 | r, 362 | wrappedRowsNextResultSet{r.parent}, 363 | wrappedRowsColumnTypeNullable{r.parent}, 364 | wrappedRowsColumnTypePrecisionScale{r.parent}, 365 | } 366 | } 367 | 368 | // plain driver.Rows 369 | pickRows[26] = func(r *wrappedRows) driver.Rows { 370 | return struct { 371 | *wrappedRows 372 | wrappedRowsColumnTypeDatabaseTypeName 373 | wrappedRowsColumnTypeNullable 374 | wrappedRowsColumnTypePrecisionScale 375 | }{ 376 | r, 377 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 378 | wrappedRowsColumnTypeNullable{r.parent}, 379 | wrappedRowsColumnTypePrecisionScale{r.parent}, 380 | } 381 | } 382 | 383 | // plain driver.Rows 384 | pickRows[27] = func(r *wrappedRows) driver.Rows { 385 | return struct { 386 | *wrappedRows 387 | wrappedRowsNextResultSet 388 | wrappedRowsColumnTypeDatabaseTypeName 389 | wrappedRowsColumnTypeNullable 390 | wrappedRowsColumnTypePrecisionScale 391 | }{ 392 | r, 393 | wrappedRowsNextResultSet{r.parent}, 394 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 395 | wrappedRowsColumnTypeNullable{r.parent}, 396 | wrappedRowsColumnTypePrecisionScale{r.parent}, 397 | } 398 | } 399 | 400 | // plain driver.Rows 401 | pickRows[28] = func(r *wrappedRows) driver.Rows { 402 | return struct { 403 | *wrappedRows 404 | wrappedRowsColumnTypeLength 405 | wrappedRowsColumnTypeNullable 406 | wrappedRowsColumnTypePrecisionScale 407 | }{ 408 | r, 409 | wrappedRowsColumnTypeLength{r.parent}, 410 | wrappedRowsColumnTypeNullable{r.parent}, 411 | wrappedRowsColumnTypePrecisionScale{r.parent}, 412 | } 413 | } 414 | 415 | // plain driver.Rows 416 | pickRows[29] = func(r *wrappedRows) driver.Rows { 417 | return struct { 418 | *wrappedRows 419 | wrappedRowsNextResultSet 420 | wrappedRowsColumnTypeLength 421 | wrappedRowsColumnTypeNullable 422 | wrappedRowsColumnTypePrecisionScale 423 | }{ 424 | r, 425 | wrappedRowsNextResultSet{r.parent}, 426 | wrappedRowsColumnTypeLength{r.parent}, 427 | wrappedRowsColumnTypeNullable{r.parent}, 428 | wrappedRowsColumnTypePrecisionScale{r.parent}, 429 | } 430 | } 431 | 432 | // plain driver.Rows 433 | pickRows[30] = func(r *wrappedRows) driver.Rows { 434 | return struct { 435 | *wrappedRows 436 | wrappedRowsColumnTypeDatabaseTypeName 437 | wrappedRowsColumnTypeLength 438 | wrappedRowsColumnTypeNullable 439 | wrappedRowsColumnTypePrecisionScale 440 | }{ 441 | r, 442 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 443 | wrappedRowsColumnTypeLength{r.parent}, 444 | wrappedRowsColumnTypeNullable{r.parent}, 445 | wrappedRowsColumnTypePrecisionScale{r.parent}, 446 | } 447 | } 448 | 449 | // plain driver.Rows 450 | pickRows[31] = func(r *wrappedRows) driver.Rows { 451 | return struct { 452 | *wrappedRows 453 | wrappedRowsNextResultSet 454 | wrappedRowsColumnTypeDatabaseTypeName 455 | wrappedRowsColumnTypeLength 456 | wrappedRowsColumnTypeNullable 457 | wrappedRowsColumnTypePrecisionScale 458 | }{ 459 | r, 460 | wrappedRowsNextResultSet{r.parent}, 461 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 462 | wrappedRowsColumnTypeLength{r.parent}, 463 | wrappedRowsColumnTypeNullable{r.parent}, 464 | wrappedRowsColumnTypePrecisionScale{r.parent}, 465 | } 466 | } 467 | 468 | // plain driver.Rows 469 | pickRows[32] = func(r *wrappedRows) driver.Rows { 470 | return struct { 471 | *wrappedRows 472 | wrappedRowsColumnTypeScanType 473 | }{ 474 | r, 475 | wrappedRowsColumnTypeScanType{r.parent}, 476 | } 477 | } 478 | 479 | // plain driver.Rows 480 | pickRows[33] = func(r *wrappedRows) driver.Rows { 481 | return struct { 482 | *wrappedRows 483 | wrappedRowsNextResultSet 484 | wrappedRowsColumnTypeScanType 485 | }{ 486 | r, 487 | wrappedRowsNextResultSet{r.parent}, 488 | wrappedRowsColumnTypeScanType{r.parent}, 489 | } 490 | } 491 | 492 | // plain driver.Rows 493 | pickRows[34] = func(r *wrappedRows) driver.Rows { 494 | return struct { 495 | *wrappedRows 496 | wrappedRowsColumnTypeDatabaseTypeName 497 | wrappedRowsColumnTypeScanType 498 | }{ 499 | r, 500 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 501 | wrappedRowsColumnTypeScanType{r.parent}, 502 | } 503 | } 504 | 505 | // plain driver.Rows 506 | pickRows[35] = func(r *wrappedRows) driver.Rows { 507 | return struct { 508 | *wrappedRows 509 | wrappedRowsNextResultSet 510 | wrappedRowsColumnTypeDatabaseTypeName 511 | wrappedRowsColumnTypeScanType 512 | }{ 513 | r, 514 | wrappedRowsNextResultSet{r.parent}, 515 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 516 | wrappedRowsColumnTypeScanType{r.parent}, 517 | } 518 | } 519 | 520 | // plain driver.Rows 521 | pickRows[36] = func(r *wrappedRows) driver.Rows { 522 | return struct { 523 | *wrappedRows 524 | wrappedRowsColumnTypeLength 525 | wrappedRowsColumnTypeScanType 526 | }{ 527 | r, 528 | wrappedRowsColumnTypeLength{r.parent}, 529 | wrappedRowsColumnTypeScanType{r.parent}, 530 | } 531 | } 532 | 533 | // plain driver.Rows 534 | pickRows[37] = func(r *wrappedRows) driver.Rows { 535 | return struct { 536 | *wrappedRows 537 | wrappedRowsNextResultSet 538 | wrappedRowsColumnTypeLength 539 | wrappedRowsColumnTypeScanType 540 | }{ 541 | r, 542 | wrappedRowsNextResultSet{r.parent}, 543 | wrappedRowsColumnTypeLength{r.parent}, 544 | wrappedRowsColumnTypeScanType{r.parent}, 545 | } 546 | } 547 | 548 | // plain driver.Rows 549 | pickRows[38] = func(r *wrappedRows) driver.Rows { 550 | return struct { 551 | *wrappedRows 552 | wrappedRowsColumnTypeDatabaseTypeName 553 | wrappedRowsColumnTypeLength 554 | wrappedRowsColumnTypeScanType 555 | }{ 556 | r, 557 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 558 | wrappedRowsColumnTypeLength{r.parent}, 559 | wrappedRowsColumnTypeScanType{r.parent}, 560 | } 561 | } 562 | 563 | // plain driver.Rows 564 | pickRows[39] = func(r *wrappedRows) driver.Rows { 565 | return struct { 566 | *wrappedRows 567 | wrappedRowsNextResultSet 568 | wrappedRowsColumnTypeDatabaseTypeName 569 | wrappedRowsColumnTypeLength 570 | wrappedRowsColumnTypeScanType 571 | }{ 572 | r, 573 | wrappedRowsNextResultSet{r.parent}, 574 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 575 | wrappedRowsColumnTypeLength{r.parent}, 576 | wrappedRowsColumnTypeScanType{r.parent}, 577 | } 578 | } 579 | 580 | // plain driver.Rows 581 | pickRows[40] = func(r *wrappedRows) driver.Rows { 582 | return struct { 583 | *wrappedRows 584 | wrappedRowsColumnTypeNullable 585 | wrappedRowsColumnTypeScanType 586 | }{ 587 | r, 588 | wrappedRowsColumnTypeNullable{r.parent}, 589 | wrappedRowsColumnTypeScanType{r.parent}, 590 | } 591 | } 592 | 593 | // plain driver.Rows 594 | pickRows[41] = func(r *wrappedRows) driver.Rows { 595 | return struct { 596 | *wrappedRows 597 | wrappedRowsNextResultSet 598 | wrappedRowsColumnTypeNullable 599 | wrappedRowsColumnTypeScanType 600 | }{ 601 | r, 602 | wrappedRowsNextResultSet{r.parent}, 603 | wrappedRowsColumnTypeNullable{r.parent}, 604 | wrappedRowsColumnTypeScanType{r.parent}, 605 | } 606 | } 607 | 608 | // plain driver.Rows 609 | pickRows[42] = func(r *wrappedRows) driver.Rows { 610 | return struct { 611 | *wrappedRows 612 | wrappedRowsColumnTypeDatabaseTypeName 613 | wrappedRowsColumnTypeNullable 614 | wrappedRowsColumnTypeScanType 615 | }{ 616 | r, 617 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 618 | wrappedRowsColumnTypeNullable{r.parent}, 619 | wrappedRowsColumnTypeScanType{r.parent}, 620 | } 621 | } 622 | 623 | // plain driver.Rows 624 | pickRows[43] = func(r *wrappedRows) driver.Rows { 625 | return struct { 626 | *wrappedRows 627 | wrappedRowsNextResultSet 628 | wrappedRowsColumnTypeDatabaseTypeName 629 | wrappedRowsColumnTypeNullable 630 | wrappedRowsColumnTypeScanType 631 | }{ 632 | r, 633 | wrappedRowsNextResultSet{r.parent}, 634 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 635 | wrappedRowsColumnTypeNullable{r.parent}, 636 | wrappedRowsColumnTypeScanType{r.parent}, 637 | } 638 | } 639 | 640 | // plain driver.Rows 641 | pickRows[44] = func(r *wrappedRows) driver.Rows { 642 | return struct { 643 | *wrappedRows 644 | wrappedRowsColumnTypeLength 645 | wrappedRowsColumnTypeNullable 646 | wrappedRowsColumnTypeScanType 647 | }{ 648 | r, 649 | wrappedRowsColumnTypeLength{r.parent}, 650 | wrappedRowsColumnTypeNullable{r.parent}, 651 | wrappedRowsColumnTypeScanType{r.parent}, 652 | } 653 | } 654 | 655 | // plain driver.Rows 656 | pickRows[45] = func(r *wrappedRows) driver.Rows { 657 | return struct { 658 | *wrappedRows 659 | wrappedRowsNextResultSet 660 | wrappedRowsColumnTypeLength 661 | wrappedRowsColumnTypeNullable 662 | wrappedRowsColumnTypeScanType 663 | }{ 664 | r, 665 | wrappedRowsNextResultSet{r.parent}, 666 | wrappedRowsColumnTypeLength{r.parent}, 667 | wrappedRowsColumnTypeNullable{r.parent}, 668 | wrappedRowsColumnTypeScanType{r.parent}, 669 | } 670 | } 671 | 672 | // plain driver.Rows 673 | pickRows[46] = func(r *wrappedRows) driver.Rows { 674 | return struct { 675 | *wrappedRows 676 | wrappedRowsColumnTypeDatabaseTypeName 677 | wrappedRowsColumnTypeLength 678 | wrappedRowsColumnTypeNullable 679 | wrappedRowsColumnTypeScanType 680 | }{ 681 | r, 682 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 683 | wrappedRowsColumnTypeLength{r.parent}, 684 | wrappedRowsColumnTypeNullable{r.parent}, 685 | wrappedRowsColumnTypeScanType{r.parent}, 686 | } 687 | } 688 | 689 | // plain driver.Rows 690 | pickRows[47] = func(r *wrappedRows) driver.Rows { 691 | return struct { 692 | *wrappedRows 693 | wrappedRowsNextResultSet 694 | wrappedRowsColumnTypeDatabaseTypeName 695 | wrappedRowsColumnTypeLength 696 | wrappedRowsColumnTypeNullable 697 | wrappedRowsColumnTypeScanType 698 | }{ 699 | r, 700 | wrappedRowsNextResultSet{r.parent}, 701 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 702 | wrappedRowsColumnTypeLength{r.parent}, 703 | wrappedRowsColumnTypeNullable{r.parent}, 704 | wrappedRowsColumnTypeScanType{r.parent}, 705 | } 706 | } 707 | 708 | // plain driver.Rows 709 | pickRows[48] = func(r *wrappedRows) driver.Rows { 710 | return struct { 711 | *wrappedRows 712 | wrappedRowsColumnTypePrecisionScale 713 | wrappedRowsColumnTypeScanType 714 | }{ 715 | r, 716 | wrappedRowsColumnTypePrecisionScale{r.parent}, 717 | wrappedRowsColumnTypeScanType{r.parent}, 718 | } 719 | } 720 | 721 | // plain driver.Rows 722 | pickRows[49] = func(r *wrappedRows) driver.Rows { 723 | return struct { 724 | *wrappedRows 725 | wrappedRowsNextResultSet 726 | wrappedRowsColumnTypePrecisionScale 727 | wrappedRowsColumnTypeScanType 728 | }{ 729 | r, 730 | wrappedRowsNextResultSet{r.parent}, 731 | wrappedRowsColumnTypePrecisionScale{r.parent}, 732 | wrappedRowsColumnTypeScanType{r.parent}, 733 | } 734 | } 735 | 736 | // plain driver.Rows 737 | pickRows[50] = func(r *wrappedRows) driver.Rows { 738 | return struct { 739 | *wrappedRows 740 | wrappedRowsColumnTypeDatabaseTypeName 741 | wrappedRowsColumnTypePrecisionScale 742 | wrappedRowsColumnTypeScanType 743 | }{ 744 | r, 745 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 746 | wrappedRowsColumnTypePrecisionScale{r.parent}, 747 | wrappedRowsColumnTypeScanType{r.parent}, 748 | } 749 | } 750 | 751 | // plain driver.Rows 752 | pickRows[51] = func(r *wrappedRows) driver.Rows { 753 | return struct { 754 | *wrappedRows 755 | wrappedRowsNextResultSet 756 | wrappedRowsColumnTypeDatabaseTypeName 757 | wrappedRowsColumnTypePrecisionScale 758 | wrappedRowsColumnTypeScanType 759 | }{ 760 | r, 761 | wrappedRowsNextResultSet{r.parent}, 762 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 763 | wrappedRowsColumnTypePrecisionScale{r.parent}, 764 | wrappedRowsColumnTypeScanType{r.parent}, 765 | } 766 | } 767 | 768 | // plain driver.Rows 769 | pickRows[52] = func(r *wrappedRows) driver.Rows { 770 | return struct { 771 | *wrappedRows 772 | wrappedRowsColumnTypeLength 773 | wrappedRowsColumnTypePrecisionScale 774 | wrappedRowsColumnTypeScanType 775 | }{ 776 | r, 777 | wrappedRowsColumnTypeLength{r.parent}, 778 | wrappedRowsColumnTypePrecisionScale{r.parent}, 779 | wrappedRowsColumnTypeScanType{r.parent}, 780 | } 781 | } 782 | 783 | // plain driver.Rows 784 | pickRows[53] = func(r *wrappedRows) driver.Rows { 785 | return struct { 786 | *wrappedRows 787 | wrappedRowsNextResultSet 788 | wrappedRowsColumnTypeLength 789 | wrappedRowsColumnTypePrecisionScale 790 | wrappedRowsColumnTypeScanType 791 | }{ 792 | r, 793 | wrappedRowsNextResultSet{r.parent}, 794 | wrappedRowsColumnTypeLength{r.parent}, 795 | wrappedRowsColumnTypePrecisionScale{r.parent}, 796 | wrappedRowsColumnTypeScanType{r.parent}, 797 | } 798 | } 799 | 800 | // plain driver.Rows 801 | pickRows[54] = func(r *wrappedRows) driver.Rows { 802 | return struct { 803 | *wrappedRows 804 | wrappedRowsColumnTypeDatabaseTypeName 805 | wrappedRowsColumnTypeLength 806 | wrappedRowsColumnTypePrecisionScale 807 | wrappedRowsColumnTypeScanType 808 | }{ 809 | r, 810 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 811 | wrappedRowsColumnTypeLength{r.parent}, 812 | wrappedRowsColumnTypePrecisionScale{r.parent}, 813 | wrappedRowsColumnTypeScanType{r.parent}, 814 | } 815 | } 816 | 817 | // plain driver.Rows 818 | pickRows[55] = func(r *wrappedRows) driver.Rows { 819 | return struct { 820 | *wrappedRows 821 | wrappedRowsNextResultSet 822 | wrappedRowsColumnTypeDatabaseTypeName 823 | wrappedRowsColumnTypeLength 824 | wrappedRowsColumnTypePrecisionScale 825 | wrappedRowsColumnTypeScanType 826 | }{ 827 | r, 828 | wrappedRowsNextResultSet{r.parent}, 829 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 830 | wrappedRowsColumnTypeLength{r.parent}, 831 | wrappedRowsColumnTypePrecisionScale{r.parent}, 832 | wrappedRowsColumnTypeScanType{r.parent}, 833 | } 834 | } 835 | 836 | // plain driver.Rows 837 | pickRows[56] = func(r *wrappedRows) driver.Rows { 838 | return struct { 839 | *wrappedRows 840 | wrappedRowsColumnTypeNullable 841 | wrappedRowsColumnTypePrecisionScale 842 | wrappedRowsColumnTypeScanType 843 | }{ 844 | r, 845 | wrappedRowsColumnTypeNullable{r.parent}, 846 | wrappedRowsColumnTypePrecisionScale{r.parent}, 847 | wrappedRowsColumnTypeScanType{r.parent}, 848 | } 849 | } 850 | 851 | // plain driver.Rows 852 | pickRows[57] = func(r *wrappedRows) driver.Rows { 853 | return struct { 854 | *wrappedRows 855 | wrappedRowsNextResultSet 856 | wrappedRowsColumnTypeNullable 857 | wrappedRowsColumnTypePrecisionScale 858 | wrappedRowsColumnTypeScanType 859 | }{ 860 | r, 861 | wrappedRowsNextResultSet{r.parent}, 862 | wrappedRowsColumnTypeNullable{r.parent}, 863 | wrappedRowsColumnTypePrecisionScale{r.parent}, 864 | wrappedRowsColumnTypeScanType{r.parent}, 865 | } 866 | } 867 | 868 | // plain driver.Rows 869 | pickRows[58] = func(r *wrappedRows) driver.Rows { 870 | return struct { 871 | *wrappedRows 872 | wrappedRowsColumnTypeDatabaseTypeName 873 | wrappedRowsColumnTypeNullable 874 | wrappedRowsColumnTypePrecisionScale 875 | wrappedRowsColumnTypeScanType 876 | }{ 877 | r, 878 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 879 | wrappedRowsColumnTypeNullable{r.parent}, 880 | wrappedRowsColumnTypePrecisionScale{r.parent}, 881 | wrappedRowsColumnTypeScanType{r.parent}, 882 | } 883 | } 884 | 885 | // plain driver.Rows 886 | pickRows[59] = func(r *wrappedRows) driver.Rows { 887 | return struct { 888 | *wrappedRows 889 | wrappedRowsNextResultSet 890 | wrappedRowsColumnTypeDatabaseTypeName 891 | wrappedRowsColumnTypeNullable 892 | wrappedRowsColumnTypePrecisionScale 893 | wrappedRowsColumnTypeScanType 894 | }{ 895 | r, 896 | wrappedRowsNextResultSet{r.parent}, 897 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 898 | wrappedRowsColumnTypeNullable{r.parent}, 899 | wrappedRowsColumnTypePrecisionScale{r.parent}, 900 | wrappedRowsColumnTypeScanType{r.parent}, 901 | } 902 | } 903 | 904 | // plain driver.Rows 905 | pickRows[60] = func(r *wrappedRows) driver.Rows { 906 | return struct { 907 | *wrappedRows 908 | wrappedRowsColumnTypeLength 909 | wrappedRowsColumnTypeNullable 910 | wrappedRowsColumnTypePrecisionScale 911 | wrappedRowsColumnTypeScanType 912 | }{ 913 | r, 914 | wrappedRowsColumnTypeLength{r.parent}, 915 | wrappedRowsColumnTypeNullable{r.parent}, 916 | wrappedRowsColumnTypePrecisionScale{r.parent}, 917 | wrappedRowsColumnTypeScanType{r.parent}, 918 | } 919 | } 920 | 921 | // plain driver.Rows 922 | pickRows[61] = func(r *wrappedRows) driver.Rows { 923 | return struct { 924 | *wrappedRows 925 | wrappedRowsNextResultSet 926 | wrappedRowsColumnTypeLength 927 | wrappedRowsColumnTypeNullable 928 | wrappedRowsColumnTypePrecisionScale 929 | wrappedRowsColumnTypeScanType 930 | }{ 931 | r, 932 | wrappedRowsNextResultSet{r.parent}, 933 | wrappedRowsColumnTypeLength{r.parent}, 934 | wrappedRowsColumnTypeNullable{r.parent}, 935 | wrappedRowsColumnTypePrecisionScale{r.parent}, 936 | wrappedRowsColumnTypeScanType{r.parent}, 937 | } 938 | } 939 | 940 | // plain driver.Rows 941 | pickRows[62] = func(r *wrappedRows) driver.Rows { 942 | return struct { 943 | *wrappedRows 944 | wrappedRowsColumnTypeDatabaseTypeName 945 | wrappedRowsColumnTypeLength 946 | wrappedRowsColumnTypeNullable 947 | wrappedRowsColumnTypePrecisionScale 948 | wrappedRowsColumnTypeScanType 949 | }{ 950 | r, 951 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 952 | wrappedRowsColumnTypeLength{r.parent}, 953 | wrappedRowsColumnTypeNullable{r.parent}, 954 | wrappedRowsColumnTypePrecisionScale{r.parent}, 955 | wrappedRowsColumnTypeScanType{r.parent}, 956 | } 957 | } 958 | 959 | // plain driver.Rows 960 | pickRows[63] = func(r *wrappedRows) driver.Rows { 961 | return struct { 962 | *wrappedRows 963 | wrappedRowsNextResultSet 964 | wrappedRowsColumnTypeDatabaseTypeName 965 | wrappedRowsColumnTypeLength 966 | wrappedRowsColumnTypeNullable 967 | wrappedRowsColumnTypePrecisionScale 968 | wrappedRowsColumnTypeScanType 969 | }{ 970 | r, 971 | wrappedRowsNextResultSet{r.parent}, 972 | wrappedRowsColumnTypeDatabaseTypeName{r.parent}, 973 | wrappedRowsColumnTypeLength{r.parent}, 974 | wrappedRowsColumnTypeNullable{r.parent}, 975 | wrappedRowsColumnTypePrecisionScale{r.parent}, 976 | wrappedRowsColumnTypeScanType{r.parent}, 977 | } 978 | } 979 | } 980 | 981 | func wrapRows(ctx context.Context, intr Interceptor, r driver.Rows) driver.Rows { 982 | or := r 983 | for { 984 | ur, ok := or.(RowsUnwrapper) 985 | if !ok { 986 | break 987 | } 988 | or = ur.Unwrap() 989 | } 990 | 991 | id := 0 992 | 993 | if _, ok := or.(driver.RowsNextResultSet); ok { 994 | id += rowsNextResultSet 995 | } 996 | if _, ok := or.(driver.RowsColumnTypeDatabaseTypeName); ok { 997 | id += rowsColumnTypeDatabaseTypeName 998 | } 999 | if _, ok := or.(driver.RowsColumnTypeLength); ok { 1000 | id += rowsColumnTypeLength 1001 | } 1002 | if _, ok := or.(driver.RowsColumnTypeNullable); ok { 1003 | id += rowsColumnTypeNullable 1004 | } 1005 | if _, ok := or.(driver.RowsColumnTypePrecisionScale); ok { 1006 | id += rowsColumnTypePrecisionScale 1007 | } 1008 | if _, ok := or.(driver.RowsColumnTypeScanType); ok { 1009 | id += rowsColumnTypeScanType 1010 | } 1011 | wr := &wrappedRows{ 1012 | ctx: ctx, 1013 | intr: intr, 1014 | parent: r, 1015 | } 1016 | return pickRows[id](wr) 1017 | } 1018 | -------------------------------------------------------------------------------- /rows_test.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "fmt" 8 | "reflect" 9 | "sync/atomic" 10 | "testing" 11 | ) 12 | 13 | var driverCount = int32(0) 14 | 15 | func driverName(t *testing.T) string { 16 | c := atomic.LoadInt32(&driverCount) 17 | name := fmt.Sprintf("driver-%s-%d", t.Name(), c) 18 | c++ 19 | atomic.StoreInt32(&driverCount, c) 20 | 21 | return name 22 | } 23 | 24 | type rowsCloseInterceptor struct { 25 | NullInterceptor 26 | 27 | rowsCloseCalled bool 28 | rowsCloseLastCtx context.Context 29 | } 30 | 31 | func (r *rowsCloseInterceptor) RowsClose(ctx context.Context, rows driver.Rows) error { 32 | r.rowsCloseCalled = true 33 | r.rowsCloseLastCtx = ctx 34 | 35 | return rows.Close() 36 | } 37 | 38 | func TestRowsClose(t *testing.T) { 39 | driverName := driverName(t) 40 | interceptor := rowsCloseInterceptor{} 41 | 42 | con := fakeConn{} 43 | sql.Register(driverName, Driver(&fakeDriver{conn: &con}, &interceptor)) 44 | 45 | db, err := sql.Open(driverName, "") 46 | if err != nil { 47 | t.Fatalf("opening db failed: %s", err) 48 | } 49 | 50 | ctx := context.Background() 51 | ctxKey := "ctxkey" 52 | ctxVal := "1" 53 | 54 | ctx = context.WithValue(ctx, ctxKey, ctxVal) // nolint: staticcheck // not using a custom type for the ctx key is not an issue here 55 | 56 | rows, err := db.QueryContext(ctx, "", "") 57 | if err != nil { 58 | t.Fatalf("db.Query failed: %s", err) 59 | } 60 | 61 | err = rows.Close() 62 | if err != nil { 63 | t.Errorf("rows Close failed: %s", err) 64 | } 65 | 66 | if !interceptor.rowsCloseCalled { 67 | t.Error("interceptor rows.Close was not called") 68 | } 69 | 70 | if interceptor.rowsCloseLastCtx == nil { 71 | t.Fatal("rows close ctx is nil") 72 | } 73 | 74 | v := interceptor.rowsCloseLastCtx.Value(ctxKey) 75 | if v == nil { 76 | t.Fatalf("ctx is different, missing value for key: %s", ctxKey) 77 | } 78 | 79 | vStr, ok := v.(string) 80 | if !ok { 81 | t.Fatalf("ctx is different, value for key: %s, has type %t, expected string", ctxKey, v) 82 | } 83 | 84 | if ctxVal != vStr { 85 | t.Errorf("ctx is different, value for key: %s, is %q, expected %q", ctxKey, vStr, ctxVal) 86 | } 87 | 88 | if !con.rowsCloseCalled { 89 | t.Fatalf("rows close of driver was not called") 90 | } 91 | } 92 | 93 | type rowsNextInterceptor struct { 94 | NullInterceptor 95 | 96 | rowsNextCalled bool 97 | rowsNextLastCtx context.Context 98 | } 99 | 100 | func (r *rowsNextInterceptor) RowsNext(ctx context.Context, rows driver.Rows, dest []driver.Value) error { 101 | r.rowsNextCalled = true 102 | r.rowsNextLastCtx = ctx 103 | return rows.Next(dest) 104 | } 105 | 106 | func TestRowsNext(t *testing.T) { 107 | con := &fakeConn{} 108 | rows := &fakeRows{vals: [][]driver.Value{{"hello", "world"}}, con: con} 109 | stmt := fakeStmt{ 110 | rows: rows, 111 | } 112 | con.stmt = stmt 113 | driverName := driverName(t) 114 | interceptor := rowsNextInterceptor{} 115 | 116 | sql.Register( 117 | driverName, 118 | Driver(&fakeDriver{conn: con}, &interceptor), 119 | ) 120 | 121 | db, err := sql.Open(driverName, "") 122 | if err != nil { 123 | t.Fatalf("opening db failed: %s", err) 124 | } 125 | 126 | ctx := context.Background() 127 | ctxKey := "ctxkey" 128 | ctxVal := "1" 129 | 130 | ctx = context.WithValue(ctx, ctxKey, ctxVal) // nolint: staticcheck // not using a custom type for the ctx key is not an issue here 131 | 132 | rs, err := db.QueryContext(ctx, "", "") 133 | if err != nil { 134 | t.Fatalf("db.Query failed: %s", err) 135 | } 136 | 137 | var id, name string 138 | for rs.Next() { 139 | err := rs.Scan(&id, &name) 140 | if err != nil { 141 | t.Fatal(err) 142 | } 143 | } 144 | 145 | err = rs.Close() 146 | if err != nil { 147 | t.Errorf("rows Close failed: %s", err) 148 | } 149 | 150 | if !rows.nextCalled { 151 | t.Error("driver rows.Next was not called") 152 | } 153 | 154 | if !interceptor.rowsNextCalled { 155 | t.Error("interceptor rows.Next was not called") 156 | } 157 | 158 | if !interceptor.rowsNextCalled { 159 | t.Error("interceptor rows.Next was not called") 160 | } 161 | 162 | if interceptor.rowsNextLastCtx == nil { 163 | t.Fatal("rows close ctx is nil") 164 | } 165 | 166 | v := interceptor.rowsNextLastCtx.Value(ctxKey) 167 | if v == nil { 168 | t.Fatalf("ctx is different, missing value for key: %s", ctxKey) 169 | } 170 | 171 | vStr, ok := v.(string) 172 | if !ok { 173 | t.Fatalf("ctx is different, value for key: %s, has type %t, expected string", ctxKey, v) 174 | } 175 | 176 | if ctxVal != vStr { 177 | t.Errorf("ctx is different, value for key: %s, is %q, expected %q", ctxKey, vStr, ctxVal) 178 | } 179 | } 180 | 181 | func TestRows_LikePGX(t *testing.T) { 182 | strType := reflect.TypeOf("") 183 | con := &fakeConn{} 184 | rs := fakeRows{vals: [][]driver.Value{{"hello", "world"}}, con: con} 185 | rows := &fakeRowsLikePgx{ 186 | fakeRows: rs, 187 | fakeWithColumnTypeDatabaseName: fakeWithColumnTypeDatabaseName{r: &rs, names: []string{"CUSTOMVARCHAR", "CUSTOMVARCHAR"}}, 188 | fakeWithColumnTypeScanType: fakeWithColumnTypeScanType{r: &rs, scanTypes: []reflect.Type{strType, strType}}, 189 | fakeWithColumnTypeNullable: fakeWithColumnTypeNullable{r: &rs, nullables: []bool{false, false}, oks: []bool{true, true}}, 190 | fakeWithColumnTypeLength: fakeWithColumnTypeLength{r: &rs, lengths: []int64{5, 5}, bools: []bool{true, true}}, 191 | fakeWithColumnTypePrecisionScale: fakeWithColumnTypePrecisionScale{ 192 | r: &rs, 193 | precisions: []int64{0, 0}, 194 | scales: []int64{0, 0}, 195 | bools: []bool{false, false}, 196 | }, 197 | } 198 | 199 | stmt := fakeStmt{ 200 | rows: rows, 201 | } 202 | con.stmt = stmt 203 | driverName := driverName(t) 204 | interceptor := rowsNextInterceptor{} 205 | 206 | sql.Register( 207 | driverName, 208 | Driver(&fakeDriver{conn: con}, &interceptor), 209 | ) 210 | 211 | db, err := sql.Open(driverName, "") 212 | if err != nil { 213 | t.Fatalf("opening db failed: %s", err) 214 | } 215 | 216 | ctx := context.Background() 217 | ctxKey := "ctxkey" 218 | ctxVal := "1" 219 | 220 | ctx = context.WithValue(ctx, ctxKey, ctxVal) // nolint: staticcheck // not using a custom type for the ctx key is not an issue here 221 | 222 | qrs, err := db.QueryContext(ctx, "", "") 223 | if err != nil { 224 | t.Fatalf("db.Query failed: %s", err) 225 | } 226 | 227 | names, err := qrs.Columns() 228 | if err != nil { 229 | t.Errorf("error calling Columns, %v", err) 230 | } 231 | 232 | cts, err := qrs.ColumnTypes() 233 | if err != nil { 234 | t.Errorf("error calling ColumnTypes, %v", err) 235 | } 236 | 237 | if len(names) != 2 || len(names) != len(cts) { 238 | t.Errorf("wrong column name or types count") 239 | } 240 | 241 | // There's no real way these can be called, but we'll test in case the 242 | // test implementation changes 243 | if rs.hasNextResultSetCalled { 244 | t.Errorf("HasNextResultSetCalled, on non-supporting type, %v", err) 245 | } 246 | if rs.nextResultSetCalled { 247 | t.Errorf("NextResultSetCalled, on non-supporting type, %v", err) 248 | } 249 | 250 | if !rs.columnTypeDatabaseNameCalled { 251 | t.Errorf("ColumnTypeDatabaseName not called, %v", err) 252 | } 253 | 254 | if !rs.columnTypeLengthCalled { 255 | t.Errorf("ColumnTypeTypeLenght not called, %v", err) 256 | } 257 | 258 | if !rs.columnTypeNullable { 259 | t.Errorf("ColumnTypeTypeLenght not called, %v", err) 260 | } 261 | 262 | if !rs.columnTypePrecisionScaleCalled { 263 | t.Errorf("ColumnTypePrecisionScale not called, %v", err) 264 | } 265 | 266 | var id, name string 267 | for qrs.Next() { 268 | err := qrs.Scan(&id, &name) 269 | if err != nil { 270 | t.Fatal(err) 271 | } 272 | } 273 | 274 | err = qrs.Close() 275 | if err != nil { 276 | t.Errorf("rows Close failed: %s", err) 277 | } 278 | 279 | if !rows.nextCalled { 280 | t.Error("driver rows.Next was not called") 281 | } 282 | 283 | if !interceptor.rowsNextCalled { 284 | t.Error("interceptor rows.Next was not called") 285 | } 286 | 287 | if !interceptor.rowsNextCalled { 288 | t.Error("interceptor rows.Next was not called") 289 | } 290 | 291 | if interceptor.rowsNextLastCtx == nil { 292 | t.Fatal("rows close ctx is nil") 293 | } 294 | 295 | v := interceptor.rowsNextLastCtx.Value(ctxKey) 296 | if v == nil { 297 | t.Fatalf("ctx is different, missing value for key: %s", ctxKey) 298 | } 299 | 300 | vStr, ok := v.(string) 301 | if !ok { 302 | t.Fatalf("ctx is different, value for key: %s, has type %t, expected string", ctxKey, v) 303 | } 304 | 305 | if ctxVal != vStr { 306 | t.Errorf("ctx is different, value for key: %s, is %q, expected %q", ctxKey, vStr, ctxVal) 307 | } 308 | } 309 | 310 | func TestWrapRows(t *testing.T) { 311 | ctx := context.Background() 312 | tt := []struct { 313 | name string 314 | rows driver.Rows 315 | }{ 316 | { 317 | name: "vanilla", 318 | rows: &fakeRows{}, 319 | }, 320 | { 321 | name: "pgx", 322 | rows: &fakeRowsLikePgx{}, 323 | }, 324 | { 325 | name: "mysql", 326 | rows: &fakeRowsLikeMysql{}, 327 | }, 328 | } 329 | 330 | for _, st := range tt { 331 | st := st 332 | t.Run(st.name, func(t *testing.T) { 333 | rows := st.rows 334 | wr := wrapRows(ctx, nil, rows) 335 | 336 | _, rok := rows.(driver.RowsNextResultSet) 337 | _, wok := wr.(driver.RowsNextResultSet) 338 | if rok != wok { 339 | t.Fatalf("inconsistent support for driver.RowsNextResultSet") 340 | } 341 | 342 | _, rok = rows.(driver.RowsColumnTypeDatabaseTypeName) 343 | _, wok = wr.(driver.RowsColumnTypeDatabaseTypeName) 344 | if rok != wok { 345 | t.Fatalf("inconsinstent support for driver.RowsColumnTypeDatabaseTypeName") 346 | } 347 | 348 | _, rok = rows.(driver.RowsColumnTypeLength) 349 | _, wok = wr.(driver.RowsColumnTypeLength) 350 | if rok != wok { 351 | t.Fatalf("inconsinstent support for driver.RowsColumnTypeLength") 352 | } 353 | 354 | _, rok = rows.(driver.RowsColumnTypeNullable) 355 | _, wok = wr.(driver.RowsColumnTypeNullable) 356 | if rok != wok { 357 | t.Fatalf("inconsinstent support for driver.RowsColumnTypeNullable") 358 | } 359 | 360 | _, rok = rows.(driver.RowsColumnTypeScanType) 361 | _, wok = wr.(driver.RowsColumnTypeScanType) 362 | if rok != wok { 363 | t.Fatalf("inconsinstent support for driver.RowsColumnTypeScanType") 364 | } 365 | 366 | _, rok = rows.(driver.RowsColumnTypePrecisionScale) 367 | _, wok = wr.(driver.RowsColumnTypePrecisionScale) 368 | if rok != wok { 369 | t.Fatalf("inconsinstent support for driver.RowsColumnTypePrecisionScale") 370 | } 371 | }) 372 | } 373 | } 374 | -------------------------------------------------------------------------------- /stmt.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | ) 7 | 8 | type wrappedStmt struct { 9 | intr Interceptor 10 | ctx context.Context 11 | query string 12 | parent driver.Stmt 13 | conn wrappedConn 14 | } 15 | 16 | // Compile time validation that our types implement the expected interfaces 17 | var ( 18 | _ driver.Stmt = wrappedStmt{} 19 | _ driver.StmtExecContext = wrappedStmt{} 20 | _ driver.StmtQueryContext = wrappedStmt{} 21 | _ driver.ColumnConverter = wrappedStmt{} 22 | ) 23 | 24 | func (s wrappedStmt) Close() (err error) { 25 | return s.intr.StmtClose(s.ctx, s.parent) 26 | } 27 | 28 | func (s wrappedStmt) NumInput() int { 29 | return s.parent.NumInput() 30 | } 31 | 32 | func (s wrappedStmt) Exec(args []driver.Value) (res driver.Result, err error) { 33 | res, err = s.parent.Exec(args) 34 | if err != nil { 35 | return nil, err 36 | } 37 | return wrappedResult{intr: s.intr, ctx: s.ctx, parent: res}, nil 38 | } 39 | 40 | func (s wrappedStmt) Query(args []driver.Value) (rows driver.Rows, err error) { 41 | rows, err = s.parent.Query(args) 42 | if err != nil { 43 | return nil, err 44 | } 45 | return wrapRows(s.ctx, s.intr, rows), nil 46 | } 47 | 48 | func (s wrappedStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { 49 | wrappedParent := wrappedParentStmt{Stmt: s.parent} 50 | res, err = s.intr.StmtExecContext(ctx, wrappedParent, s.query, args) 51 | if err != nil { 52 | return nil, err 53 | } 54 | return wrappedResult{intr: s.intr, ctx: ctx, parent: res}, nil 55 | } 56 | 57 | func (s wrappedStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { 58 | wrappedParent := wrappedParentStmt{Stmt: s.parent} 59 | ctx, rows, err = s.intr.StmtQueryContext(ctx, wrappedParent, s.query, args) 60 | if err != nil { 61 | return nil, err 62 | } 63 | return wrapRows(ctx, s.intr, rows), nil 64 | } 65 | 66 | func (s wrappedStmt) ColumnConverter(idx int) driver.ValueConverter { 67 | if converter, ok := s.parent.(driver.ColumnConverter); ok { 68 | return converter.ColumnConverter(idx) 69 | } 70 | 71 | return driver.DefaultParameterConverter 72 | } 73 | 74 | type wrappedParentStmt struct { 75 | driver.Stmt 76 | } 77 | 78 | func (s wrappedParentStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { 79 | if stmtQueryContext, ok := s.Stmt.(driver.StmtQueryContext); ok { 80 | return stmtQueryContext.QueryContext(ctx, args) 81 | } 82 | // Fallback implementation 83 | dargs, err := namedValueToValue(args) 84 | if err != nil { 85 | return nil, err 86 | } 87 | select { 88 | default: 89 | case <-ctx.Done(): 90 | return nil, ctx.Err() 91 | } 92 | return s.Stmt.Query(dargs) 93 | } 94 | 95 | func (s wrappedParentStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { 96 | if stmtExecContext, ok := s.Stmt.(driver.StmtExecContext); ok { 97 | return stmtExecContext.ExecContext(ctx, args) 98 | } 99 | // Fallback implementation 100 | dargs, err := namedValueToValue(args) 101 | if err != nil { 102 | return nil, err 103 | } 104 | select { 105 | default: 106 | case <-ctx.Done(): 107 | return nil, ctx.Err() 108 | } 109 | return s.Stmt.Exec(dargs) 110 | } 111 | -------------------------------------------------------------------------------- /stmt_go19.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import "database/sql/driver" 4 | 5 | var _ driver.NamedValueChecker = wrappedStmt{} 6 | 7 | func (s wrappedStmt) CheckNamedValue(v *driver.NamedValue) error { 8 | if checker, ok := s.parent.(driver.NamedValueChecker); ok { 9 | return checker.CheckNamedValue(v) 10 | } 11 | 12 | if checker, ok := s.conn.parent.(driver.NamedValueChecker); ok { 13 | return checker.CheckNamedValue(v) 14 | } 15 | 16 | return defaultCheckNamedValue(v) 17 | } 18 | -------------------------------------------------------------------------------- /stmt_go19_test.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | // TestDefaultParameterConversion ensures that 11 | // driver.DefaultParameterConverter is used when neither stmt nor con 12 | // implements any value converters. 13 | func TestDefaultParameterConversion(t *testing.T) { 14 | driverName := driverName(t) 15 | 16 | expectVal := int64(1) 17 | con := &fakeConn{} 18 | fakeStmt := &fakeStmt{ 19 | rows: &fakeRows{ 20 | con: con, 21 | vals: [][]driver.Value{{expectVal}}, 22 | }, 23 | } 24 | con.stmt = fakeStmt 25 | 26 | sql.Register( 27 | driverName, 28 | Driver(&fakeDriver{conn: con}, &NullInterceptor{}), 29 | ) 30 | 31 | db, err := sql.Open(driverName, "") 32 | if err != nil { 33 | t.Fatalf("Failed to open: %v", err) 34 | } 35 | 36 | t.Cleanup(func() { 37 | if err := db.Close(); err != nil { 38 | t.Errorf("Failed to close db: %v", err) 39 | } 40 | }) 41 | 42 | stmt, err := db.Prepare("") 43 | if err != nil { 44 | t.Fatalf("Prepare failed: %s", err) 45 | } 46 | 47 | // int32 values are converted by driver.DefaultParameterConverter to 48 | // int64 49 | queryVal := int32(1) 50 | rows, err := stmt.Query(queryVal) 51 | if err != nil { 52 | t.Fatalf("Query failed: %s", err) 53 | } 54 | 55 | count := 0 56 | for rows.Next() { 57 | var v int64 58 | err := rows.Scan(&v) 59 | if err != nil { 60 | t.Fatalf("rows.Scan failed, %v", err) 61 | } 62 | if v != 1 { 63 | t.Errorf("converted value is %d, passed value to Query was: %d", v, expectVal) 64 | } 65 | count++ 66 | } 67 | 68 | if count != 1 { 69 | t.Fatalf("got too many rows, expected 1, got %d ", 1) 70 | } 71 | } 72 | 73 | func TestWrappedStmt_CheckNamedValue(t *testing.T) { 74 | tests := map[string]struct { 75 | fd *fakeDriver 76 | expected struct { 77 | cc bool // Whether the fakeConn's CheckNamedValue was called 78 | sc bool // Whether the fakeStmt's CheckNamedValue was called 79 | } 80 | }{ 81 | "When both conn and stmt implement CheckNamedValue": { 82 | fd: &fakeDriver{ 83 | conn: &fakeConnWithCheckNamedValue{ 84 | fakeConn: fakeConn{ 85 | stmt: &fakeStmtWithCheckNamedValue{}, 86 | }, 87 | }, 88 | }, 89 | expected: struct { 90 | cc bool 91 | sc bool 92 | }{cc: false, sc: true}, 93 | }, 94 | "When only conn implements CheckNamedValue": { 95 | fd: &fakeDriver{ 96 | conn: &fakeConnWithCheckNamedValue{ 97 | fakeConn: fakeConn{ 98 | stmt: &fakeStmtWithoutCheckNamedValue{}, 99 | }, 100 | }, 101 | }, 102 | expected: struct { 103 | cc bool 104 | sc bool 105 | }{cc: true, sc: false}, 106 | }, 107 | "When only stmt implements CheckNamedValue": { 108 | fd: &fakeDriver{ 109 | conn: &fakeConnWithoutCheckNamedValue{ 110 | fakeConn: fakeConn{ 111 | stmt: &fakeStmtWithCheckNamedValue{}, 112 | }, 113 | }, 114 | }, 115 | expected: struct { 116 | cc bool 117 | sc bool 118 | }{cc: false, sc: true}, 119 | }, 120 | "When both stmt do not implement CheckNamedValue": { 121 | fd: &fakeDriver{ 122 | conn: &fakeConnWithoutCheckNamedValue{ 123 | fakeConn: fakeConn{ 124 | stmt: &fakeStmtWithoutCheckNamedValue{}, 125 | }, 126 | }, 127 | }, 128 | expected: struct { 129 | cc bool 130 | sc bool 131 | }{cc: false, sc: false}, 132 | }, 133 | } 134 | 135 | for name, test := range tests { 136 | t.Run(name, func(t *testing.T) { 137 | driverName := driverName(t) 138 | sql.Register(driverName, Driver(test.fd, &fakeInterceptor{})) 139 | db, err := sql.Open(driverName, "dummy") 140 | if err != nil { 141 | t.Errorf("Failed to open: %v", err) 142 | } 143 | defer func() { 144 | if err := db.Close(); err != nil { 145 | t.Errorf("Failed to close db: %v", err) 146 | } 147 | }() 148 | 149 | stmt, err := db.Prepare("SELECT foo FROM bar Where 1 = ?") 150 | if err != nil { 151 | t.Errorf("Failed to prepare: %v", err) 152 | } 153 | 154 | if _, err := stmt.Query(1); err != nil { 155 | t.Errorf("Failed to query: %v", err) 156 | } 157 | 158 | conn := reflect.ValueOf(test.fd.conn).Elem() 159 | sc := conn.FieldByName("stmt").Elem().Elem().FieldByName("called").Bool() 160 | cc := conn.FieldByName("called").Bool() 161 | 162 | if test.expected.sc != sc { 163 | t.Errorf("sc mismatch.\n got: %#v\nwant: %#v", sc, test.expected.sc) 164 | } 165 | 166 | if test.expected.cc != cc { 167 | t.Errorf("cc mismatch.\n got: %#v\nwant: %#v", cc, test.expected.cc) 168 | } 169 | }) 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /stmt_test.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "testing" 8 | ) 9 | 10 | type stmtCtxKey string 11 | 12 | const ( 13 | stmtRowContextKey stmtCtxKey = "rowcontext" 14 | stmtRowContextValue string = "rowvalue" 15 | ) 16 | 17 | type stmtTestInterceptor struct { 18 | T *testing.T 19 | RowsNextValid bool 20 | RowsCloseValid bool 21 | NullInterceptor 22 | } 23 | 24 | func (i *stmtTestInterceptor) StmtQueryContext(ctx context.Context, stmt driver.StmtQueryContext, _ string, args []driver.NamedValue) (context.Context, driver.Rows, error) { 25 | ctx = context.WithValue(ctx, stmtRowContextKey, stmtRowContextValue) 26 | 27 | r, err := stmt.QueryContext(ctx, args) 28 | return ctx, r, err 29 | } 30 | 31 | func (i *stmtTestInterceptor) RowsNext(ctx context.Context, rows driver.Rows, dest []driver.Value) error { 32 | if ctx.Value(stmtRowContextKey) == stmtRowContextValue { 33 | i.RowsNextValid = true 34 | } 35 | 36 | i.T.Log(ctx) 37 | 38 | return rows.Next(dest) 39 | } 40 | 41 | func (i *stmtTestInterceptor) RowsClose(ctx context.Context, rows driver.Rows) error { 42 | if ctx.Value(stmtRowContextKey) == stmtRowContextValue { 43 | i.RowsCloseValid = true 44 | } 45 | 46 | i.T.Log(ctx) 47 | 48 | return rows.Close() 49 | } 50 | 51 | func TestStmtQueryContext_PassWrappedRowContext(t *testing.T) { 52 | driverName := driverName(t) 53 | 54 | con := &fakeConn{} 55 | fakeStmt := &fakeStmt{ 56 | rows: &fakeRows{ 57 | con: con, 58 | vals: [][]driver.Value{{}}, 59 | }, 60 | } 61 | con.stmt = fakeStmt 62 | 63 | ti := &stmtTestInterceptor{T: t} 64 | 65 | sql.Register( 66 | driverName, 67 | Driver(&fakeDriver{conn: con}, ti), 68 | ) 69 | 70 | db, err := sql.Open(driverName, "") 71 | if err != nil { 72 | t.Fatalf("Failed to open: %v", err) 73 | } 74 | 75 | t.Cleanup(func() { 76 | if err := db.Close(); err != nil { 77 | t.Errorf("Failed to close db: %v", err) 78 | } 79 | }) 80 | 81 | stmt, err := db.PrepareContext(context.Background(), "") 82 | if err != nil { 83 | t.Fatalf("Prepare failed: %s", err) 84 | } 85 | 86 | rows, err := stmt.Query("") 87 | if err != nil { 88 | t.Fatalf("Stmt query failed: %s", err) 89 | } 90 | 91 | rows.Next() 92 | rows.Close() 93 | stmt.Close() 94 | 95 | if !ti.RowsNextValid { 96 | t.Error("RowsNext context not valid") 97 | } 98 | if !ti.RowsCloseValid { 99 | t.Error("RowsClose context not valid") 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /tools/rows_picker_gen.go: -------------------------------------------------------------------------------- 1 | // +build ignore 2 | 3 | package main 4 | 5 | import ( 6 | "flag" 7 | "fmt" 8 | "io" 9 | "log" 10 | "os" 11 | "time" 12 | ) 13 | 14 | func main() { 15 | var err error 16 | fn := flag.String("o", "", "output file") 17 | flag.Parse() 18 | 19 | out := os.Stdout 20 | if *fn != "" { 21 | out, err = os.Create(*fn) 22 | if err != nil { 23 | log.Fatalf("could not create file %q, %v", *fn, err) 24 | } 25 | } 26 | 27 | intfs := []string{ 28 | "NextResultSet", 29 | "ColumnTypeDatabaseTypeName", 30 | "ColumnTypeLength", 31 | "ColumnTypeNullable", 32 | "ColumnTypePrecisionScale", 33 | "ColumnTypeScanType", 34 | } 35 | 36 | genComment(out) 37 | fmt.Fprintln(out, "package sqlmw") 38 | 39 | fmt.Fprintln(out, "") 40 | fmt.Fprintln(out, "import (") 41 | fmt.Fprintln(out, "\t\"context\"") 42 | fmt.Fprintln(out, "\t\"database/sql/driver\"") 43 | fmt.Fprintln(out, ")") 44 | 45 | fmt.Fprintln(out, "") 46 | genConst(out, intfs) 47 | 48 | fmt.Fprintln(out, "") 49 | genPickerTable(out, intfs) 50 | 51 | fmt.Fprintln(out, "") 52 | genWrapRows(out, intfs) 53 | 54 | err = out.Close() 55 | if err != nil { 56 | log.Fatalf("could close file, %v", err) 57 | } 58 | } 59 | 60 | func genComment(w io.Writer) { 61 | str := time.Now().Format(time.Stamp) 62 | fmt.Fprintln(w, "// Code generated using tool/rows_picker_gen.go DO NOT EDIT.") 63 | fmt.Fprintf(w, "// Date: %s\n", str) 64 | fmt.Fprintln(w, "") 65 | } 66 | 67 | func genConst(w io.Writer, intfs []string) { 68 | fmt.Fprintln(w, "const (") 69 | for i, n := range intfs { 70 | suf := "" 71 | if i == 0 { 72 | suf = " = 1 << iota" 73 | } 74 | fmt.Fprintf(w, "\trows%s%s\n", n, suf) 75 | } 76 | fmt.Fprintln(w, ")") 77 | } 78 | 79 | func forEachBit(n int, intfs []string, f func(n int, intf string)) { 80 | for i := 0; i < len(intfs); i++ { 81 | b := 1 << i 82 | if b&n == b { 83 | f(n, intfs[i]) 84 | } 85 | } 86 | } 87 | 88 | func genPickerTable(w io.Writer, intfs []string) { 89 | tlen := 1 << len(intfs) 90 | fmt.Fprintf(w, "var pickRows = make([]func(*wrappedRows) driver.Rows, %d)\n\n", tlen) 91 | 92 | fmt.Fprintln(w, "func init() {") 93 | defer fmt.Fprintln(w, "}") 94 | 95 | fmt.Fprintln(w, ` 96 | // plain driver.Rows 97 | pickRows[0] = func(r *wrappedRows) driver.Rows { 98 | return r 99 | }`) 100 | 101 | for i := 1; i < tlen; i++ { 102 | fmt.Fprintf(w, ` 103 | // plain driver.Rows 104 | pickRows[%d] = func(r *wrappedRows) driver.Rows { 105 | return struct { 106 | *wrappedRows`, i) 107 | fmt.Fprintln(w, "") 108 | forEachBit(i, intfs, func(_ int, intf string) { 109 | fmt.Fprintf(w, "\t\t\twrappedRows%s\n", intf) 110 | }) 111 | fmt.Fprintln(w, "\t\t}{\n\t\t\tr,") 112 | forEachBit(i, intfs, func(_ int, intf string) { 113 | fmt.Fprintf(w, "\t\t\twrappedRows%s{r.parent},\n", intf) 114 | }) 115 | fmt.Fprintln(w, "\t\t}") 116 | fmt.Fprintln(w, "\t}") 117 | } 118 | } 119 | 120 | func genWrapRows(w io.Writer, intfs []string) { 121 | fmt.Fprintln(w, "func wrapRows(ctx context.Context, intr Interceptor, r driver.Rows) driver.Rows {") 122 | fmt.Fprintln(w, ` or := r 123 | for { 124 | ur, ok := or.(RowsUnwrapper) 125 | if !ok { 126 | break 127 | } 128 | or = ur.Unwrap() 129 | } 130 | 131 | id := 0`) 132 | 133 | defer fmt.Fprintln(w, ` 134 | wr := &wrappedRows{ 135 | ctx: ctx, 136 | intr: intr, 137 | parent: r, 138 | } 139 | return pickRows[id](wr) 140 | }`) 141 | 142 | for _, n := range intfs { 143 | fmt.Fprintf(w, ` 144 | if _, ok := or.(driver.Rows%s); ok { 145 | id += rows%[1]s 146 | }`, n) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | package sqlmw 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | ) 7 | 8 | type wrappedTx struct { 9 | intr Interceptor 10 | ctx context.Context 11 | parent driver.Tx 12 | } 13 | 14 | // Compile time validation that our types implement the expected interfaces 15 | var ( 16 | _ driver.Tx = wrappedTx{} 17 | ) 18 | 19 | func (t wrappedTx) Commit() (err error) { 20 | return t.intr.TxCommit(t.ctx, t.parent) 21 | } 22 | 23 | func (t wrappedTx) Rollback() (err error) { 24 | return t.intr.TxRollback(t.ctx, t.parent) 25 | } 26 | --------------------------------------------------------------------------------