├── go.mod ├── .gitignore ├── LICENSE ├── sqlrange_test.go ├── sqlrange.go ├── README.md └── fakedb_test.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/achille-roussel/sqlrange 2 | 3 | go 1.23 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Achille 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ================================================================================ 24 | 25 | Copyright (c) 2009 The Go Authors. All rights reserved. 26 | 27 | Redistribution and use in source and binary forms, with or without 28 | modification, are permitted provided that the following conditions are 29 | met: 30 | 31 | * Redistributions of source code must retain the above copyright 32 | notice, this list of conditions and the following disclaimer. 33 | * Redistributions in binary form must reproduce the above 34 | copyright notice, this list of conditions and the following disclaimer 35 | in the documentation and/or other materials provided with the 36 | distribution. 37 | * Neither the name of Google Inc. nor the names of its 38 | contributors may be used to endorse or promote products derived from 39 | this software without specific prior written permission. 40 | 41 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 42 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 43 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 44 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 45 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 46 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 47 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 48 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 49 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 50 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 51 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 52 | -------------------------------------------------------------------------------- /sqlrange_test.go: -------------------------------------------------------------------------------- 1 | package sqlrange_test 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "slices" 7 | "testing" 8 | "time" 9 | 10 | "github.com/achille-roussel/sqlrange" 11 | ) 12 | 13 | func ExampleExec() { 14 | type Row struct { 15 | Age int `sql:"age"` 16 | Name string `sql:"name"` 17 | } 18 | 19 | db := newTestDB(new(testing.T), "people") 20 | defer db.Close() 21 | 22 | for res, err := range sqlrange.Exec(db, `INSERT|people|name=?,age=?`, 23 | func(yield func(Row, error) bool) { 24 | _ = yield(Row{Age: 19, Name: "Luke"}, nil) && 25 | yield(Row{Age: 42, Name: "Hitchhiker"}, nil) 26 | }, 27 | sqlrange.ExecArgsFields[Row]("name", "age"), 28 | ) { 29 | if err != nil { 30 | log.Fatal(err) 31 | } 32 | rowsAffected, err := res.RowsAffected() 33 | if err != nil { 34 | log.Fatal(err) 35 | } 36 | fmt.Println(rowsAffected) 37 | } 38 | 39 | // Output: 40 | // 1 41 | // 1 42 | } 43 | 44 | func ExampleQuery() { 45 | type Row struct { 46 | Age int `sql:"age"` 47 | Name string `sql:"name"` 48 | } 49 | 50 | db := newTestDB(new(testing.T), "people") 51 | defer db.Close() 52 | 53 | for row, err := range sqlrange.Query[Row](db, `SELECT|people|age,name|`) { 54 | if err != nil { 55 | log.Fatal(err) 56 | } 57 | fmt.Println(row) 58 | } 59 | 60 | // Output: 61 | // {1 Alice} 62 | // {2 Bob} 63 | // {3 Chris} 64 | } 65 | 66 | type person struct { 67 | Age int `sql:"age"` 68 | Name string `sql:"name"` 69 | BirthDate time.Time `sql:"bdate"` 70 | } 71 | 72 | func TestExec(t *testing.T) { 73 | db := newTestDB(t, "people") 74 | defer db.Close() 75 | 76 | for res, err := range sqlrange.Exec(db, `INSERT|people|name=?,age=?`, 77 | func(yield func(person, error) bool) { 78 | for _, p := range []person{ 79 | {Age: 19, Name: "Luke"}, 80 | {Age: 42, Name: "Hitchhiker"}, 81 | } { 82 | if !yield(p, nil) { 83 | return 84 | } 85 | } 86 | }, 87 | sqlrange.ExecArgsFields[person]("name", "age"), 88 | ) { 89 | if err != nil { 90 | t.Fatal(err) 91 | } 92 | if n, err := res.RowsAffected(); err != nil { 93 | t.Fatal(err) 94 | } else if n != 1 { 95 | t.Errorf("expect 1, got %d", n) 96 | } 97 | } 98 | } 99 | 100 | func TestQuery(t *testing.T) { 101 | db := newTestDB(t, "people") 102 | defer db.Close() 103 | 104 | var people []person 105 | for p, err := range sqlrange.Query[person](db, `SELECT|people|age,name|`) { 106 | if err != nil { 107 | t.Fatal(err) 108 | } 109 | people = append(people, p) 110 | } 111 | 112 | expect := []person{ 113 | {Age: 1, Name: "Alice"}, 114 | {Age: 2, Name: "Bob"}, 115 | {Age: 3, Name: "Chris"}, 116 | } 117 | 118 | if !slices.Equal(people, expect) { 119 | t.Errorf("expect %v, got %v", expect, people) 120 | } 121 | } 122 | 123 | func BenchmarkQuery100Rows(b *testing.B) { 124 | const N = 500 125 | 126 | db := newTestDB(b, "people") 127 | defer db.Close() 128 | 129 | for _, err := range sqlrange.Exec(db, `INSERT|people|age=?,name=?,bdate=?`, func(yield func(person, error) bool) { 130 | for i := range N { 131 | if !yield(person{ 132 | Age: i, 133 | Name: fmt.Sprintf("Person %d", i), 134 | }, nil) { 135 | break 136 | } 137 | } 138 | }) { 139 | if err != nil { 140 | b.Fatal(err) 141 | } 142 | } 143 | 144 | for n := b.N; n > 0; { 145 | for _, err := range sqlrange.Query[person](db, `SELECT|people|age|`) { 146 | if err != nil { 147 | b.Fatal(err) 148 | } 149 | if n--; n == 0 { 150 | break 151 | } 152 | } 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /sqlrange.go: -------------------------------------------------------------------------------- 1 | // Package sqlrange integrates database/sql with Go 1.23 range functions. 2 | package sqlrange 3 | 4 | import ( 5 | "context" 6 | "database/sql" 7 | "fmt" 8 | "iter" 9 | "reflect" 10 | "slices" 11 | "sync/atomic" 12 | ) 13 | 14 | // ExecOption is a functional option type to configure the [Exec] and [ExecContext] 15 | // functions. 16 | type ExecOption[Row any] func(*execOptions[Row]) 17 | 18 | // ExecArgsFields constructs an option that specifies the fields to include in 19 | // the query arguments from a list of column names. 20 | // 21 | // This option is useful when the query only needs a subset of the fields from 22 | // the row type, or when the query arguments are in a different order than the 23 | // struct fields. 24 | func ExecArgsFields[Row any](columnNames ...string) ExecOption[Row] { 25 | structFieldIndexes := make([][]int, len(columnNames)) 26 | 27 | for columnName, structField := range Fields(reflect.TypeOf(new(Row)).Elem()) { 28 | if columnIndex := slices.Index(columnNames, columnName); columnIndex >= 0 { 29 | structFieldIndexes[columnIndex] = structField.Index 30 | } 31 | } 32 | 33 | for i, structFieldIndex := range structFieldIndexes { 34 | if structFieldIndex == nil { 35 | panic(fmt.Errorf("column %q not found", columnNames[i])) 36 | } 37 | } 38 | 39 | return ExecArgs(func(args []any, row Row) []any { 40 | rowValue := reflect.ValueOf(row) 41 | for _, structFieldIndex := range structFieldIndexes { 42 | args = append(args, rowValue.FieldByIndex(structFieldIndex).Interface()) 43 | } 44 | return args 45 | }) 46 | } 47 | 48 | // ExecArgs is an option that specifies the function being called to generate 49 | // the list of arguments passed when executing a query. 50 | // 51 | // By default, the Row value is converted to a list of arguments by taking the 52 | // fields with a "sql" struct tag in the order they appear in the struct, 53 | // as defined by the [reflect.VisibleFields] function. 54 | // 55 | // The function must append the arguments to the slice passed as argument and 56 | // return the resulting slice. 57 | func ExecArgs[Row any](fn func([]any, Row) []any) ExecOption[Row] { 58 | return func(opts *execOptions[Row]) { opts.args = fn } 59 | } 60 | 61 | // ExecQuery is an option that specifies the function being called to generate 62 | // the query to execute for a given Row value. 63 | // 64 | // The function receives the original query value passed to [Exec] or [ExecContext], 65 | // and returns the query to execute. 66 | // 67 | // This is useful when parts of the query depend on the Row value that the query 68 | // is being executed on, for example when the query is an insert. 69 | func ExecQuery[Row any](fn func(string, Row) string) ExecOption[Row] { 70 | return func(opts *execOptions[Row]) { opts.query = fn } 71 | } 72 | 73 | type execOptions[Row any] struct { 74 | args func([]any, Row) []any 75 | query func(string, Row) string 76 | } 77 | 78 | // Executable is the interface implemented by [sql.DB], [sql.Conn], or [sql.Tx]. 79 | type Executable interface { 80 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 81 | } 82 | 83 | // Exec is like [ExecContext] but it uses the background context. 84 | func Exec[Row any](e Executable, query string, seq iter.Seq2[Row, error], opts ...ExecOption[Row]) iter.Seq2[sql.Result, error] { 85 | return ExecContext[Row](context.Background(), e, query, seq, opts...) 86 | } 87 | 88 | // ExecContext executes a query for each row in the sequence. 89 | // 90 | // To ensure that the query is executed atomically, it is usually useful to 91 | // call ExecContext on a transaction ([sql.Tx]), for example: 92 | // 93 | // tx, err := db.BeginTx(ctx, nil) 94 | // if err != nil { 95 | // ... 96 | // } 97 | // defer tx.Rollback() 98 | // for r, err := range sqlrange.ExecContext[RowType](ctx, tx, query, rows) { 99 | // if err != nil { 100 | // ... 101 | // } 102 | // ... 103 | // } 104 | // if err := tx.Commit(); err != nil { 105 | // ... 106 | // } 107 | // 108 | // Since the function makes one query execution for each row read from the 109 | // sequence, latency of the query execution can quickly increase. In some cases, 110 | // such as inserting values in a database, the program can amortize the cost of 111 | // query latency by batching the rows being inserted, for example: 112 | // 113 | // for r, err := range sqlrange.ExecContext(ctx, tx, 114 | // `insert into table (col1, col2, col3) values `, 115 | // // yield groups of rows to be inserted in bulk 116 | // func(yield func([]RowType, error) bool) { 117 | // ... 118 | // }, 119 | // // append values for the insert query 120 | // sqlrange.ExecArgs(func(args []any, rows []RowType) []any { 121 | // for _, row := range rows { 122 | // args = append(args, row.Col1, row.Col2, row.Col3) 123 | // } 124 | // return args 125 | // }), 126 | // // generate placeholders for the insert query 127 | // sqlrange.ExecQuery(func(query string, rows []RowType) string { 128 | // return query + strings.Repeat(`(?, ?, ?)`, len(rows)) 129 | // }), 130 | // ) { 131 | // ... 132 | // } 133 | // 134 | // Batching operations this way is necessary to achieve high throughput when 135 | // inserting values into a database. 136 | func ExecContext[Row any](ctx context.Context, e Executable, query string, seq iter.Seq2[Row, error], opts ...ExecOption[Row]) iter.Seq2[sql.Result, error] { 137 | return func(yield func(sql.Result, error) bool) { 138 | options := new(execOptions[Row]) 139 | for _, opt := range opts { 140 | opt(options) 141 | } 142 | 143 | if options.args == nil { 144 | row := new(Row) 145 | val := reflect.ValueOf(row).Elem() 146 | fields := Fields(val.Type()) 147 | options.args = func(args []any, in Row) []any { 148 | *row = in 149 | for _, structField := range fields { 150 | args = append(args, val.FieldByIndex(structField.Index).Interface()) 151 | } 152 | return args 153 | } 154 | } 155 | 156 | if options.query == nil { 157 | options.query = func(query string, _ Row) string { return query } 158 | } 159 | 160 | var execArgs []any 161 | var execQuery string 162 | for r, err := range seq { 163 | if err != nil { 164 | yield(nil, err) 165 | return 166 | } 167 | execArgs = options.args(execArgs[:0], r) 168 | execQuery = options.query(query, r) 169 | 170 | res, err := e.ExecContext(ctx, execQuery, execArgs...) 171 | if !yield(res, err) { 172 | return 173 | } 174 | if err != nil { 175 | return 176 | } 177 | } 178 | } 179 | } 180 | 181 | // Queryable is an interface implemented by types that can send SQL queries, 182 | // such as [sql.DB], [sql.Conn], or [sql.Tx]. 183 | type Queryable interface { 184 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 185 | } 186 | 187 | // Query is like [QueryContext] but it uses the background context. 188 | func Query[Row any](q Queryable, query string, args ...any) iter.Seq2[Row, error] { 189 | return QueryContext[Row](context.Background(), q, query, args...) 190 | } 191 | 192 | // QueryContext returns the results of the query as a sequence of rows. 193 | // 194 | // The returned function automatically closes the underlying [sql.Rows] value when 195 | // it completes its iteration. 196 | // 197 | // A typical use of QueryContext is: 198 | // 199 | // for row, err := range sqlrange.QueryContext[RowType](ctx, db, query, args...) { 200 | // if err != nil { 201 | // ... 202 | // } 203 | // ... 204 | // } 205 | // 206 | // The q parameter represents a queryable type, such as [sql.DB], [sql.Conn], 207 | // or [sql.Tx]. 208 | // 209 | // See [Scan] for more information about how the rows are mapped to the row type 210 | // parameter Row. 211 | func QueryContext[Row any](ctx context.Context, q Queryable, query string, args ...any) iter.Seq2[Row, error] { 212 | return func(yield func(Row, error) bool) { 213 | if rows, err := q.QueryContext(ctx, query, args...); err != nil { 214 | var zero Row 215 | yield(zero, err) 216 | } else { 217 | scan[Row](yield, rows) 218 | } 219 | } 220 | } 221 | 222 | // Scan returns a sequence of rows from a [sql.Rows] value. 223 | // 224 | // The returned function automatically closes the rows passed as argument when 225 | // it completes its iteration. 226 | // 227 | // A typical use of Scan is: 228 | // 229 | // rows, err := db.QueryContext(ctx, query, args...) 230 | // if err != nil { 231 | // ... 232 | // } 233 | // for row, err := range sqlrange.Scan[RowType](rows) { 234 | // if err != nil { 235 | // ... 236 | // } 237 | // ... 238 | // } 239 | // 240 | // Scan uses reflection to map the columns of the rows to the fields of the 241 | // struct passed as argument. The mapping is done by matching the name of the 242 | // columns with the name of the fields. The name of the columns is taken from 243 | // the "sql" tag of the fields. For example: 244 | // 245 | // type Row struct { 246 | // ID int64 `sql:"id"` 247 | // Name string `sql:"name"` 248 | // } 249 | // 250 | // The fields of the struct that do not have a "sql" tag are ignored. 251 | // 252 | // Ranging over the returned function will panic if the type parameter is not a 253 | // struct. 254 | func Scan[Row any](rows *sql.Rows) iter.Seq2[Row, error] { 255 | return func(yield func(Row, error) bool) { scan(yield, rows) } 256 | } 257 | 258 | func scan[Row any](yield func(Row, error) bool, rows *sql.Rows) { 259 | defer rows.Close() 260 | var zero Row 261 | 262 | columns, err := rows.Columns() 263 | if err != nil { 264 | yield(zero, err) 265 | return 266 | } 267 | 268 | scanArgs := make([]any, len(columns)) 269 | row := new(Row) 270 | val := reflect.ValueOf(row).Elem() 271 | 272 | for columnName, structField := range Fields(val.Type()) { 273 | if columnIndex := slices.Index(columns, columnName); columnIndex >= 0 { 274 | scanArgs[columnIndex] = val.FieldByIndex(structField.Index).Addr().Interface() 275 | } 276 | } 277 | 278 | for rows.Next() { 279 | if err := rows.Scan(scanArgs...); err != nil { 280 | yield(zero, err) 281 | return 282 | } 283 | if !yield(*row, nil) { 284 | return 285 | } 286 | *row = zero 287 | } 288 | 289 | if err := rows.Err(); err != nil { 290 | yield(zero, err) 291 | } 292 | } 293 | 294 | // Fields returns a sequence of the fields of a struct type that have a "sql" 295 | // tag. 296 | func Fields(t reflect.Type) iter.Seq2[string, reflect.StructField] { 297 | return func(yield func(string, reflect.StructField) bool) { 298 | cache, _ := cachedFields.Load().(map[reflect.Type][]field) 299 | 300 | fields, ok := cache[t] 301 | if !ok { 302 | fields = appendFields(nil, t, nil) 303 | 304 | newCache := make(map[reflect.Type][]field, len(cache)+1) 305 | for k, v := range cache { 306 | newCache[k] = v 307 | } 308 | newCache[t] = fields 309 | cachedFields.Store(newCache) 310 | } 311 | 312 | for _, f := range fields { 313 | if !yield(f.name, f.field) { 314 | return 315 | } 316 | } 317 | } 318 | } 319 | 320 | type field struct { 321 | name string 322 | field reflect.StructField 323 | } 324 | 325 | var cachedFields atomic.Value // map[reflect.Type][]field 326 | 327 | func appendFields(fields []field, t reflect.Type, index []int) []field { 328 | for i, n := 0, t.NumField(); i < n; i++ { 329 | if f := t.Field(i); f.IsExported() { 330 | if len(index) > 0 { 331 | f.Index = append(index, f.Index...) 332 | } 333 | if f.Anonymous { 334 | if f.Type.Kind() == reflect.Struct { 335 | fields = appendFields(fields, f.Type, f.Index) 336 | } 337 | } else if s, ok := f.Tag.Lookup("sql"); ok { 338 | fields = append(fields, field{s, f}) 339 | } 340 | } 341 | } 342 | return fields 343 | } 344 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlrange [![Go Reference](https://pkg.go.dev/badge/github.com/achille-roussel/sqlrange.svg)](https://pkg.go.dev/github.com/achille-roussel/sqlrange) 2 | 3 | Library using the `database/sql` package and Go 1.23 range functions to execute 4 | queries against SQL databases. 5 | 6 | ## Installation 7 | 8 | This package is intended to be used as a library and installed with: 9 | ```sh 10 | go get github.com/achille-roussel/sqlrange 11 | ``` 12 | 13 | :warning: The package requires Go 1.23 or later for range function support. 14 | 15 | ## Usage 16 | 17 | The `sqlrange` package contains two kinds of functions called **Exec** and 18 | **Query** which wrap the standard library's `database/sql` methods with the 19 | same names. The package adds stronger type safety and the ability to use 20 | range functions as iterators to pass values to the queries or retrieve results. 21 | 22 | Note that `sqlrange` **IS NOT AN ORM**, it is a lightweight package which does 23 | not hide any of the details and simply provides library functions to structure 24 | applications that stream values in and out of databases. 25 | 26 | ### Query 27 | 28 | The **Query** functions are used to read streams of values from databases, 29 | in the same way that `sql.(*DB).Query` does, but using range functions to 30 | simplify the code constructs, and type parameters to automatically decode 31 | SQL results into Go struct values. 32 | 33 | The type parameter must be a struct with fields containing "sql" struct tags 34 | to define the names of columns that the fields are mapped to: 35 | ```go 36 | type Point struct { 37 | X float64 `sql:"x"` 38 | Y float64 `sql:"y"` 39 | } 40 | ``` 41 | ```go 42 | for p, err := range sqlrange.Query[Point](db, `select x, y from points`) { 43 | if err != nil { 44 | ... 45 | } 46 | ... 47 | } 48 | ``` 49 | Note that resource management here is automated by the range function 50 | returned by calling **Query**, the underlying `*sql.Rows` value is automatically 51 | closed when the program exits the body of the range loop consuming the rows. 52 | 53 | ### Exec 54 | 55 | The **Exec** functions are used to execute insert, update, or delete queries 56 | against databases, accepting a stream of parameters as arguments (in the form of 57 | a range function), and returning a stream of results. 58 | 59 | Since the function will send multiple queries to the database, it is often 60 | preferable to apply it to a transaction (or a statement derived from a 61 | transaction via `sql.(*Tx).Stmt`) to ensure atomicity of the operation. 62 | 63 | ```go 64 | tx, err := db.Begin() 65 | if err != nil { 66 | ... 67 | } 68 | defer tx.Rollback() 69 | 70 | for r, err := range sqlrange.Exec(tx, 71 | `insert into table (col1, col2, col3) values (?, ?, ?)`, 72 | // This range function yields the values that will be inserted into 73 | // the database by executing the query above. 74 | func(yield func(RowType, error) bool) { 75 | ... 76 | }, 77 | // Inject the arguments for the SQL query being executed. 78 | // The function is called for each value yielded by the range 79 | // function above. 80 | sqlrange.ExecArgs(func(args []any, row RowType) []any { 81 | return append(args, row.Col1, row.Col2, row.Col3) 82 | }), 83 | ) { 84 | // Each results of each execution are streamed and must be consumed 85 | // by the program to drive the operation. 86 | if err != nil { 87 | ... 88 | } 89 | ... 90 | } 91 | 92 | if err := tx.Commit(); err != nil { 93 | ... 94 | } 95 | ``` 96 | 97 | ### Context 98 | 99 | Mirroring methods of the `sql.DB` type, functions of the `sqlrange` package have 100 | variants that take a `context.Context` as first argument to support asynchronous 101 | cancellation or timing out the operations. 102 | 103 | Reusing the example above, we could set a 10 secondstime limit for the query 104 | using **QueryContext** instead of **Query**: 105 | 106 | ```go 107 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 108 | defer cancel() 109 | 110 | for p, err := range sqlrange.QueryContext[Point](ctx, db, `select x, y from points`) { 111 | if err != nil { 112 | ... 113 | } 114 | ... 115 | } 116 | ``` 117 | 118 | The context is propagated to the `sql.(*DB).QueryContext` method, which then 119 | passes it to the underlying SQL driver. 120 | 121 | ## Performance 122 | 123 | Functions in this package are optimized to have a minimal compute and memory 124 | footprint. Applications should not observe any performance degradation from 125 | using it, compared to using the `database/sql` package directly. This is an 126 | important property of the package since it means that the type safety, resource 127 | lifecycle management, and expressiveness do not have to be a trade off. 128 | 129 | This is a use case where the use of range functions really shines: because all 130 | the code points where range functions are created get inlined, the compiler's 131 | escape analysis can place most of the values on the stack, keeping the memory 132 | and garbage collection overhead to a minimum. 133 | 134 | Most of the escaping heap allocations in this package come from the use of 135 | reflection to convert SQL rows into Go values, which are optimized using two 136 | different approaches: 137 | 138 | - **Caching:** internally, the package caches the `reflect.StructField` values 139 | that it needs. This is necessary to remove some of the allocations caused by 140 | the `reflect` package allocating the [`Index`][structField] on the heap. 141 | See https://github.com/golang/go/issues/2320 for more details. 142 | 143 | - **Amortization:** since the intended use case is to select ranges of rows, 144 | or execute batch queries, the functions can reuse the local state maintained 145 | to read values. The more rows are involved in the query, the great the cost of 146 | allocating those values gets amortized, to the point that it quickly becomes 147 | insignificant. 148 | 149 | To illustrate, we can look at the memory profiles for the package benchmarks. 150 | 151 | **objects allocated on the heap** 152 | ``` 153 | File: sqlrange.test 154 | Type: alloc_objects 155 | Time: Jan 15, 2024 at 8:32am (PST) 156 | Showing nodes accounting for 23444929, 97.50% of 24046152 total 157 | Dropped 43 nodes (cum <= 120230) 158 | flat flat% sum% cum cum% 159 | 21408835 89.03% 89.03% 21408835 89.03% github.com/achille-roussel/sqlrange_test.(*fakeStmt).QueryContext /go/src/github.com/achille-roussel/sqlrange/fakedb_test.go:1040 160 | 1769499 7.36% 96.39% 1769499 7.36% strconv.formatBits /sdk/go1.22rc1/src/strconv/itoa.go:199 161 | 217443 0.9% 97.30% 217443 0.9% github.com/achille-roussel/sqlrange_test.(*fakeStmt).QueryContext /go/src/github.com/achille-roussel/sqlrange/fakedb_test.go:1044 162 | 32768 0.14% 97.43% 21926303 91.18% database/sql.(*DB).query /sdk/go1.22rc1/src/database/sql/sql.go:1754 163 | 16384 0.068% 97.50% 23925181 99.50% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows /go/src/github.com/achille-roussel/sqlrange/sqlrange_test.go:145 164 | 0 0% 97.50% 21926303 91.18% database/sql.(*DB).QueryContext /sdk/go1.22rc1/src/database/sql/sql.go:1731 165 | 0 0% 97.50% 21926303 91.18% database/sql.(*DB).QueryContext.func1 /sdk/go1.22rc1/src/database/sql/sql.go:1732 166 | 0 0% 97.50% 21746433 90.44% database/sql.(*DB).queryDC /sdk/go1.22rc1/src/database/sql/sql.go:1806 167 | 0 0% 97.50% 22039082 91.65% database/sql.(*DB).retry /sdk/go1.22rc1/src/database/sql/sql.go:1566 168 | 0 0% 97.50% 1769499 7.36% database/sql.(*Rows).Scan /sdk/go1.22rc1/src/database/sql/sql.go:3354 169 | 0 0% 97.50% 1769499 7.36% database/sql.asString /sdk/go1.22rc1/src/database/sql/convert.go:499 170 | 0 0% 97.50% 1769499 7.36% database/sql.convertAssignRows /sdk/go1.22rc1/src/database/sql/convert.go:433 171 | 0 0% 97.50% 169852 0.71% database/sql.ctxDriverPrepare /sdk/go1.22rc1/src/database/sql/ctxutil.go:15 172 | 0 0% 97.50% 21746433 90.44% database/sql.ctxDriverStmtQuery /sdk/go1.22rc1/src/database/sql/ctxutil.go:82 173 | 0 0% 97.50% 21746433 90.44% database/sql.rowsiFromStatement /sdk/go1.22rc1/src/database/sql/sql.go:2836 174 | 0 0% 97.50% 202620 0.84% database/sql.withLock /sdk/go1.22rc1/src/database/sql/sql.go:3530 175 | 0 0% 97.50% 21926303 91.18% github.com/achille-roussel/sqlrange.QueryContext[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }] /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:213 176 | 0 0% 97.50% 1769499 7.36% github.com/achille-roussel/sqlrange.QueryContext[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].Scan[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].func2 /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:278 177 | 0 0% 97.50% 21926303 91.18% github.com/achille-roussel/sqlrange.Query[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }] /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:189 (inline) 178 | 0 0% 97.50% 120971 0.5% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows /go/src/github.com/achille-roussel/sqlrange/sqlrange_test.go:129 179 | 0 0% 97.50% 120971 0.5% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows.Exec[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].ExecContext[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].func4 /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:162 180 | 0 0% 97.50% 120971 0.5% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows.func1 /go/src/github.com/achille-roussel/sqlrange/sqlrange_test.go:131 181 | 0 0% 97.50% 1769499 7.36% strconv.FormatInt /sdk/go1.22rc1/src/strconv/itoa.go:29 182 | 0 0% 97.50% 24033864 99.95% testing.(*B).launch /sdk/go1.22rc1/src/testing/benchmark.go:316 183 | 0 0% 97.50% 24046152 100% testing.(*B).runN /sdk/go1.22rc1/src/testing/benchmark.go:193 184 | ``` 185 | 186 | **memory allocated on the heap** 187 | ``` 188 | File: sqlrange.test 189 | Type: alloc_space 190 | Time: Jan 15, 2024 at 8:32am (PST) 191 | Showing nodes accounting for 626.05MB, 97.66% of 641.05MB total 192 | Dropped 33 nodes (cum <= 3.21MB) 193 | flat flat% sum% cum cum% 194 | 408.51MB 63.72% 63.72% 408.51MB 63.72% github.com/achille-roussel/sqlrange_test.(*fakeStmt).QueryContext /go/src/github.com/achille-roussel/sqlrange/fakedb_test.go:1040 195 | 174.04MB 27.15% 90.87% 174.04MB 27.15% github.com/achille-roussel/sqlrange_test.(*fakeStmt).QueryContext /go/src/github.com/achille-roussel/sqlrange/fakedb_test.go:1044 196 | 27MB 4.21% 95.09% 27MB 4.21% strconv.formatBits /sdk/go1.22rc1/src/strconv/itoa.go:199 197 | 5.50MB 0.86% 95.94% 5.50MB 0.86% github.com/achille-roussel/sqlrange_test.(*fakeStmt).QueryContext /go/src/github.com/achille-roussel/sqlrange/fakedb_test.go:1064 198 | 5.50MB 0.86% 96.80% 5.50MB 0.86% database/sql.(*DB).queryDC /sdk/go1.22rc1/src/database/sql/sql.go:1815 199 | 4.50MB 0.7% 97.50% 4.50MB 0.7% strings.genSplit /sdk/go1.22rc1/src/strings/strings.go:249 200 | 0.50MB 0.078% 97.58% 635.05MB 99.06% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows /go/src/github.com/achille-roussel/sqlrange/sqlrange_test.go:145 201 | 0.50MB 0.078% 97.66% 602.05MB 93.92% database/sql.(*DB).query /sdk/go1.22rc1/src/database/sql/sql.go:1754 202 | 0 0% 97.66% 5.50MB 0.86% database/sql.(*DB).ExecContext /sdk/go1.22rc1/src/database/sql/sql.go:1661 203 | 0 0% 97.66% 5.50MB 0.86% database/sql.(*DB).ExecContext.func1 /sdk/go1.22rc1/src/database/sql/sql.go:1662 204 | 0 0% 97.66% 602.05MB 93.92% database/sql.(*DB).QueryContext /sdk/go1.22rc1/src/database/sql/sql.go:1731 205 | 0 0% 97.66% 602.05MB 93.92% database/sql.(*DB).QueryContext.func1 /sdk/go1.22rc1/src/database/sql/sql.go:1732 206 | 0 0% 97.66% 5.50MB 0.86% database/sql.(*DB).exec /sdk/go1.22rc1/src/database/sql/sql.go:1683 207 | 0 0% 97.66% 5MB 0.78% database/sql.(*DB).queryDC /sdk/go1.22rc1/src/database/sql/sql.go:1797 208 | 0 0% 97.66% 589.55MB 91.97% database/sql.(*DB).queryDC /sdk/go1.22rc1/src/database/sql/sql.go:1806 209 | 0 0% 97.66% 5MB 0.78% database/sql.(*DB).queryDC.func2 /sdk/go1.22rc1/src/database/sql/sql.go:1798 210 | 0 0% 97.66% 607.55MB 94.77% database/sql.(*DB).retry /sdk/go1.22rc1/src/database/sql/sql.go:1566 211 | 0 0% 97.66% 27MB 4.21% database/sql.(*Rows).Scan /sdk/go1.22rc1/src/database/sql/sql.go:3354 212 | 0 0% 97.66% 27MB 4.21% database/sql.asString /sdk/go1.22rc1/src/database/sql/convert.go:499 213 | 0 0% 97.66% 27MB 4.21% database/sql.convertAssignRows /sdk/go1.22rc1/src/database/sql/convert.go:433 214 | 0 0% 97.66% 8MB 1.25% database/sql.ctxDriverPrepare /sdk/go1.22rc1/src/database/sql/ctxutil.go:15 215 | 0 0% 97.66% 589.55MB 91.97% database/sql.ctxDriverStmtQuery /sdk/go1.22rc1/src/database/sql/ctxutil.go:82 216 | 0 0% 97.66% 589.55MB 91.97% database/sql.rowsiFromStatement /sdk/go1.22rc1/src/database/sql/sql.go:2836 217 | 0 0% 97.66% 8.50MB 1.33% database/sql.withLock /sdk/go1.22rc1/src/database/sql/sql.go:3530 218 | 0 0% 97.66% 602.05MB 93.92% github.com/achille-roussel/sqlrange.QueryContext[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }] /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:213 219 | 0 0% 97.66% 27MB 4.21% github.com/achille-roussel/sqlrange.QueryContext[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].Scan[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].func2 /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:278 220 | 0 0% 97.66% 602.05MB 93.92% github.com/achille-roussel/sqlrange.Query[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }] /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:189 (inline) 221 | 0 0% 97.66% 6MB 0.94% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows /go/src/github.com/achille-roussel/sqlrange/sqlrange_test.go:129 222 | 0 0% 97.66% 6MB 0.94% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows.Exec[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].ExecContext[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].func4 /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:162 223 | 0 0% 97.66% 5.50MB 0.86% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows.Exec[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].ExecContext[go.shape.struct { Age int "sql:\"age\""; Name string "sql:\"name\""; BirthDate time.Time "sql:\"bdate\"" }].func4.3 /go/src/github.com/achille-roussel/sqlrange/sqlrange.go:170 224 | 0 0% 97.66% 6MB 0.94% github.com/achille-roussel/sqlrange_test.BenchmarkQuery100Rows.func1 /go/src/github.com/achille-roussel/sqlrange/sqlrange_test.go:131 225 | 0 0% 97.66% 27MB 4.21% strconv.FormatInt /sdk/go1.22rc1/src/strconv/itoa.go:29 226 | 0 0% 97.66% 4.50MB 0.7% strings.Split /sdk/go1.22rc1/src/strings/strings.go:307 (inline) 227 | 0 0% 97.66% 640.05MB 99.84% testing.(*B).launch /sdk/go1.22rc1/src/testing/benchmark.go:316 228 | 0 0% 97.66% 641.05MB 100% testing.(*B).runN /sdk/go1.22rc1/src/testing/benchmark.go:193 229 | ``` 230 | 231 | Almost all the memory allocated on the heap is done in the SQL driver. 232 | The fake driver employed for tests isn't very efficient, but it still shows 233 | that the package does not contribute to the majority of memory usage. 234 | Programs that use SQL drivers for production databases like MySQL or Postgres 235 | will have performance characteristics dictated by the driver and won't suffer 236 | from utilizing the `sqlrange` package abstractions. 237 | 238 | [structField]: https://pkg.go.dev/reflect#StructField 239 | -------------------------------------------------------------------------------- /fakedb_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2011 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // package sql 6 | package sqlrange_test 7 | 8 | import ( 9 | "context" 10 | "database/sql" 11 | "database/sql/driver" 12 | "errors" 13 | "fmt" 14 | "io" 15 | "reflect" 16 | "strconv" 17 | "strings" 18 | "sync" 19 | "sync/atomic" 20 | "testing" 21 | "time" 22 | ) 23 | 24 | const fakeDBName = "foo" 25 | 26 | var chrisBirthday = time.Unix(123456789, 0) 27 | 28 | func newTestDB(t testing.TB, name string) *sql.DB { 29 | return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name) 30 | } 31 | 32 | func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *sql.DB { 33 | fc.name = fakeDBName 34 | db := sql.OpenDB(fc) 35 | if _, err := db.Exec("WIPE"); err != nil { 36 | t.Fatalf("exec wipe: %v", err) 37 | } 38 | if name == "people" { 39 | exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime") 40 | exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1) 41 | exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2) 42 | exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) 43 | } 44 | if name == "magicquery" { 45 | // Magic table name and column, known by fakedb_test.go. 46 | exec(t, db, "CREATE|magicquery|op=string,millis=int32") 47 | exec(t, db, "INSERT|magicquery|op=sleep,millis=10") 48 | } 49 | if name == "tx_status" { 50 | // Magic table name and column, known by fakedb_test.go. 51 | exec(t, db, "CREATE|tx_status|tx_status=string") 52 | exec(t, db, "INSERT|tx_status|tx_status=invalid") 53 | } 54 | return db 55 | } 56 | 57 | func exec(t testing.TB, db *sql.DB, query string, args ...any) { 58 | t.Helper() 59 | _, err := db.Exec(query, args...) 60 | if err != nil { 61 | t.Fatalf("Exec of %q: %v", query, err) 62 | } 63 | } 64 | 65 | // fakeDriver is a fake database that implements Go's driver.Driver 66 | // interface, just for testing. 67 | // 68 | // It speaks a query language that's semantically similar to but 69 | // syntactically different and simpler than SQL. The syntax is as 70 | // follows: 71 | // 72 | // WIPE 73 | // CREATE||=,=,... 74 | // where types are: "string", [u]int{8,16,32,64}, "bool" 75 | // INSERT||col=val,col2=val2,col3=? 76 | // SELECT||projectcol1,projectcol2|filtercol=?,filtercol2=? 77 | // SELECT||projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 78 | // 79 | // Any of these can be preceded by PANIC||, to cause the 80 | // named method on fakeStmt to panic. 81 | // 82 | // Any of these can be proceeded by WAIT||, to cause the 83 | // named method on fakeStmt to sleep for the specified duration. 84 | // 85 | // Multiple of these can be combined when separated with a semicolon. 86 | // 87 | // When opening a fakeDriver's database, it starts empty with no 88 | // tables. All tables and data are stored in memory only. 89 | type fakeDriver struct { 90 | mu sync.Mutex // guards 3 following fields 91 | openCount int // conn opens 92 | closeCount int // conn closes 93 | waitCh chan struct{} 94 | waitingCh chan struct{} 95 | dbs map[string]*fakeDB 96 | } 97 | 98 | type fakeConnector struct { 99 | name string 100 | 101 | waiter func(context.Context) 102 | closed bool 103 | } 104 | 105 | func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { 106 | conn, err := fdriver.Open(c.name) 107 | conn.(*fakeConn).waiter = c.waiter 108 | return conn, err 109 | } 110 | 111 | func (c *fakeConnector) Driver() driver.Driver { 112 | return fdriver 113 | } 114 | 115 | func (c *fakeConnector) Close() error { 116 | if c.closed { 117 | return errors.New("fakedb: connector is closed") 118 | } 119 | c.closed = true 120 | return nil 121 | } 122 | 123 | type fakeDriverCtx struct { 124 | fakeDriver 125 | } 126 | 127 | var _ driver.DriverContext = &fakeDriverCtx{} 128 | 129 | func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { 130 | return &fakeConnector{name: name}, nil 131 | } 132 | 133 | type fakeDB struct { 134 | name string 135 | 136 | useRawBytes atomic.Bool 137 | 138 | mu sync.Mutex 139 | tables map[string]*table 140 | badConn bool 141 | allowAny bool 142 | } 143 | 144 | type fakeError struct { 145 | Message string 146 | Wrapped error 147 | } 148 | 149 | func (err fakeError) Error() string { 150 | return err.Message 151 | } 152 | 153 | func (err fakeError) Unwrap() error { 154 | return err.Wrapped 155 | } 156 | 157 | type table struct { 158 | mu sync.Mutex 159 | colname []string 160 | coltype []string 161 | rows []*row 162 | } 163 | 164 | func (t *table) columnIndex(name string) int { 165 | for n, nname := range t.colname { 166 | if name == nname { 167 | return n 168 | } 169 | } 170 | return -1 171 | } 172 | 173 | type row struct { 174 | cols []any // must be same size as its table colname + coltype 175 | } 176 | 177 | type memToucher interface { 178 | // touchMem reads & writes some memory, to help find data races. 179 | touchMem() 180 | } 181 | 182 | type fakeConn struct { 183 | db *fakeDB // where to return ourselves to 184 | 185 | currTx *fakeTx 186 | 187 | // Every operation writes to line to enable the race detector 188 | // check for data races. 189 | line int64 190 | 191 | // Stats for tests: 192 | mu sync.Mutex 193 | stmtsMade int 194 | stmtsClosed int 195 | numPrepare int 196 | 197 | // bad connection tests; see isBad() 198 | bad bool 199 | stickyBad bool 200 | 201 | skipDirtySession bool // tests that use Conn should set this to true. 202 | 203 | // dirtySession tests ResetSession, true if a query has executed 204 | // until ResetSession is called. 205 | dirtySession bool 206 | 207 | // The waiter is called before each query. May be used in place of the "WAIT" 208 | // directive. 209 | waiter func(context.Context) 210 | } 211 | 212 | func (c *fakeConn) touchMem() { 213 | c.line++ 214 | } 215 | 216 | func (c *fakeConn) incrStat(v *int) { 217 | c.mu.Lock() 218 | *v++ 219 | c.mu.Unlock() 220 | } 221 | 222 | type fakeTx struct { 223 | c *fakeConn 224 | } 225 | 226 | type boundCol struct { 227 | Column string 228 | Placeholder string 229 | Ordinal int 230 | } 231 | 232 | type fakeStmt struct { 233 | memToucher 234 | c *fakeConn 235 | q string // just for debugging 236 | 237 | cmd string 238 | table string 239 | panic string 240 | wait time.Duration 241 | 242 | next *fakeStmt // used for returning multiple results. 243 | 244 | closed bool 245 | 246 | colName []string // used by CREATE, INSERT, SELECT (selected columns) 247 | colType []string // used by CREATE 248 | colValue []any // used by INSERT (mix of strings and "?" for bound params) 249 | placeholders int // used by INSERT/SELECT: number of ? params 250 | 251 | whereCol []boundCol // used by SELECT (all placeholders) 252 | 253 | placeholderConverter []driver.ValueConverter // used by INSERT 254 | } 255 | 256 | var fdriver driver.Driver = &fakeDriver{} 257 | 258 | func init() { 259 | sql.Register("test", fdriver) 260 | } 261 | 262 | func contains(list []string, y string) bool { 263 | for _, x := range list { 264 | if x == y { 265 | return true 266 | } 267 | } 268 | return false 269 | } 270 | 271 | // hook to simulate connection failures 272 | var hookOpenErr struct { 273 | sync.Mutex 274 | fn func() error 275 | } 276 | 277 | func setHookOpenErr(fn func() error) { 278 | hookOpenErr.Lock() 279 | defer hookOpenErr.Unlock() 280 | hookOpenErr.fn = fn 281 | } 282 | 283 | // Supports dsn forms: 284 | // 285 | // 286 | // ; (only currently supported option is `badConn`, 287 | // which causes driver.ErrBadConn to be returned on 288 | // every other conn.Begin()) 289 | func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { 290 | hookOpenErr.Lock() 291 | fn := hookOpenErr.fn 292 | hookOpenErr.Unlock() 293 | if fn != nil { 294 | if err := fn(); err != nil { 295 | return nil, err 296 | } 297 | } 298 | parts := strings.Split(dsn, ";") 299 | if len(parts) < 1 { 300 | return nil, errors.New("fakedb: no database name") 301 | } 302 | name := parts[0] 303 | 304 | db := d.getDB(name) 305 | 306 | d.mu.Lock() 307 | d.openCount++ 308 | d.mu.Unlock() 309 | conn := &fakeConn{db: db} 310 | 311 | if len(parts) >= 2 && parts[1] == "badConn" { 312 | conn.bad = true 313 | } 314 | if d.waitCh != nil { 315 | d.waitingCh <- struct{}{} 316 | <-d.waitCh 317 | d.waitCh = nil 318 | d.waitingCh = nil 319 | } 320 | return conn, nil 321 | } 322 | 323 | func (d *fakeDriver) getDB(name string) *fakeDB { 324 | d.mu.Lock() 325 | defer d.mu.Unlock() 326 | if d.dbs == nil { 327 | d.dbs = make(map[string]*fakeDB) 328 | } 329 | db, ok := d.dbs[name] 330 | if !ok { 331 | db = &fakeDB{name: name} 332 | d.dbs[name] = db 333 | } 334 | return db 335 | } 336 | 337 | func (db *fakeDB) wipe() { 338 | db.mu.Lock() 339 | defer db.mu.Unlock() 340 | db.tables = nil 341 | } 342 | 343 | func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { 344 | db.mu.Lock() 345 | defer db.mu.Unlock() 346 | if db.tables == nil { 347 | db.tables = make(map[string]*table) 348 | } 349 | if _, exist := db.tables[name]; exist { 350 | return fmt.Errorf("fakedb: table %q already exists", name) 351 | } 352 | if len(columnNames) != len(columnTypes) { 353 | return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", 354 | name, len(columnNames), len(columnTypes)) 355 | } 356 | db.tables[name] = &table{colname: columnNames, coltype: columnTypes} 357 | return nil 358 | } 359 | 360 | // must be called with db.mu lock held 361 | func (db *fakeDB) table(table string) (*table, bool) { 362 | if db.tables == nil { 363 | return nil, false 364 | } 365 | t, ok := db.tables[table] 366 | return t, ok 367 | } 368 | 369 | func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { 370 | db.mu.Lock() 371 | defer db.mu.Unlock() 372 | t, ok := db.table(table) 373 | if !ok { 374 | return 375 | } 376 | for n, cname := range t.colname { 377 | if cname == column { 378 | return t.coltype[n], true 379 | } 380 | } 381 | return "", false 382 | } 383 | 384 | func (c *fakeConn) isBad() bool { 385 | if c.stickyBad { 386 | return true 387 | } else if c.bad { 388 | if c.db == nil { 389 | return false 390 | } 391 | // alternate between bad conn and not bad conn 392 | c.db.badConn = !c.db.badConn 393 | return c.db.badConn 394 | } else { 395 | return false 396 | } 397 | } 398 | 399 | func (c *fakeConn) isDirtyAndMark() bool { 400 | if c.skipDirtySession { 401 | return false 402 | } 403 | if c.currTx != nil { 404 | c.dirtySession = true 405 | return false 406 | } 407 | if c.dirtySession { 408 | return true 409 | } 410 | c.dirtySession = true 411 | return false 412 | } 413 | 414 | func (c *fakeConn) Begin() (driver.Tx, error) { 415 | if c.isBad() { 416 | return nil, fakeError{Wrapped: driver.ErrBadConn} 417 | } 418 | if c.currTx != nil { 419 | return nil, errors.New("fakedb: already in a transaction") 420 | } 421 | c.touchMem() 422 | c.currTx = &fakeTx{c: c} 423 | return c.currTx, nil 424 | } 425 | 426 | var hookPostCloseConn struct { 427 | sync.Mutex 428 | fn func(*fakeConn, error) 429 | } 430 | 431 | func setHookpostCloseConn(fn func(*fakeConn, error)) { 432 | hookPostCloseConn.Lock() 433 | defer hookPostCloseConn.Unlock() 434 | hookPostCloseConn.fn = fn 435 | } 436 | 437 | var testStrictClose *testing.T 438 | 439 | // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close 440 | // fails to close. If nil, the check is disabled. 441 | func setStrictFakeConnClose(t *testing.T) { 442 | testStrictClose = t 443 | } 444 | 445 | func (c *fakeConn) ResetSession(ctx context.Context) error { 446 | c.dirtySession = false 447 | c.currTx = nil 448 | if c.isBad() { 449 | return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} 450 | } 451 | return nil 452 | } 453 | 454 | var _ driver.Validator = (*fakeConn)(nil) 455 | 456 | func (c *fakeConn) IsValid() bool { 457 | return !c.isBad() 458 | } 459 | 460 | func (c *fakeConn) Close() (err error) { 461 | drv := fdriver.(*fakeDriver) 462 | defer func() { 463 | if err != nil && testStrictClose != nil { 464 | testStrictClose.Errorf("failed to close a test fakeConn: %v", err) 465 | } 466 | hookPostCloseConn.Lock() 467 | fn := hookPostCloseConn.fn 468 | hookPostCloseConn.Unlock() 469 | if fn != nil { 470 | fn(c, err) 471 | } 472 | if err == nil { 473 | drv.mu.Lock() 474 | drv.closeCount++ 475 | drv.mu.Unlock() 476 | } 477 | }() 478 | c.touchMem() 479 | if c.currTx != nil { 480 | return errors.New("fakedb: can't close fakeConn; in a Transaction") 481 | } 482 | if c.db == nil { 483 | return errors.New("fakedb: can't close fakeConn; already closed") 484 | } 485 | if c.stmtsMade > c.stmtsClosed { 486 | return errors.New("fakedb: can't close; dangling statement(s)") 487 | } 488 | c.db = nil 489 | return nil 490 | } 491 | 492 | func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { 493 | for _, arg := range args { 494 | switch arg.Value.(type) { 495 | case int64, float64, bool, nil, []byte, string, time.Time: 496 | default: 497 | if !allowAny { 498 | return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) 499 | } 500 | } 501 | } 502 | return nil 503 | } 504 | 505 | func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { 506 | // Ensure that ExecContext is called if available. 507 | panic("ExecContext was not called.") 508 | } 509 | 510 | func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 511 | // This is an optional interface, but it's implemented here 512 | // just to check that all the args are of the proper types. 513 | // ErrSkip is returned so the caller acts as if we didn't 514 | // implement this at all. 515 | err := checkSubsetTypes(c.db.allowAny, args) 516 | if err != nil { 517 | return nil, err 518 | } 519 | return nil, driver.ErrSkip 520 | } 521 | 522 | func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { 523 | // Ensure that ExecContext is called if available. 524 | panic("QueryContext was not called.") 525 | } 526 | 527 | func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 528 | // This is an optional interface, but it's implemented here 529 | // just to check that all the args are of the proper types. 530 | // ErrSkip is returned so the caller acts as if we didn't 531 | // implement this at all. 532 | err := checkSubsetTypes(c.db.allowAny, args) 533 | if err != nil { 534 | return nil, err 535 | } 536 | return nil, driver.ErrSkip 537 | } 538 | 539 | func errf(msg string, args ...any) error { 540 | return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) 541 | } 542 | 543 | // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? 544 | // (note that where columns must always contain ? marks, 545 | // just a limitation for fakedb) 546 | func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 547 | if len(parts) != 3 { 548 | stmt.Close() 549 | return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) 550 | } 551 | stmt.table = parts[0] 552 | 553 | stmt.colName = strings.Split(parts[1], ",") 554 | for n, colspec := range strings.Split(parts[2], ",") { 555 | if colspec == "" { 556 | continue 557 | } 558 | nameVal := strings.Split(colspec, "=") 559 | if len(nameVal) != 2 { 560 | stmt.Close() 561 | return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 562 | } 563 | column, value := nameVal[0], nameVal[1] 564 | _, ok := c.db.columnType(stmt.table, column) 565 | if !ok { 566 | stmt.Close() 567 | return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) 568 | } 569 | if !strings.HasPrefix(value, "?") { 570 | stmt.Close() 571 | return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", 572 | stmt.table, column) 573 | } 574 | stmt.placeholders++ 575 | stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) 576 | } 577 | return stmt, nil 578 | } 579 | 580 | // parts are table|col=type,col2=type2 581 | func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { 582 | if len(parts) != 2 { 583 | stmt.Close() 584 | return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) 585 | } 586 | stmt.table = parts[0] 587 | for n, colspec := range strings.Split(parts[1], ",") { 588 | nameType := strings.Split(colspec, "=") 589 | if len(nameType) != 2 { 590 | stmt.Close() 591 | return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 592 | } 593 | stmt.colName = append(stmt.colName, nameType[0]) 594 | stmt.colType = append(stmt.colType, nameType[1]) 595 | } 596 | return stmt, nil 597 | } 598 | 599 | // parts are table|col=?,col2=val 600 | func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { 601 | if len(parts) != 2 { 602 | stmt.Close() 603 | return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) 604 | } 605 | stmt.table = parts[0] 606 | for n, colspec := range strings.Split(parts[1], ",") { 607 | nameVal := strings.Split(colspec, "=") 608 | if len(nameVal) != 2 { 609 | stmt.Close() 610 | return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) 611 | } 612 | column, value := nameVal[0], nameVal[1] 613 | ctype, ok := c.db.columnType(stmt.table, column) 614 | if !ok { 615 | stmt.Close() 616 | return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) 617 | } 618 | stmt.colName = append(stmt.colName, column) 619 | 620 | if !strings.HasPrefix(value, "?") { 621 | var subsetVal any 622 | // Convert to driver subset type 623 | switch ctype { 624 | case "string": 625 | subsetVal = []byte(value) 626 | case "blob": 627 | subsetVal = []byte(value) 628 | case "int32": 629 | i, err := strconv.Atoi(value) 630 | if err != nil { 631 | stmt.Close() 632 | return nil, errf("invalid conversion to int32 from %q", value) 633 | } 634 | subsetVal = int64(i) // int64 is a subset type, but not int32 635 | case "table": // For testing cursor reads. 636 | c.skipDirtySession = true 637 | vparts := strings.Split(value, "!") 638 | 639 | substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) 640 | if err != nil { 641 | return nil, err 642 | } 643 | cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) 644 | substmt.Close() 645 | if err != nil { 646 | return nil, err 647 | } 648 | subsetVal = cursor 649 | default: 650 | stmt.Close() 651 | return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) 652 | } 653 | stmt.colValue = append(stmt.colValue, subsetVal) 654 | } else { 655 | stmt.placeholders++ 656 | stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) 657 | stmt.colValue = append(stmt.colValue, value) 658 | } 659 | } 660 | return stmt, nil 661 | } 662 | 663 | // hook to simulate broken connections 664 | var hookPrepareBadConn func() bool 665 | 666 | func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { 667 | panic("use PrepareContext") 668 | } 669 | 670 | func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 671 | c.numPrepare++ 672 | if c.db == nil { 673 | panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) 674 | } 675 | 676 | if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { 677 | return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn} 678 | } 679 | 680 | c.touchMem() 681 | var firstStmt, prev *fakeStmt 682 | for _, query := range strings.Split(query, ";") { 683 | parts := strings.Split(query, "|") 684 | if len(parts) < 1 { 685 | return nil, errf("empty query") 686 | } 687 | stmt := &fakeStmt{q: query, c: c, memToucher: c} 688 | if firstStmt == nil { 689 | firstStmt = stmt 690 | } 691 | if len(parts) >= 3 { 692 | switch parts[0] { 693 | case "PANIC": 694 | stmt.panic = parts[1] 695 | parts = parts[2:] 696 | case "WAIT": 697 | wait, err := time.ParseDuration(parts[1]) 698 | if err != nil { 699 | return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) 700 | } 701 | parts = parts[2:] 702 | stmt.wait = wait 703 | } 704 | } 705 | cmd := parts[0] 706 | stmt.cmd = cmd 707 | parts = parts[1:] 708 | 709 | if c.waiter != nil { 710 | c.waiter(ctx) 711 | if err := ctx.Err(); err != nil { 712 | return nil, err 713 | } 714 | } 715 | 716 | if stmt.wait > 0 { 717 | wait := time.NewTimer(stmt.wait) 718 | select { 719 | case <-wait.C: 720 | case <-ctx.Done(): 721 | wait.Stop() 722 | return nil, ctx.Err() 723 | } 724 | } 725 | 726 | c.incrStat(&c.stmtsMade) 727 | var err error 728 | switch cmd { 729 | case "WIPE": 730 | // Nothing 731 | case "USE_RAWBYTES": 732 | c.db.useRawBytes.Store(true) 733 | case "SELECT": 734 | stmt, err = c.prepareSelect(stmt, parts) 735 | case "CREATE": 736 | stmt, err = c.prepareCreate(stmt, parts) 737 | case "INSERT": 738 | stmt, err = c.prepareInsert(ctx, stmt, parts) 739 | case "NOSERT": 740 | // Do all the prep-work like for an INSERT but don't actually insert the row. 741 | // Used for some of the concurrent tests. 742 | stmt, err = c.prepareInsert(ctx, stmt, parts) 743 | default: 744 | stmt.Close() 745 | return nil, errf("unsupported command type %q", cmd) 746 | } 747 | if err != nil { 748 | return nil, err 749 | } 750 | if prev != nil { 751 | prev.next = stmt 752 | } 753 | prev = stmt 754 | } 755 | return firstStmt, nil 756 | } 757 | 758 | func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { 759 | if s.panic == "ColumnConverter" { 760 | panic(s.panic) 761 | } 762 | if len(s.placeholderConverter) == 0 { 763 | return driver.DefaultParameterConverter 764 | } 765 | return s.placeholderConverter[idx] 766 | } 767 | 768 | func (s *fakeStmt) Close() error { 769 | if s.panic == "Close" { 770 | panic(s.panic) 771 | } 772 | if s.c == nil { 773 | panic("nil conn in fakeStmt.Close") 774 | } 775 | if s.c.db == nil { 776 | panic("in fakeStmt.Close, conn's db is nil (already closed)") 777 | } 778 | s.touchMem() 779 | if !s.closed { 780 | s.c.incrStat(&s.c.stmtsClosed) 781 | s.closed = true 782 | } 783 | if s.next != nil { 784 | s.next.Close() 785 | } 786 | return nil 787 | } 788 | 789 | var errClosed = errors.New("fakedb: statement has been closed") 790 | 791 | // hook to simulate broken connections 792 | var hookExecBadConn func() bool 793 | 794 | func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { 795 | panic("Using ExecContext") 796 | } 797 | 798 | var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") 799 | 800 | func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 801 | if s.panic == "Exec" { 802 | panic(s.panic) 803 | } 804 | if s.closed { 805 | return nil, errClosed 806 | } 807 | 808 | if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { 809 | return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} 810 | } 811 | if s.c.isDirtyAndMark() { 812 | return nil, errFakeConnSessionDirty 813 | } 814 | 815 | err := checkSubsetTypes(s.c.db.allowAny, args) 816 | if err != nil { 817 | return nil, err 818 | } 819 | s.touchMem() 820 | 821 | if s.wait > 0 { 822 | time.Sleep(s.wait) 823 | } 824 | 825 | select { 826 | default: 827 | case <-ctx.Done(): 828 | return nil, ctx.Err() 829 | } 830 | 831 | db := s.c.db 832 | switch s.cmd { 833 | case "WIPE": 834 | db.wipe() 835 | return driver.ResultNoRows, nil 836 | case "USE_RAWBYTES": 837 | s.c.db.useRawBytes.Store(true) 838 | return driver.ResultNoRows, nil 839 | case "CREATE": 840 | if err := db.createTable(s.table, s.colName, s.colType); err != nil { 841 | return nil, err 842 | } 843 | return driver.ResultNoRows, nil 844 | case "INSERT": 845 | return s.execInsert(args, true) 846 | case "NOSERT": 847 | // Do all the prep-work like for an INSERT but don't actually insert the row. 848 | // Used for some of the concurrent tests. 849 | return s.execInsert(args, false) 850 | } 851 | return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) 852 | } 853 | 854 | // When doInsert is true, add the row to the table. 855 | // When doInsert is false do prep-work and error checking, but don't 856 | // actually add the row to the table. 857 | func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { 858 | db := s.c.db 859 | if len(args) != s.placeholders { 860 | panic("error in pkg db; should only get here if size is correct") 861 | } 862 | db.mu.Lock() 863 | t, ok := db.table(s.table) 864 | db.mu.Unlock() 865 | if !ok { 866 | return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 867 | } 868 | 869 | t.mu.Lock() 870 | defer t.mu.Unlock() 871 | 872 | var cols []any 873 | if doInsert { 874 | cols = make([]any, len(t.colname)) 875 | } 876 | argPos := 0 877 | for n, colname := range s.colName { 878 | colidx := t.columnIndex(colname) 879 | if colidx == -1 { 880 | return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) 881 | } 882 | var val any 883 | if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { 884 | if strvalue == "?" { 885 | val = args[argPos].Value 886 | } else { 887 | // Assign value from argument placeholder name. 888 | for _, a := range args { 889 | if a.Name == strvalue[1:] { 890 | val = a.Value 891 | break 892 | } 893 | } 894 | } 895 | argPos++ 896 | } else { 897 | val = s.colValue[n] 898 | } 899 | if doInsert { 900 | cols[colidx] = val 901 | } 902 | } 903 | 904 | if doInsert { 905 | t.rows = append(t.rows, &row{cols: cols}) 906 | } 907 | return driver.RowsAffected(1), nil 908 | } 909 | 910 | // hook to simulate broken connections 911 | var hookQueryBadConn func() bool 912 | 913 | func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { 914 | panic("Use QueryContext") 915 | } 916 | 917 | func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 918 | if s.panic == "Query" { 919 | panic(s.panic) 920 | } 921 | if s.closed { 922 | return nil, errClosed 923 | } 924 | 925 | if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { 926 | return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} 927 | } 928 | if s.c.isDirtyAndMark() { 929 | return nil, errFakeConnSessionDirty 930 | } 931 | 932 | err := checkSubsetTypes(s.c.db.allowAny, args) 933 | if err != nil { 934 | return nil, err 935 | } 936 | 937 | s.touchMem() 938 | db := s.c.db 939 | if len(args) != s.placeholders { 940 | panic("error in pkg db; should only get here if size is correct") 941 | } 942 | 943 | setMRows := make([][]*row, 0, 1) 944 | setColumns := make([][]string, 0, 1) 945 | setColType := make([][]string, 0, 1) 946 | 947 | for { 948 | db.mu.Lock() 949 | t, ok := db.table(s.table) 950 | db.mu.Unlock() 951 | if !ok { 952 | return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) 953 | } 954 | 955 | if s.table == "magicquery" { 956 | if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { 957 | if args[0].Value == "sleep" { 958 | time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) 959 | } 960 | } 961 | } 962 | if s.table == "tx_status" && s.colName[0] == "tx_status" { 963 | txStatus := "autocommit" 964 | if s.c.currTx != nil { 965 | txStatus = "transaction" 966 | } 967 | cursor := &rowsCursor{ 968 | db: s.c.db, 969 | parentMem: s.c, 970 | posRow: -1, 971 | rows: [][]*row{ 972 | { 973 | { 974 | cols: []any{ 975 | txStatus, 976 | }, 977 | }, 978 | }, 979 | }, 980 | cols: [][]string{ 981 | { 982 | "tx_status", 983 | }, 984 | }, 985 | colType: [][]string{ 986 | { 987 | "string", 988 | }, 989 | }, 990 | errPos: -1, 991 | } 992 | return cursor, nil 993 | } 994 | 995 | t.mu.Lock() 996 | 997 | colIdx := make(map[string]int) // select column name -> column index in table 998 | for _, name := range s.colName { 999 | idx := t.columnIndex(name) 1000 | if idx == -1 { 1001 | t.mu.Unlock() 1002 | return nil, fmt.Errorf("fakedb: unknown column name %q", name) 1003 | } 1004 | colIdx[name] = idx 1005 | } 1006 | 1007 | mrows := []*row{} 1008 | rows: 1009 | for _, trow := range t.rows { 1010 | // Process the where clause, skipping non-match rows. This is lazy 1011 | // and just uses fmt.Sprintf("%v") to test equality. Good enough 1012 | // for test code. 1013 | for _, wcol := range s.whereCol { 1014 | idx := t.columnIndex(wcol.Column) 1015 | if idx == -1 { 1016 | t.mu.Unlock() 1017 | return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) 1018 | } 1019 | tcol := trow.cols[idx] 1020 | if bs, ok := tcol.([]byte); ok { 1021 | // lazy hack to avoid sprintf %v on a []byte 1022 | tcol = string(bs) 1023 | } 1024 | var argValue any 1025 | if wcol.Placeholder == "?" { 1026 | argValue = args[wcol.Ordinal-1].Value 1027 | } else { 1028 | // Assign arg value from placeholder name. 1029 | for _, a := range args { 1030 | if a.Name == wcol.Placeholder[1:] { 1031 | argValue = a.Value 1032 | break 1033 | } 1034 | } 1035 | } 1036 | if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { 1037 | continue rows 1038 | } 1039 | } 1040 | mrow := &row{cols: make([]any, len(s.colName))} 1041 | for seli, name := range s.colName { 1042 | mrow.cols[seli] = trow.cols[colIdx[name]] 1043 | } 1044 | mrows = append(mrows, mrow) 1045 | } 1046 | 1047 | var colType []string 1048 | for _, column := range s.colName { 1049 | colType = append(colType, t.coltype[t.columnIndex(column)]) 1050 | } 1051 | 1052 | t.mu.Unlock() 1053 | 1054 | setMRows = append(setMRows, mrows) 1055 | setColumns = append(setColumns, s.colName) 1056 | setColType = append(setColType, colType) 1057 | 1058 | if s.next == nil { 1059 | break 1060 | } 1061 | s = s.next 1062 | } 1063 | 1064 | cursor := &rowsCursor{ 1065 | db: s.c.db, 1066 | parentMem: s.c, 1067 | posRow: -1, 1068 | rows: setMRows, 1069 | cols: setColumns, 1070 | colType: setColType, 1071 | errPos: -1, 1072 | } 1073 | return cursor, nil 1074 | } 1075 | 1076 | func (s *fakeStmt) NumInput() int { 1077 | if s.panic == "NumInput" { 1078 | panic(s.panic) 1079 | } 1080 | return s.placeholders 1081 | } 1082 | 1083 | // hook to simulate broken connections 1084 | var hookCommitBadConn func() bool 1085 | 1086 | func (tx *fakeTx) Commit() error { 1087 | tx.c.currTx = nil 1088 | if hookCommitBadConn != nil && hookCommitBadConn() { 1089 | return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1090 | } 1091 | tx.c.touchMem() 1092 | return nil 1093 | } 1094 | 1095 | // hook to simulate broken connections 1096 | var hookRollbackBadConn func() bool 1097 | 1098 | func (tx *fakeTx) Rollback() error { 1099 | tx.c.currTx = nil 1100 | if hookRollbackBadConn != nil && hookRollbackBadConn() { 1101 | return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} 1102 | } 1103 | tx.c.touchMem() 1104 | return nil 1105 | } 1106 | 1107 | type rowsCursor struct { 1108 | db *fakeDB 1109 | parentMem memToucher 1110 | cols [][]string 1111 | colType [][]string 1112 | posSet int 1113 | posRow int 1114 | rows [][]*row 1115 | closed bool 1116 | 1117 | // errPos and err are for making Next return early with error. 1118 | errPos int 1119 | err error 1120 | 1121 | // a clone of slices to give out to clients, indexed by the 1122 | // original slice's first byte address. we clone them 1123 | // just so we're able to corrupt them on close. 1124 | bytesClone map[*byte][]byte 1125 | 1126 | // Every operation writes to line to enable the race detector 1127 | // check for data races. 1128 | // This is separate from the fakeConn.line to allow for drivers that 1129 | // can start multiple queries on the same transaction at the same time. 1130 | line int64 1131 | 1132 | // closeErr is returned when rowsCursor.Close 1133 | closeErr error 1134 | } 1135 | 1136 | func (rc *rowsCursor) touchMem() { 1137 | rc.parentMem.touchMem() 1138 | rc.line++ 1139 | } 1140 | 1141 | func (rc *rowsCursor) Close() error { 1142 | rc.touchMem() 1143 | rc.parentMem.touchMem() 1144 | rc.closed = true 1145 | return rc.closeErr 1146 | } 1147 | 1148 | func (rc *rowsCursor) Columns() []string { 1149 | return rc.cols[rc.posSet] 1150 | } 1151 | 1152 | func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { 1153 | return colTypeToReflectType(rc.colType[rc.posSet][index]) 1154 | } 1155 | 1156 | var rowsCursorNextHook func(dest []driver.Value) error 1157 | 1158 | func (rc *rowsCursor) Next(dest []driver.Value) error { 1159 | if rowsCursorNextHook != nil { 1160 | return rowsCursorNextHook(dest) 1161 | } 1162 | 1163 | if rc.closed { 1164 | return errors.New("fakedb: cursor is closed") 1165 | } 1166 | rc.touchMem() 1167 | rc.posRow++ 1168 | if rc.posRow == rc.errPos { 1169 | return rc.err 1170 | } 1171 | if rc.posRow >= len(rc.rows[rc.posSet]) { 1172 | return io.EOF // per interface spec 1173 | } 1174 | for i, v := range rc.rows[rc.posSet][rc.posRow].cols { 1175 | // TODO(bradfitz): convert to subset types? naah, I 1176 | // think the subset types should only be input to 1177 | // driver, but the sql package should be able to handle 1178 | // a wider range of types coming out of drivers. all 1179 | // for ease of drivers, and to prevent drivers from 1180 | // messing up conversions or doing them differently. 1181 | dest[i] = v 1182 | 1183 | if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() { 1184 | if rc.bytesClone == nil { 1185 | rc.bytesClone = make(map[*byte][]byte) 1186 | } 1187 | clone, ok := rc.bytesClone[&bs[0]] 1188 | if !ok { 1189 | clone = make([]byte, len(bs)) 1190 | copy(clone, bs) 1191 | rc.bytesClone[&bs[0]] = clone 1192 | } 1193 | dest[i] = clone 1194 | } 1195 | } 1196 | return nil 1197 | } 1198 | 1199 | func (rc *rowsCursor) HasNextResultSet() bool { 1200 | rc.touchMem() 1201 | return rc.posSet < len(rc.rows)-1 1202 | } 1203 | 1204 | func (rc *rowsCursor) NextResultSet() error { 1205 | rc.touchMem() 1206 | if rc.HasNextResultSet() { 1207 | rc.posSet++ 1208 | rc.posRow = -1 1209 | return nil 1210 | } 1211 | return io.EOF // Per interface spec. 1212 | } 1213 | 1214 | // fakeDriverString is like driver.String, but indirects pointers like 1215 | // DefaultValueConverter. 1216 | // 1217 | // This could be surprising behavior to retroactively apply to 1218 | // driver.String now that Go1 is out, but this is convenient for 1219 | // our TestPointerParamsAndScans. 1220 | type fakeDriverString struct{} 1221 | 1222 | func (fakeDriverString) ConvertValue(v any) (driver.Value, error) { 1223 | switch c := v.(type) { 1224 | case string, []byte: 1225 | return v, nil 1226 | case *string: 1227 | if c == nil { 1228 | return nil, nil 1229 | } 1230 | return *c, nil 1231 | } 1232 | return fmt.Sprintf("%v", v), nil 1233 | } 1234 | 1235 | type anyTypeConverter struct{} 1236 | 1237 | func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) { 1238 | return v, nil 1239 | } 1240 | 1241 | func converterForType(typ string) driver.ValueConverter { 1242 | switch typ { 1243 | case "bool": 1244 | return driver.Bool 1245 | case "nullbool": 1246 | return driver.Null{Converter: driver.Bool} 1247 | case "byte", "int16": 1248 | return driver.NotNull{Converter: driver.DefaultParameterConverter} 1249 | case "int32": 1250 | return driver.Int32 1251 | case "nullbyte", "nullint32", "nullint16": 1252 | return driver.Null{Converter: driver.DefaultParameterConverter} 1253 | case "string": 1254 | return driver.NotNull{Converter: fakeDriverString{}} 1255 | case "nullstring": 1256 | return driver.Null{Converter: fakeDriverString{}} 1257 | case "int64": 1258 | // TODO(coopernurse): add type-specific converter 1259 | return driver.NotNull{Converter: driver.DefaultParameterConverter} 1260 | case "nullint64": 1261 | // TODO(coopernurse): add type-specific converter 1262 | return driver.Null{Converter: driver.DefaultParameterConverter} 1263 | case "float64": 1264 | // TODO(coopernurse): add type-specific converter 1265 | return driver.NotNull{Converter: driver.DefaultParameterConverter} 1266 | case "nullfloat64": 1267 | // TODO(coopernurse): add type-specific converter 1268 | return driver.Null{Converter: driver.DefaultParameterConverter} 1269 | case "datetime": 1270 | return driver.NotNull{Converter: driver.DefaultParameterConverter} 1271 | case "nulldatetime": 1272 | return driver.Null{Converter: driver.DefaultParameterConverter} 1273 | case "any": 1274 | return anyTypeConverter{} 1275 | } 1276 | panic("invalid fakedb column type of " + typ) 1277 | } 1278 | 1279 | func colTypeToReflectType(typ string) reflect.Type { 1280 | switch typ { 1281 | case "bool": 1282 | return reflect.TypeOf(false) 1283 | case "nullbool": 1284 | return reflect.TypeOf(sql.NullBool{}) 1285 | case "int16": 1286 | return reflect.TypeOf(int16(0)) 1287 | case "nullint16": 1288 | return reflect.TypeOf(sql.NullInt16{}) 1289 | case "int32": 1290 | return reflect.TypeOf(int32(0)) 1291 | case "nullint32": 1292 | return reflect.TypeOf(sql.NullInt32{}) 1293 | case "string": 1294 | return reflect.TypeOf("") 1295 | case "nullstring": 1296 | return reflect.TypeOf(sql.NullString{}) 1297 | case "int64": 1298 | return reflect.TypeOf(int64(0)) 1299 | case "nullint64": 1300 | return reflect.TypeOf(sql.NullInt64{}) 1301 | case "float64": 1302 | return reflect.TypeOf(float64(0)) 1303 | case "nullfloat64": 1304 | return reflect.TypeOf(sql.NullFloat64{}) 1305 | case "datetime": 1306 | return reflect.TypeOf(time.Time{}) 1307 | case "any": 1308 | return reflect.TypeOf(new(any)).Elem() 1309 | } 1310 | panic("invalid fakedb column type of " + typ) 1311 | } 1312 | --------------------------------------------------------------------------------