├── .github └── workflows │ ├── lint.yml │ └── test.yml ├── .gitignore ├── LICENCE ├── README.md ├── bench_test.go ├── cursor.go ├── debug.go ├── debug_test.go ├── exec.go ├── exec_test.go ├── go.mod ├── go.sum ├── helpers.go ├── helpers_test.go ├── interfaces.go ├── mapper.go ├── mapper_struct.go ├── mapper_test.go ├── mod.go ├── pgxscan └── pgxscan.go ├── row.go ├── source.go └── stdscan └── stdscan.go /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | push: 4 | pull_request: 5 | permissions: 6 | contents: read 7 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 8 | # pull-requests: read 9 | jobs: 10 | golangci: 11 | name: lint 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/setup-go@v3 15 | with: 16 | go-version: 1.18 17 | - uses: actions/checkout@v3 18 | - name: golangci-lint 19 | uses: golangci/golangci-lint-action@v3 20 | with: 21 | version: latest 22 | args: --verbose 23 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: [push, pull_request] 3 | jobs: 4 | test: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - name: Checkout Repo 8 | uses: actions/checkout@v3 9 | 10 | - name: Setup Go 11 | uses: actions/setup-go@v3 12 | with: 13 | go-version: 1.18 14 | 15 | - name: Install Dependencies 16 | run: go mod download 17 | 18 | - name: Run tests 19 | run: go test -race -covermode atomic -coverprofile=covprofile.out ./... 20 | 21 | - name: Send coverage 22 | uses: shogo82148/actions-goveralls@v1 23 | with: 24 | path-to-profile: covprofile.out 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # profile files 2 | *.prof 3 | # Test executable compiled when generating profiles 4 | scan.test 5 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Stephen Afam-Osemene 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scan 2 | 3 | [![Test Status](https://github.com/stephenafamo/scan/actions/workflows/test.yml/badge.svg)](https://github.com/stephenafamo/scan/actions/workflows/test.yml) 4 | ![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/stephenafamo/scan) 5 | [![Go Reference](https://pkg.go.dev/badge/github.com/stephenafamo/scan.svg)](https://pkg.go.dev/github.com/stephenafamo/scan) 6 | [![Go Report Card](https://goreportcard.com/badge/github.com/stephenafamo/scan)](https://goreportcard.com/report/github.com/stephenafamo/scan) 7 | ![GitHub tag (latest SemVer)](https://img.shields.io/github/v/tag/stephenafamo/scan) 8 | [![Coverage Status](https://coveralls.io/repos/github/stephenafamo/scan/badge.svg)](https://coveralls.io/github/stephenafamo/scan) 9 | 10 | Scan provides the ability to use database/sql/rows to scan datasets directly to any defined structure. 11 | 12 | ## Reference 13 | 14 | - Standard library scan package. For use with `database/sql`. [Link](https://pkg.go.dev/github.com/stephenafamo/scan/stdscan) 15 | - PGX library scan package. For use with `github.com/jackc/pgx/v5`. [Link](https://pkg.go.dev/github.com/stephenafamo/scan/pgxscan) 16 | - Base scan package. For use with any implementation of [`scan.Queryer`](https://pkg.go.dev/github.com/stephenafamo/scan#Queryer). [Link](https://pkg.go.dev/github.com/stephenafamo/scan) 17 | 18 | ## Using with `database/sql` 19 | 20 | ```go 21 | package main 22 | 23 | import ( 24 | "context" 25 | "database/sql" 26 | 27 | "github.com/stephenafamo/scan" 28 | "github.com/stephenafamo/scan/stdscan" 29 | ) 30 | 31 | type User struct { 32 | ID int 33 | Name string 34 | Email string 35 | Age int 36 | } 37 | 38 | func main() { 39 | ctx := context.Background() 40 | db, _ := sql.Open("postgres", "example-connection-url") 41 | 42 | // count: 5 43 | count, _ := stdscan.One(ctx, db, scan.SingleColumnMapper[int], "SELECT COUNT(*) FROM users") 44 | // []int{1, 2, 3, 4, 5} 45 | userIDs, _ := stdscan.All(ctx, db, scan.SingleColumnMapper[int], "SELECT id FROM users") 46 | // []User{...} 47 | users, _ := stdscan.All(ctx, db, scan.StructMapper[User](), `SELECT id, name, email, age FROM users`) 48 | } 49 | 50 | func collectIDandEmail(_ context.Context, c cols) any { 51 | return func(v *Values) (int, string, error) { 52 | return Value[int](v, "id"), Value[string](v, "email"), nil 53 | } 54 | } 55 | ``` 56 | 57 | And many more!! 58 | 59 | ## Using with [pgx](https://github.com/jackc/pgx) 60 | 61 | ```go 62 | ctx := context.Background() 63 | db, _ := pgxpool.New(ctx, "example-connection-url") 64 | 65 | // []User{...} 66 | users, _ := pgxscan.All(ctx, db, scan.StructMapper[User](), `SELECT id, name, email, age FROM users`) 67 | ``` 68 | 69 | ## Using with other DB packages 70 | 71 | Instead of `github.com/stephenafamo/scan/stdscan`, use the base package `github.com/stephenafam/scan` which only needs an executor that implements the right interface. 72 | Both `stdscan` and `pgxscan` are based on this. 73 | 74 | ## How it works 75 | 76 | ### Scanning Functions 77 | 78 | #### `One()` 79 | 80 | Use `One()` to scan and return **a single** row. 81 | 82 | ```go 83 | // User{...} 84 | user, _ := stdscan.One(ctx, db, scan.StructMapper[User](), `SELECT id, name, email, age FROM users`) 85 | ``` 86 | 87 | #### `All()` 88 | 89 | Use `All()` to scan and return **all** rows. 90 | 91 | ```go 92 | // []User{...} 93 | users, _ := stdscan.All(ctx, db, scan.StructMapper[User](), `SELECT id, name, email, age FROM users`) 94 | ``` 95 | 96 | #### `Each()` 97 | 98 | Use `Each()` to iterate over the rows of a query using range. 99 | 100 | It works with the [range-over-func](https://tip.golang.org/blog/range-functions) syntax. 101 | 102 | ```go 103 | // []User{...} 104 | for user, err := range scan.Each(ctx, db, scan.StructMapper[User](), `SELECT id, name, email, age FROM users`) { 105 | if err != nil { 106 | return err 107 | } 108 | // do something with user 109 | } 110 | ``` 111 | 112 | #### `Cursor()` 113 | 114 | Use `Cursor()` to scan each row on demand. This is useful when retrieving large results. 115 | 116 | ```go 117 | c, _ := stdscan.Cursor(ctx, db, scan.StructMapper[User](), `SELECT id, name, email, age FROM users`) 118 | defer c.Close() 119 | 120 | for c.Next() { 121 | // User{...} 122 | user := c.Get() 123 | } 124 | ``` 125 | 126 | ### Mappers 127 | 128 | Each of these functions takes a `Mapper` to indicate how each row should be scanned. 129 | The `Mapper` has the signature: 130 | 131 | ```go 132 | type Mapper[T any] func(context.Context, cols) (before BeforeFunc, after func(any) (T, error)) 133 | 134 | type BeforeFunc = func(*Row) (link any, err error) 135 | ``` 136 | 137 | A mapper returns 2 functions 138 | 139 | - **before**: This is called before scanning the row. The mapper should schedule scans using the `ScheduleScan` or `ScheduleScanx` methods of the `Row`. The return value of the **before** function is passed to the **after** function after scanning values from the database. 140 | - **after**: This is called after the scan operation. The mapper should then covert the link value back to the desired concrete type. 141 | 142 | There are some builtin mappers for common cases: 143 | 144 | #### `ColumnMapper[T any](name string)` 145 | 146 | Maps the value of a single column to the given type. The name of the column must be specified 147 | 148 | ```go 149 | // []string{"user1@example.com", "user2@example.com", "user3@example.com", ...} 150 | emails, _ := stdscan.All(ctx, db, scan.ColumnMapper[string]("email"), `SELECT id, name, email FROM users`) 151 | ``` 152 | 153 | #### `SingleColumnMapper[T any]` 154 | 155 | For queries that return only one column. Since only one column is returned, there is no need to specify the column name. 156 | This is why it throws an error if the query returns more than one column. 157 | 158 | ```go 159 | // []string{"user1@example.com", "user2@example.com", "user3@example.com", ...} 160 | emails, _ := stdscan.All(ctx, db, scan.SingleColumnMapper[string], `SELECT email FROM users`) 161 | ``` 162 | 163 | #### `SliceMapper[T any]` 164 | 165 | Maps a row into a slice of values `[]T`. Unless all the columns are of the same type, it will likely be used to map the row to `[]any`. 166 | 167 | ```go 168 | // [][]any{ 169 | // []any{1, "John Doe", "john@example.com"}, 170 | // []any{2, "Jane Doe", "jane@example.com"}, 171 | // ... 172 | // } 173 | users, _ := stdscan.All(ctx, db, scan.SliceMapper[any], `SELECT id, name, email FROM users`) 174 | ``` 175 | 176 | #### `MapMapper[T any]` 177 | 178 | Maps a row into a map of values `map[string]T`. The key of the map is the column names. Unless all columns are of the same type, it will likely be used to map to `map[string]any`. 179 | 180 | ```go 181 | // []map[string]any{ 182 | // map[string]any{"id": 1, "name": John Doe", "email": "john@example.com"}, 183 | // map[string]any{"id": 2, "name": Jane Doe", "email": "jane@example.com"}, 184 | // ... 185 | // } 186 | users, _ := stdscan.All(ctx, db, scan.MapMapper[any], `SELECT id, name, email FROM users`) 187 | ``` 188 | 189 | #### `StructMapper[T any](...MappingOption)` 190 | 191 | This is the most advanced mapper. Scans column values into the fields of the struct. 192 | 193 | ```go 194 | type User struct { 195 | ID int `db:"id"` 196 | Name string `db:"name"` 197 | Email string `db:"email"` 198 | Age int `db:"age"` 199 | } 200 | 201 | // []User{...} 202 | users, _ := stdscan.All(ctx, db, scan.StructMapper[User](), `SELECT id, name, email, age FROM users`) 203 | ``` 204 | 205 | The default behaviour of `StructMapper` is often good enough. For more advanced use cases, some options can be passed to the StructMapper. 206 | 207 | - **WithStructTagPrefix**: Use this when every column from the database has a prefix. 208 | 209 | ```go 210 | users, _ := stdscan.All(ctx, db, scan.StructMapper[User](scan.WithStructTagPrefix("user-")), 211 | `SELECT id AS "user-id", name AS "user-name" FROM users`, 212 | ) 213 | ``` 214 | 215 | - **WithRowValidator**: If the `StructMapper` has a row validator, the values will be sent to it before scanning. If the row is invalid (i.e. it returns false), then scanning is skipped and the zero value of the row-type is returned. 216 | 217 | - **WithTypeConverter**: If the `StructMapper` has a type converter, all fields of the struct are converted to a new type using the `ConverType` method. After scanning, the values are restored into the struct using the `OriginalValue` method. 218 | 219 | #### `CustomStructMapper[T any](MapperSource, ...MappingSourceOption)` 220 | 221 | Uses a custom struct maping source which should have been created with [NewStructMapperSource](https://pkg.go.dev/github.com/stephenafamo/scan#NewStructMapperSource). 222 | 223 | This works the same way as `StructMapper`, but instead of using the default mapping source, it uses a custom one. 224 | 225 | In the example below, we want to use a `scan` as the struct tag key instead of `db` 226 | 227 | ```go 228 | type User struct { 229 | ID int `scan:"id"` 230 | Name string `scan:"name"` 231 | Email string `scan:"email"` 232 | Age int `scan:"age"` 233 | } 234 | 235 | src, _ := NewStructMapperSource(scan.WithStructTagKey("scan")) 236 | // []User{...} 237 | users, _ := stdscan.All(ctx, db, scan.StructMapper[User](), `SELECT id, name, email, age FROM users`) 238 | ``` 239 | 240 | These are the options that can be passed to `NewStructMapperSource`: 241 | 242 | - **WithStructTagKey**: Change the struct tag used to map columns to struct fields. Default: **db** 243 | - **WithColumnSeparator**: Change the separator for column names of nested struct fields. Default: **.** 244 | - **WithFieldNameMapper**: Change how Struct field names are mapped to column names when there are no struct tags. Default: **snake_case** (i.e. `CreatedAt` is mapped to `created_at`). 245 | - **WithScannableTypes**: Pass a list of interfaces that if implemented, can be scanned by the executor. This means that a field with this type is treated as a single value and will not check the nested fields. Default: `*sql.Scanner`. 246 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "testing" 11 | "time" 12 | 13 | _ "github.com/stephenafamo/fakedb" 14 | ) 15 | 16 | // Benchmark Command: go test -v -run=XXX -cpu 1,2,4 -benchmem -bench=. -memprofile mem.prof 17 | 18 | var ( 19 | db *sql.DB 20 | dataSize = 100 21 | ) 22 | 23 | func TestMain(m *testing.M) { 24 | var err error 25 | 26 | ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 27 | defer cancel() 28 | 29 | db, err = sql.Open("test", "bench") 30 | if err != nil { 31 | panic(fmt.Errorf("Error opening testdb %w", err)) 32 | } 33 | defer db.Close() 34 | 35 | err = prepareData(ctx) 36 | if err != nil { 37 | panic(err) 38 | } 39 | 40 | exitVal := m.Run() 41 | 42 | os.Exit(exitVal) 43 | } 44 | 45 | func BenchmarkScanAll(b *testing.B) { 46 | b.StopTimer() 47 | ctx := context.Background() 48 | 49 | for i := 0; i < b.N; i++ { 50 | b.StopTimer() 51 | rows, err := db.Query("SELECT|user||") 52 | if err != nil { 53 | panic(err) 54 | } 55 | b.StartTimer() 56 | if _, err := AllFromRows(ctx, StructMapper[Userss](), rows); err != nil { 57 | panic(err) 58 | } 59 | rows.Close() 60 | } 61 | } 62 | 63 | func BenchmarkScanOne(b *testing.B) { 64 | b.StopTimer() 65 | ctx := context.Background() 66 | 67 | for i := 0; i < b.N; i++ { 68 | b.StopTimer() 69 | rows, err := db.Query("SELECT|user||") 70 | if err != nil { 71 | panic(err) 72 | } 73 | b.StartTimer() 74 | if _, err := OneFromRows(ctx, StructMapper[Userss](), rows); err != nil { 75 | panic(err) 76 | } 77 | rows.Close() 78 | } 79 | } 80 | 81 | func prepareData(ctx context.Context) error { 82 | create := "CREATE|user|id=int64,username=string,password=string" 83 | create += ",email=string,mobile_phone=string,company=string,avatar_url=string" 84 | create += ",role=int16,last_online_at=int64,create_at=datetime,update_at=datetime" 85 | 86 | if _, err := db.ExecContext(ctx, create); err != nil { 87 | return err 88 | } 89 | 90 | for i := 0; i < dataSize; i++ { 91 | userName := fmt.Sprintf("user%d", i+1) 92 | password := fmt.Sprintf("password%d", i+1) 93 | email := fmt.Sprintf("user%d@sqlscan.com", i+1) 94 | mobilePhone := fmt.Sprintf("%d", 10000*(i+1)) 95 | company := fmt.Sprintf("company%d", i+1) 96 | avatarURL := fmt.Sprintf("http://sqlscan.com/avatar/%d", i+1) 97 | role := i % 3 98 | lastOnlineAt := time.Now().Unix() + int64(i) 99 | createAt := time.Now().UTC() 100 | updateAt := time.Now().UTC() 101 | _, err := db.Exec(`INSERT|user|id=?,username=?,password=?,email=?,mobile_phone=?,company=?,avatar_url=?,role=?,last_online_at=?,create_at=?,update_at=?`, 102 | i, userName, password, email, mobilePhone, company, avatarURL, role, lastOnlineAt, createAt, updateAt) 103 | if err != nil { 104 | return err 105 | } 106 | } 107 | 108 | return nil 109 | } 110 | 111 | type Userss struct { 112 | ID int `db:"id"` 113 | UserName string `db:"username"` 114 | Password string `db:"password"` 115 | Email string `db:"email"` 116 | MobilePhone string `db:"mobile_phone"` 117 | Company string `db:"company"` 118 | AvatarURL string `db:"avatar_url"` 119 | Role int `db:"role"` 120 | LastOnlineAt int64 `db:"last_online_at"` 121 | CreateAt time.Time `db:"create_at"` 122 | UpdateAt time.Time `db:"update_at"` 123 | } 124 | 125 | func TestPrepare(t *testing.T) { 126 | cnt := 0 127 | rows, err := db.Query("SELECT|user||") 128 | if err != nil { 129 | t.Fatal(err) 130 | } 131 | defer rows.Close() 132 | for rows.Next() { 133 | cnt++ 134 | } 135 | if cnt != dataSize { 136 | t.Error("wrong cnt") 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /cursor.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | type ICursor[T any] interface { 4 | // Close the underlying rows 5 | Close() error 6 | // Prepare the next row 7 | Next() bool 8 | // Get the values of the current row 9 | Get() (T, error) 10 | // Return any error with the underlying rows 11 | Err() error 12 | } 13 | 14 | type cursor[T any] struct { 15 | v *Row 16 | before func(*Row) (any, error) 17 | after func(any) (T, error) 18 | } 19 | 20 | func (c *cursor[T]) Close() error { 21 | return c.v.r.Close() 22 | } 23 | 24 | func (c *cursor[T]) Err() error { 25 | return c.v.r.Err() 26 | } 27 | 28 | func (c *cursor[T]) Next() bool { 29 | return c.v.r.Next() 30 | } 31 | 32 | func (c *cursor[T]) Get() (T, error) { 33 | return scanOneRow(c.v, c.before, c.after) 34 | } 35 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "os" 8 | ) 9 | 10 | func Debug(q Queryer, w io.Writer) Queryer { 11 | if w == nil { 12 | w = os.Stdout 13 | } 14 | 15 | return debugQueryer{w: w, q: q} 16 | } 17 | 18 | type debugQueryer struct { 19 | w io.Writer 20 | q Queryer 21 | } 22 | 23 | func (d debugQueryer) QueryContext(ctx context.Context, query string, args ...any) (Rows, error) { 24 | fmt.Fprintln(d.w, query) 25 | fmt.Fprintln(d.w, []any(args)) 26 | return d.q.QueryContext(ctx, query, args...) 27 | } 28 | -------------------------------------------------------------------------------- /debug_test.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "os" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | type NoopQueryer struct{} 13 | 14 | func (n NoopQueryer) QueryContext(ctx context.Context, query string, args ...any) (Rows, error) { 15 | return nil, nil 16 | } 17 | 18 | func TestDebugQueryerDefaultWriter(t *testing.T) { 19 | d, ok := Debug(NoopQueryer{}, nil).(debugQueryer) 20 | if !ok { 21 | t.Fatal("DebugQueryer does not return an instance of debugQueryer") 22 | } 23 | 24 | debugFile, ok := d.w.(*os.File) 25 | if !ok { 26 | t.Fatal("writer for debugQueryer is not an *os.File") 27 | } 28 | 29 | if debugFile != os.Stdout { 30 | t.Fatal("writer for debugQueryer is not os.Stdout") 31 | } 32 | } 33 | 34 | func TestDebugQueryer(t *testing.T) { 35 | dest := &bytes.Buffer{} 36 | exec := Debug(NoopQueryer{}, dest) 37 | 38 | sql := "A QUERY" 39 | args := []any{"arg1", "arg2", "arg3"} 40 | 41 | _, err := exec.QueryContext(context.Background(), sql, args...) 42 | if err != nil { 43 | t.Fatal("error running QueryContext") 44 | } 45 | 46 | debugsql, debugArgsStr, found := strings.Cut(dest.String(), "\n") 47 | if !found { 48 | t.Fatalf("arg delimiter not found in\n%s", dest.String()) 49 | } 50 | 51 | if strings.TrimSpace(debugsql) != sql { 52 | t.Fatalf("wrong debug sql.\nExpected: %s\nGot: %s", sql, strings.TrimSpace(debugsql)) 53 | } 54 | 55 | debugArgsStr = strings.TrimSpace(debugArgsStr) 56 | debugArgsStr = strings.Trim(debugArgsStr, "[]") 57 | 58 | var debugArgs []string //nolint:prealloc 59 | for _, s := range strings.Fields(debugArgsStr) { 60 | s := strings.TrimSpace(s) 61 | if s == "" { 62 | continue 63 | } 64 | 65 | debugArgs = append(debugArgs, s) 66 | } 67 | 68 | if len(debugArgs) != len(args) { 69 | t.Fatalf("wrong length of debug args.\nExpected: %d\nGot: %d\n\n%s", len(args), len(debugArgs), debugArgs) 70 | } 71 | 72 | for i := range args { 73 | argStr := strings.TrimSpace(fmt.Sprint(args[i])) 74 | debugStr := strings.TrimSpace(debugArgs[i]) 75 | if argStr != debugStr { 76 | t.Fatalf("wrong debug arg %d.\nExpected: %s\nGot: %s", i, argStr, debugStr) 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /exec.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | // One scans a single row from the query and maps it to T using a [Queryer] 9 | func One[T any](ctx context.Context, exec Queryer, m Mapper[T], query string, args ...any) (T, error) { 10 | var t T 11 | 12 | rows, err := exec.QueryContext(ctx, query, args...) 13 | if err != nil { 14 | return t, err 15 | } 16 | defer rows.Close() 17 | 18 | return OneFromRows(ctx, m, rows) 19 | } 20 | 21 | // OneFromRows scans a single row from the given [Rows] result and maps it to T using a [Queryer] 22 | func OneFromRows[T any](ctx context.Context, m Mapper[T], rows Rows) (T, error) { 23 | var t T 24 | 25 | allowUnknown, _ := ctx.Value(CtxKeyAllowUnknownColumns).(bool) 26 | v, err := wrapRows(rows, allowUnknown) 27 | if err != nil { 28 | return t, err 29 | } 30 | 31 | before, after := m(ctx, v.columnsCopy()) 32 | 33 | if !rows.Next() { 34 | if err = rows.Err(); err != nil { 35 | return t, err 36 | } 37 | return t, sql.ErrNoRows 38 | } 39 | 40 | t, err = scanOneRow(v, before, after) 41 | if err != nil { 42 | return t, err 43 | } 44 | 45 | return t, rows.Err() 46 | } 47 | 48 | // All scans all rows from the query and returns a slice []T of all rows using a [Queryer] 49 | func All[T any](ctx context.Context, exec Queryer, m Mapper[T], query string, args ...any) ([]T, error) { 50 | rows, err := exec.QueryContext(ctx, query, args...) 51 | if err != nil { 52 | return nil, err 53 | } 54 | defer rows.Close() 55 | 56 | return AllFromRows(ctx, m, rows) 57 | } 58 | 59 | // AllFromRows scans all rows from the given [Rows] and returns a slice []T of all rows using a [Queryer] 60 | func AllFromRows[T any](ctx context.Context, m Mapper[T], rows Rows) ([]T, error) { 61 | allowUnknown, _ := ctx.Value(CtxKeyAllowUnknownColumns).(bool) 62 | v, err := wrapRows(rows, allowUnknown) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | before, after := m(ctx, v.columnsCopy()) 68 | 69 | var results []T 70 | for rows.Next() { 71 | one, err := scanOneRow(v, before, after) 72 | if err != nil { 73 | return nil, err 74 | } 75 | 76 | results = append(results, one) 77 | } 78 | 79 | return results, rows.Err() 80 | } 81 | 82 | // Cursor runs a query and returns a cursor that works similar to *sql.Rows 83 | func Cursor[T any](ctx context.Context, exec Queryer, m Mapper[T], query string, args ...any) (ICursor[T], error) { 84 | rows, err := exec.QueryContext(ctx, query, args...) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | return CursorFromRows(ctx, m, rows) 90 | } 91 | 92 | // Each returns a function that can be used to iterate over the rows of a query 93 | // this function works with range-over-func so it is possible to do 94 | // 95 | // for val, err := range scan.Each(ctx, exec, m, query, args...) { 96 | // if err != nil { 97 | // return err 98 | // } 99 | // // do something with val 100 | // } 101 | func Each[T any](ctx context.Context, exec Queryer, m Mapper[T], query string, args ...any) func(func(T, error) bool) { 102 | rows, err := exec.QueryContext(ctx, query, args...) 103 | if err != nil { 104 | return func(yield func(T, error) bool) { yield(*new(T), err) } 105 | } 106 | 107 | allowUnknown, _ := ctx.Value(CtxKeyAllowUnknownColumns).(bool) 108 | wrapped, err := wrapRows(rows, allowUnknown) 109 | if err != nil { 110 | rows.Close() 111 | return func(yield func(T, error) bool) { yield(*new(T), err) } 112 | } 113 | 114 | before, after := m(ctx, wrapped.columnsCopy()) 115 | 116 | return func(yield func(T, error) bool) { 117 | defer rows.Close() 118 | 119 | for rows.Next() { 120 | val, err := scanOneRow(wrapped, before, after) 121 | if !yield(val, err) { 122 | return 123 | } 124 | } 125 | } 126 | } 127 | 128 | // CursorFromRows returns a cursor from [Rows] that works similar to *sql.Rows 129 | func CursorFromRows[T any](ctx context.Context, m Mapper[T], rows Rows) (ICursor[T], error) { 130 | allowUnknown, _ := ctx.Value(CtxKeyAllowUnknownColumns).(bool) 131 | v, err := wrapRows(rows, allowUnknown) 132 | if err != nil { 133 | return nil, err 134 | } 135 | 136 | before, after := m(ctx, v.columnsCopy()) 137 | 138 | return &cursor[T]{ 139 | v: v, 140 | before: before, 141 | after: after, 142 | }, nil 143 | } 144 | 145 | func scanOneRow[T any](v *Row, before func(*Row) (any, error), after func(any) (T, error)) (T, error) { 146 | val, err := before(v) 147 | if err != nil { 148 | var t T 149 | return t, err 150 | } 151 | 152 | err = v.scanCurrentRow() 153 | if err != nil { 154 | var t T 155 | return t, err 156 | } 157 | 158 | return after(val) 159 | } 160 | -------------------------------------------------------------------------------- /exec_test.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/google/go-cmp/cmp" 12 | _ "github.com/stephenafamo/fakedb" 13 | ) 14 | 15 | func createDB(tb testing.TB, cols [][2]string) (*sql.DB, func()) { 16 | tb.Helper() 17 | db, err := sql.Open("test", "foo") 18 | if err != nil { 19 | tb.Fatalf("Error opening testdb %v", err) 20 | } 21 | 22 | first := true 23 | b := &strings.Builder{} 24 | fmt.Fprintf(b, "CREATE|%s|", tb.Name()) 25 | 26 | for _, def := range cols { 27 | if !first { 28 | b.WriteString(",") 29 | } else { 30 | first = false 31 | } 32 | 33 | fmt.Fprintf(b, "%s=%s", def[0], def[1]) 34 | } 35 | 36 | exec(tb, db, b.String()) 37 | return db, func() { 38 | exec(tb, db, fmt.Sprintf("DROP|%s", tb.Name())) 39 | } 40 | } 41 | 42 | func exec(tb testing.TB, exec *sql.DB, query string, args ...interface{}) sql.Result { 43 | tb.Helper() 44 | result, err := exec.ExecContext(context.Background(), query, args...) 45 | if err != nil { 46 | tb.Fatalf("Exec of %q: %v", query, err) 47 | } 48 | 49 | return result 50 | } 51 | 52 | func insert(tb testing.TB, ex *sql.DB, cols []string, vals ...[]any) { 53 | tb.Helper() 54 | query := fmt.Sprintf("INSERT|%s|%s=?", tb.Name(), strings.Join(cols, "=?,")) 55 | for _, val := range vals { 56 | exec(tb, ex, query, val...) 57 | } 58 | } 59 | 60 | func createQuery(tb testing.TB, cols []string) string { 61 | tb.Helper() 62 | return fmt.Sprintf("SELECT|%s|%s|", tb.Name(), strings.Join(cols, ",")) 63 | } 64 | 65 | type queryCase[T any] struct { 66 | ctx context.Context 67 | columns strstr 68 | rows rows 69 | query []string // columns to select 70 | mapper Mapper[T] 71 | expectOne T 72 | expectAll []T 73 | expectedErr error 74 | } 75 | 76 | func testQuery[T any](t *testing.T, name string, tc queryCase[T]) { 77 | t.Helper() 78 | 79 | t.Run(name, func(t *testing.T) { 80 | ctx := tc.ctx 81 | if ctx == nil { 82 | ctx = context.Background() 83 | } 84 | 85 | ex, clean := createDB(t, tc.columns) 86 | defer clean() 87 | 88 | insert(t, ex, colSliceFromMap(tc.columns), tc.rows...) 89 | query := createQuery(t, tc.query) 90 | 91 | queryer := stdQ{ex} 92 | t.Run("one", func(t *testing.T) { 93 | one, err := One(ctx, queryer, tc.mapper, query) 94 | if diff := diffErr(tc.expectedErr, err); diff != "" { 95 | t.Fatalf("diff: %s", diff) 96 | } 97 | 98 | if diff := cmp.Diff(tc.expectOne, one); diff != "" { 99 | t.Fatalf("diff: %s", diff) 100 | } 101 | }) 102 | 103 | t.Run("all", func(t *testing.T) { 104 | all, err := All(ctx, queryer, tc.mapper, query) 105 | if diff := diffErr(tc.expectedErr, err); diff != "" { 106 | t.Fatalf("diff: %s", diff) 107 | } 108 | 109 | if diff := cmp.Diff(tc.expectAll, all); diff != "" { 110 | t.Fatalf("diff: %s", diff) 111 | } 112 | }) 113 | 114 | t.Run("each", func(t *testing.T) { 115 | var i int 116 | Each(ctx, queryer, tc.mapper, query)(func(val T, err error) bool { 117 | if diff := diffErr(tc.expectedErr, err); diff != "" { 118 | t.Fatalf("diff: %s", diff) 119 | } 120 | 121 | if err != nil { 122 | return false 123 | } 124 | 125 | if diff := cmp.Diff(tc.expectAll[i], val); diff != "" { 126 | t.Fatalf("diff: %s", diff) 127 | } 128 | i++ 129 | return true 130 | }) 131 | if i != len(tc.expectAll) { 132 | t.Fatalf("Should have %d rows, but each only scanned %d", len(tc.expectAll), i) 133 | } 134 | }) 135 | 136 | t.Run("cursor", func(t *testing.T) { 137 | c, err := Cursor(ctx, queryer, tc.mapper, query) 138 | if err != nil { 139 | t.Fatalf("error getting cursor: %v", err) 140 | return 141 | } 142 | defer c.Close() 143 | 144 | var i int 145 | for c.Next() { 146 | v, err := c.Get() 147 | if diff := diffErr(tc.expectedErr, err); diff != "" { 148 | t.Fatalf("diff: %s", diff) 149 | } 150 | 151 | if err != nil { 152 | return 153 | } 154 | 155 | if diff := cmp.Diff(tc.expectAll[i], v); diff != "" { 156 | t.Fatalf("diff: %s", diff) 157 | } 158 | 159 | i++ 160 | } 161 | 162 | if i != len(tc.expectAll) { 163 | t.Fatalf("Should have %d rows, but cursor only scanned %d", len(tc.expectAll), i) 164 | } 165 | 166 | if diff := diffErr(tc.expectedErr, c.Err()); diff != "" { 167 | t.Fatalf("diff: %s", diff) 168 | } 169 | }) 170 | }) 171 | } 172 | 173 | func TestSingleValue(t *testing.T) { 174 | testQuery(t, "int", queryCase[int]{ 175 | columns: strstr{{"id", "int64"}}, 176 | rows: singleRows(1, 2, 3, 5, 8, 13, 21), 177 | query: []string{"id"}, 178 | mapper: SingleColumnMapper[int], 179 | expectOne: 1, 180 | expectAll: []int{1, 2, 3, 5, 8, 13, 21}, 181 | }) 182 | 183 | testQuery(t, "string", queryCase[string]{ 184 | columns: strstr{{"name", "string"}}, 185 | rows: singleRows("first", "second", "third"), 186 | query: []string{"name"}, 187 | mapper: SingleColumnMapper[string], 188 | expectOne: "first", 189 | expectAll: []string{"first", "second", "third"}, 190 | }) 191 | 192 | time1 := randate() 193 | time2 := randate() 194 | time3 := randate() 195 | testQuery(t, "datetime", queryCase[time.Time]{ 196 | columns: strstr{{"when", "datetime"}}, 197 | rows: singleRows(time1, time2, time3), 198 | query: []string{"when"}, 199 | mapper: SingleColumnMapper[time.Time], 200 | expectOne: time1, 201 | expectAll: []time.Time{time1, time2, time3}, 202 | }) 203 | } 204 | 205 | func TestColumnValue(t *testing.T) { 206 | testQuery(t, "int", queryCase[int]{ 207 | columns: strstr{{"id", "int64"}}, 208 | rows: singleRows(1, 2, 3, 5, 8, 13, 21), 209 | query: []string{"id"}, 210 | mapper: ColumnMapper[int]("id"), 211 | expectOne: 1, 212 | expectAll: []int{1, 2, 3, 5, 8, 13, 21}, 213 | }) 214 | 215 | testQuery(t, "unknown", queryCase[int]{ 216 | columns: strstr{{"id", "int64"}}, 217 | rows: singleRows(1, 2, 3, 5, 8, 13, 21), 218 | query: []string{"id"}, 219 | mapper: ColumnMapper[int]("unknown_column"), 220 | expectedErr: createError(nil, "unknown_column"), 221 | }) 222 | } 223 | 224 | func TestMap(t *testing.T) { 225 | user1 := map[string]any{"id": int64(1), "name": "foo"} 226 | user2 := map[string]any{"id": int64(2), "name": "bar"} 227 | 228 | testQuery(t, "user", queryCase[map[string]any]{ 229 | columns: strstr{{"id", "int64"}, {"name", "string"}}, 230 | rows: rows{[]any{1, "foo"}, []any{2, "bar"}}, 231 | query: []string{"id", "name"}, 232 | mapper: MapMapper[any], 233 | expectOne: user1, 234 | expectAll: []map[string]any{user1, user2}, 235 | }) 236 | } 237 | 238 | func TestStruct(t *testing.T) { 239 | user1 := User{ID: 1, Name: "foo"} 240 | user2 := User{ID: 2, Name: "bar"} 241 | 242 | testQuery(t, "user", queryCase[User]{ 243 | columns: strstr{{"id", "int64"}, {"name", "string"}}, 244 | rows: rows{[]any{1, "foo"}, []any{2, "bar"}}, 245 | query: []string{"id", "name"}, 246 | mapper: StructMapper[User](), 247 | expectOne: user1, 248 | expectAll: []User{user1, user2}, 249 | }) 250 | 251 | testQuery(t, "user with type converter", queryCase[User]{ 252 | columns: strstr{{"id", "int64"}, {"name", "string"}}, 253 | rows: rows{[]any{1, "foo"}, []any{2, "bar"}}, 254 | query: []string{"id", "name"}, 255 | mapper: StructMapper[User](WithTypeConverter(typeConverter{})), 256 | expectOne: user1, 257 | expectAll: []User{user1, user2}, 258 | }) 259 | 260 | testQuery(t, "user with unknown column", queryCase[User]{ 261 | columns: strstr{ 262 | {"id", "int64"}, 263 | {"name", "string"}, 264 | {"missing1", "int64"}, 265 | {"missing2", "string"}, 266 | }, 267 | rows: rows{[]any{1, "foo", "100", "foobar"}, []any{2, "bar", "200", "barfoo"}}, 268 | query: []string{"id", "name", "missing1", "missing2"}, 269 | mapper: StructMapper[User](), 270 | expectedErr: createError(nil, "no destination", "missing1"), 271 | }) 272 | 273 | testQuery(t, "userWithMod", queryCase[*User]{ 274 | columns: strstr{{"id", "int64"}, {"name", "string"}}, 275 | rows: rows{[]any{1, "foo"}, []any{2, "bar"}}, 276 | query: []string{"id", "name"}, 277 | mapper: CustomStructMapper[*User](defaultStructMapper, WithMapperMods(userMod)), 278 | expectOne: &User{ID: 200, Name: "foo modified"}, 279 | expectAll: []*User{ 280 | {ID: 200, Name: "foo modified"}, 281 | {ID: 400, Name: "bar modified"}, 282 | }, 283 | }) 284 | 285 | createdAt1 := randate() 286 | createdAt2 := randate() 287 | updatedAt1 := randate() 288 | updatedAt2 := randate() 289 | timestamp1 := &Timestamps{CreatedAt: createdAt1, UpdatedAt: updatedAt1} 290 | timestamp2 := &Timestamps{CreatedAt: createdAt2, UpdatedAt: updatedAt2} 291 | 292 | testQuery(t, "userwithtimestamps", queryCase[UserWithTimestamps]{ 293 | columns: strstr{ 294 | {"id", "int64"}, 295 | {"name", "string"}, 296 | {"created_at", "datetime"}, 297 | {"updated_at", "datetime"}, 298 | }, 299 | rows: rows{ 300 | []any{1, "foo", createdAt1, updatedAt1}, 301 | []any{2, "bar", createdAt2, updatedAt2}, 302 | }, 303 | query: []string{"id", "name", "created_at", "updated_at"}, 304 | mapper: StructMapper[UserWithTimestamps](), 305 | expectOne: UserWithTimestamps{User: user1, Timestamps: timestamp1}, 306 | expectAll: []UserWithTimestamps{ 307 | {User: user1, Timestamps: timestamp1}, 308 | {User: user2, Timestamps: timestamp2}, 309 | }, 310 | }) 311 | } 312 | 313 | func TestAllowUnknownColumns(t *testing.T) { 314 | type testStruct struct { 315 | ID int64 316 | Int int64 317 | } 318 | 319 | // fails when context does not have CtxKeyAllowUnknownColumns set to true 320 | testQuery(t, "unknowncolumnsnotallowed", queryCase[testStruct]{ 321 | columns: strstr{{"id", "int64"}, {"ignored_int", "int64"}, {"int", "int64"}}, 322 | rows: rows{{1, 10, 1}, {2, 20, 2}}, 323 | query: []string{"id", "ignored_int", "int"}, 324 | mapper: StructMapper[testStruct](), 325 | expectedErr: createError(fmt.Errorf("No destination for column ignored_int"), "no destination", "ignored_int"), 326 | }) 327 | 328 | // succeeds when context has CtxKeyAllowUnknownColumns set to true 329 | testQuery(t, "unknowncolumnsallowed", queryCase[testStruct]{ 330 | ctx: context.WithValue(context.Background(), CtxKeyAllowUnknownColumns, true), 331 | columns: strstr{{"id", "int64"}, {"ignored_int", "int64"}, {"int", "int64"}}, 332 | rows: rows{{1, 10, 1}, {2, 20, 2}}, 333 | query: []string{"id", "ignored_int", "int"}, 334 | mapper: StructMapper[testStruct](), 335 | expectOne: testStruct{ID: 1, Int: 1}, 336 | expectAll: []testStruct{{ID: 1, Int: 1}, {ID: 2, Int: 2}}, 337 | }) 338 | } 339 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/stephenafamo/scan 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/aarondl/opt v0.0.0-20221129170750-3d40c96d9bb8 7 | github.com/google/go-cmp v0.5.8 8 | github.com/jackc/pgx/v5 v5.2.0 9 | github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97 10 | ) 11 | 12 | require ( 13 | github.com/aarondl/json v0.0.0-20221020222930-8b0db17ef1bf // indirect 14 | github.com/jackc/pgpassfile v1.0.0 // indirect 15 | github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect 16 | golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 // indirect 17 | golang.org/x/text v0.3.8 // indirect 18 | ) 19 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/aarondl/json v0.0.0-20221020222930-8b0db17ef1bf h1:+edM69bH/X6JpYPmJYBRLanAMe1V5yRXYU3hHUovGcE= 2 | github.com/aarondl/json v0.0.0-20221020222930-8b0db17ef1bf/go.mod h1:FZqLhJSj2tg0ZN48GB1zvj00+ZYcHPqgsC7yzcgCq6k= 3 | github.com/aarondl/opt v0.0.0-20221129170750-3d40c96d9bb8 h1:pAJut2Ye6sxwIS8zQvu1BhX87B+9MwUKmzjdEkwPWg4= 4 | github.com/aarondl/opt v0.0.0-20221129170750-3d40c96d9bb8/go.mod h1:l4/5NZtYd/SIohsFhaJQQe+sPOTG22furpZ5FvcYOzk= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= 8 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 9 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 10 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 11 | github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= 12 | github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= 13 | github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8= 14 | github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk= 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 16 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 17 | github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97 h1:XItoZNmhOih06TC02jK7l3wlpZ0XT/sPQYutDcGOQjg= 18 | github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97/go.mod h1:bM3Vmw1IakoaXocHmMIGgJFYob0vuK+CFWiJHQvz0jQ= 19 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 20 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 21 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 22 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= 23 | golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM= 24 | golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 25 | golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= 26 | golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= 27 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 28 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 29 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 30 | -------------------------------------------------------------------------------- /helpers.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import "reflect" 4 | 5 | func typeOf[T any]() reflect.Type { 6 | return reflect.TypeOf((*T)(nil)).Elem() 7 | } 8 | -------------------------------------------------------------------------------- /helpers_test.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "math/rand" 9 | "reflect" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "github.com/aarondl/opt" 15 | "github.com/google/go-cmp/cmp" 16 | ) 17 | 18 | type ( 19 | strstr = [][2]string 20 | rows = [][]any 21 | ) 22 | 23 | var ( 24 | now = time.Now() 25 | goodSlice = []any{ 26 | now, 27 | 100, 28 | "A string", 29 | sql.NullString{Valid: false}, 30 | "another string", 31 | []byte("interesting"), 32 | } 33 | ) 34 | 35 | type stdQ struct { 36 | *sql.DB 37 | } 38 | 39 | func (s stdQ) QueryContext(ctx context.Context, query string, args ...any) (Rows, error) { 40 | return s.DB.QueryContext(ctx, query, args...) 41 | } 42 | 43 | type Timestamps struct { 44 | CreatedAt time.Time 45 | UpdatedAt time.Time 46 | } 47 | 48 | type PtrTimestamps struct { 49 | CreatedAt *time.Time 50 | UpdatedAt *time.Time 51 | } 52 | 53 | type User struct { 54 | ID int 55 | Name string 56 | } 57 | 58 | type PtrUser1 struct { 59 | ID *int 60 | Name string 61 | PtrTimestamps 62 | } 63 | 64 | type PtrUser2 struct { 65 | ID int 66 | Name *string 67 | *PtrTimestamps 68 | } 69 | 70 | type UserWithTimestamps struct { 71 | User 72 | *Timestamps 73 | Blog *Blog 74 | } 75 | 76 | type Blog struct { 77 | ID int 78 | User UserWithTimestamps 79 | } 80 | 81 | type Tagged struct { 82 | ID int `db:"tag_id" custom:"custom_id"` 83 | Name string `db:"tag_name" custom:"custom_name"` 84 | Email string `db:"EMAIL"` 85 | Exclude int `db:"-" custom:"-"` 86 | } 87 | 88 | type ScannableUser struct { 89 | ID int 90 | Name string 91 | } 92 | 93 | func (s ScannableUser) Scan() { 94 | } 95 | 96 | type wrapper struct { 97 | V any 98 | } 99 | 100 | // Scan implements the sql.Scanner interface. If the wrapped type implements 101 | // sql.Scanner then it will call that. 102 | func (v *wrapper) Scan(value any) error { 103 | if scanner, ok := v.V.(sql.Scanner); ok { 104 | return scanner.Scan(value) 105 | } 106 | 107 | if err := opt.ConvertAssign(v.V, value); err != nil { 108 | return fmt.Errorf("convert assign err: %w", err) 109 | } 110 | 111 | return nil 112 | } 113 | 114 | type typeConverter struct{} 115 | 116 | func (d typeConverter) TypeToDestination(typ reflect.Type) reflect.Value { 117 | val := reflect.ValueOf(&wrapper{ 118 | V: reflect.New(typ).Interface(), 119 | }) 120 | 121 | return val 122 | } 123 | 124 | func (d typeConverter) ValueFromDestination(val reflect.Value) reflect.Value { 125 | return val.Elem().FieldByName("V").Elem().Elem() 126 | } 127 | 128 | func toPtr[T any](v T) *T { 129 | return &v 130 | } 131 | 132 | // To quickly generate column definition for tests 133 | // make it in the form {"1": 1, "2": 2} 134 | func columns(n int) []string { 135 | m := make([]string, n) 136 | for i := 0; i < n; i++ { 137 | m[i] = strconv.Itoa(i) 138 | } 139 | 140 | return m 141 | } 142 | 143 | func columnNames(names ...string) []string { 144 | return names 145 | } 146 | 147 | func colSliceFromMap(c [][2]string) []string { 148 | s := make([]string, 0, len(c)) 149 | for _, def := range c { 150 | s = append(s, def[0]) 151 | } 152 | return s 153 | } 154 | 155 | func singleRows[T any](vals ...T) rows { 156 | r := make(rows, len(vals)) 157 | for k, v := range vals { 158 | r[k] = []any{v} 159 | } 160 | 161 | return r 162 | } 163 | 164 | func mapToVals[T any](vals []any) map[string]T { 165 | m := make(map[string]T, len(vals)) 166 | for i, v := range vals { 167 | m[strconv.Itoa(i)] = v.(T) 168 | } 169 | 170 | return m 171 | } 172 | 173 | func randate() time.Time { 174 | min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix() 175 | max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix() 176 | delta := max - min 177 | 178 | sec := rand.Int63n(delta) + min 179 | return time.Unix(sec, 0) 180 | } 181 | 182 | func userMod(ctx context.Context, c cols) (BeforeFunc, AfterMod) { 183 | return func(v *Row) (any, error) { 184 | return nil, nil 185 | }, func(link, retrieved any) error { 186 | u, ok := retrieved.(*User) 187 | if !ok { 188 | return errors.New("wrong retrieved type") 189 | } 190 | if u == nil { 191 | return nil 192 | } 193 | u.ID *= 200 194 | u.Name += " modified" 195 | 196 | return nil 197 | } 198 | } 199 | 200 | func convertMappingError(m *MappingError) string { 201 | return strings.Join(m.meta, " ") 202 | } 203 | 204 | func diffErr(expected, got error) string { 205 | return cmp.Diff(expected, got, cmp.Transformer("convertMappingErr", convertMappingError), equateErrors()) 206 | } 207 | 208 | // equateErrors returns a Comparer option that determines errors to be equal 209 | // if errors.Is reports them to match. The AnyError error can be used to 210 | // match any non-nil error. 211 | func equateErrors() cmp.Option { 212 | return cmp.FilterValues(nonMappingErrors, cmp.Comparer(compareErrors)) 213 | } 214 | 215 | // nonMappingErrors reports whether x and y are types that implement error. 216 | // The input types are deliberately of the interface{} type rather than the 217 | // error type so that we can handle situations where the current type is an 218 | // interface{}, but the underlying concrete types both happen to implement 219 | // the error interface. 220 | func nonMappingErrors(x, y error) bool { 221 | var me *MappingError 222 | ok1 := errors.As(x, &me) 223 | ok2 := errors.As(y, &me) 224 | return !(ok1 && ok2) 225 | } 226 | 227 | func compareErrors(xe, ye error) bool { 228 | if errors.Is(xe, ye) || errors.Is(ye, xe) { 229 | return true 230 | } 231 | 232 | return xe.Error() == ye.Error() 233 | } 234 | -------------------------------------------------------------------------------- /interfaces.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | ) 7 | 8 | type contextKey string 9 | 10 | // Queryer is the main interface used in this package 11 | // it is expected to run the query and args and return a set of Rows 12 | type Queryer interface { 13 | QueryContext(ctx context.Context, query string, args ...any) (Rows, error) 14 | } 15 | 16 | // Rows is an interface that is expected to be returned as the result of a query 17 | type Rows interface { 18 | Scan(...any) error 19 | Columns() ([]string, error) 20 | Next() bool 21 | Close() error 22 | Err() error 23 | } 24 | 25 | type TypeConverter interface { 26 | // TypeToDestination is called with the expected type of the column 27 | // it is expected to return a pointer to the desired value to scan into 28 | // the returned destination is directly scanned into 29 | TypeToDestination(reflect.Type) reflect.Value 30 | 31 | // ValueFromDestination retrieves the original value from the destination 32 | // the returned value is set back to the appropriate struct field 33 | ValueFromDestination(reflect.Value) reflect.Value 34 | } 35 | 36 | // RowValidator is called with pointer to all the values from a row 37 | // to determine if the row is valid 38 | // if it is not, the zero type for that row is returned 39 | type RowValidator = func(cols []string, vals []reflect.Value) bool 40 | 41 | type StructMapperSource interface { 42 | getMapping(reflect.Type) (mapping, error) 43 | } 44 | -------------------------------------------------------------------------------- /mapper.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "strconv" 8 | ) 9 | 10 | type ( 11 | cols = []string 12 | visited map[reflect.Type]int 13 | ) 14 | 15 | func (v visited) copy() visited { 16 | v2 := make(visited, len(v)) 17 | for t, c := range v { 18 | v2[t] = c 19 | } 20 | 21 | return v2 22 | } 23 | 24 | type mapinfo struct { 25 | name string 26 | position []int 27 | init [][]int 28 | isPointer bool 29 | } 30 | 31 | type mapping []mapinfo 32 | 33 | func (m mapping) cols() []string { 34 | cols := make([]string, len(m)) 35 | for i, info := range m { 36 | cols[i] = info.name 37 | } 38 | 39 | return cols 40 | } 41 | 42 | // Mapper is a function that return the mapping functions. 43 | // Any expensive operation, like reflection should be done outside the returned 44 | // function. 45 | // It is called with the columns from the query to get the mapping functions 46 | // which is then used to map every row. 47 | // 48 | // The Mapper does not return an error itself to make it less cumbersome 49 | // It is recommended to instead return a function that returns an error 50 | // the [ErrorMapper] is provider for this 51 | type Mapper[T any] func(context.Context, cols) (before BeforeFunc, after func(any) (T, error)) 52 | 53 | // BeforeFunc is returned by a mapper and is called before a row is scanned 54 | // Scans should be scheduled with either 55 | // the [*Row.ScheduleScan] or [*Row.ScheduleScanx] methods 56 | type BeforeFunc = func(*Row) (link any, err error) 57 | 58 | // The generator function does not return an error itself to make it less cumbersome 59 | // so we return a function that only returns an error instead 60 | // This function makes it easy to return this error 61 | func ErrorMapper[T any](err error, meta ...string) (func(*Row) (any, error), func(any) (T, error)) { 62 | err = createError(err, meta...) 63 | 64 | return func(v *Row) (any, error) { 65 | return nil, err 66 | }, func(any) (T, error) { 67 | var t T 68 | return t, err 69 | } 70 | } 71 | 72 | // Returns a [MappingError] with some optional metadata 73 | func createError(err error, meta ...string) error { 74 | if me, ok := err.(*MappingError); ok && len(meta) == 0 { 75 | return me 76 | } 77 | 78 | return &MappingError{cause: err, meta: meta} 79 | } 80 | 81 | // MappingError wraps another error and holds some additional metadata 82 | type MappingError struct { 83 | meta []string // easy compare 84 | cause error 85 | } 86 | 87 | // Unwrap returns the wrapped error 88 | func (m *MappingError) Unwrap() error { 89 | return m.cause 90 | } 91 | 92 | // Error implements the error interface 93 | func (m *MappingError) Error() string { 94 | if m.cause == nil { 95 | return "" 96 | } 97 | 98 | return m.cause.Error() 99 | } 100 | 101 | // For queries that return only one column 102 | // throws an error if there is more than one column 103 | func SingleColumnMapper[T any](ctx context.Context, c cols) (before func(*Row) (any, error), after func(any) (T, error)) { 104 | if len(c) != 1 { 105 | err := fmt.Errorf("Expected 1 column but got %d columns", len(c)) 106 | return ErrorMapper[T](err, "wrong column count", "1", strconv.Itoa(len(c))) 107 | } 108 | 109 | return func(v *Row) (any, error) { 110 | var t T 111 | v.ScheduleScan(c[0], &t) 112 | return &t, nil 113 | }, func(v any) (T, error) { 114 | return *(v.(*T)), nil 115 | } 116 | } 117 | 118 | // Map a column by name. 119 | func ColumnMapper[T any](name string) func(ctx context.Context, c cols) (before func(*Row) (any, error), after func(any) (T, error)) { 120 | return func(ctx context.Context, c cols) (before func(*Row) (any, error), after func(any) (T, error)) { 121 | return func(v *Row) (any, error) { 122 | var t T 123 | v.ScheduleScan(name, &t) 124 | return &t, nil 125 | }, func(v any) (T, error) { 126 | return *(v.(*T)), nil 127 | } 128 | } 129 | } 130 | 131 | // Maps each row into []any in the order 132 | func SliceMapper[T any](ctx context.Context, c cols) (before func(*Row) (any, error), after func(any) ([]T, error)) { 133 | return func(v *Row) (any, error) { 134 | row := make([]T, len(c)) 135 | 136 | for index, name := range c { 137 | v.ScheduleScan(name, &row[index]) 138 | } 139 | 140 | return row, nil 141 | }, func(v any) ([]T, error) { 142 | return v.([]T), nil 143 | } 144 | } 145 | 146 | // Maps all rows into map[string]T 147 | // Most likely used with interface{} to get a map[string]interface{} 148 | func MapMapper[T any](ctx context.Context, c cols) (before func(*Row) (any, error), after func(any) (map[string]T, error)) { 149 | return func(v *Row) (any, error) { 150 | row := make([]*T, len(c)) 151 | 152 | for index, name := range c { 153 | var t T 154 | v.ScheduleScan(name, &t) 155 | row[index] = &t 156 | } 157 | 158 | return row, nil 159 | }, func(v any) (map[string]T, error) { 160 | row := make(map[string]T, len(c)) 161 | slice := v.([]*T) 162 | for index, name := range c { 163 | row[name] = *slice[index] 164 | } 165 | 166 | return row, nil 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /mapper_struct.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | // CtxKeyAllowUnknownColumns makes it possible to allow unknown columns using the context 10 | var CtxKeyAllowUnknownColumns contextKey = "allow unknown columns" 11 | 12 | // Uses reflection to create a mapping function for a struct type 13 | // using the default options 14 | func StructMapper[T any](opts ...MappingOption) Mapper[T] { 15 | return CustomStructMapper[T](defaultStructMapper, opts...) 16 | } 17 | 18 | func StructMapperColumns[T any](opts ...MappingOption) ([]string, error) { 19 | return CustomStructMapperColumns[T](defaultStructMapper, opts...) 20 | } 21 | 22 | // Uses reflection to create a mapping function for a struct type 23 | // using with custom options 24 | func CustomStructMapper[T any](src StructMapperSource, optMod ...MappingOption) Mapper[T] { 25 | opts := mappingOptions{} 26 | for _, o := range optMod { 27 | o(&opts) 28 | } 29 | 30 | mod := func(ctx context.Context, c cols) (func(*Row) (any, error), func(any) (T, error)) { 31 | return structMapperFrom[T](ctx, c, src, opts) 32 | } 33 | 34 | if len(opts.mapperMods) > 0 { 35 | mod = Mod(mod, opts.mapperMods...) 36 | } 37 | 38 | return mod 39 | } 40 | 41 | func CustomStructMapperColumns[T any](src StructMapperSource, optMod ...MappingOption) ([]string, error) { 42 | opts := mappingOptions{} 43 | for _, o := range optMod { 44 | o(&opts) 45 | } 46 | 47 | if len(opts.mapperMods) > 0 { 48 | return nil, fmt.Errorf("Mapper mods are not supported in CustomStructMapperColumns") 49 | } 50 | 51 | typ := typeOf[T]() 52 | 53 | _, err := checks(typ) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | mapping, err := src.getMapping(typ) 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | return mapping.cols(), nil 64 | } 65 | 66 | func structMapperFrom[T any](ctx context.Context, c cols, s StructMapperSource, opts mappingOptions) (func(*Row) (any, error), func(any) (T, error)) { 67 | typ := typeOf[T]() 68 | 69 | isPointer, err := checks(typ) 70 | if err != nil { 71 | return ErrorMapper[T](err) 72 | } 73 | 74 | mapping, err := s.getMapping(typ) 75 | if err != nil { 76 | return ErrorMapper[T](err) 77 | } 78 | 79 | return mapperFromMapping[T](mapping, typ, isPointer, opts)(ctx, c) 80 | } 81 | 82 | // Check if there are any errors, and returns if it is a pointer or not 83 | func checks(typ reflect.Type) (bool, error) { 84 | if typ == nil { 85 | return false, fmt.Errorf("Nil type passed to StructMapper") 86 | } 87 | 88 | var isPointer bool 89 | 90 | switch { 91 | case typ.Kind() == reflect.Struct: 92 | case typ.Kind() == reflect.Pointer: 93 | isPointer = true 94 | 95 | if typ.Elem().Kind() != reflect.Struct { 96 | return false, fmt.Errorf("Type %q is not a struct or pointer to a struct", typ.String()) 97 | } 98 | default: 99 | return false, fmt.Errorf("Type %q is not a struct or pointer to a struct", typ.String()) 100 | } 101 | 102 | return isPointer, nil 103 | } 104 | 105 | type mappingOptions struct { 106 | typeConverter TypeConverter 107 | rowValidator RowValidator 108 | mapperMods []MapperMod 109 | structTagPrefix string 110 | } 111 | 112 | // MappingeOption is a function type that changes how the mapper is generated 113 | type MappingOption func(*mappingOptions) 114 | 115 | // WithRowValidator sets the [RowValidator] for the struct mapper 116 | // after scanning all values in a row, they are passed to the RowValidator 117 | // if it returns false, the zero value for that row is returned 118 | func WithRowValidator(rv RowValidator) MappingOption { 119 | return func(opt *mappingOptions) { 120 | opt.rowValidator = rv 121 | } 122 | } 123 | 124 | // TypeConverter sets the [TypeConverter] for the struct mapper 125 | // it is called to modify the type of a column and get the original value back 126 | func WithTypeConverter(tc TypeConverter) MappingOption { 127 | return func(opt *mappingOptions) { 128 | opt.typeConverter = tc 129 | } 130 | } 131 | 132 | // WithStructTagPrefix should be used when every column from the database has a prefix. 133 | func WithStructTagPrefix(prefix string) MappingOption { 134 | return func(opt *mappingOptions) { 135 | opt.structTagPrefix = prefix 136 | } 137 | } 138 | 139 | // WithMapperMods accepts mods used to modify the mapper 140 | func WithMapperMods(mods ...MapperMod) MappingOption { 141 | return func(opt *mappingOptions) { 142 | opt.mapperMods = append(opt.mapperMods, mods...) 143 | } 144 | } 145 | 146 | func mapperFromMapping[T any](m mapping, typ reflect.Type, isPointer bool, opts mappingOptions) func(context.Context, cols) (func(*Row) (any, error), func(any) (T, error)) { 147 | return func(ctx context.Context, c cols) (func(*Row) (any, error), func(any) (T, error)) { 148 | // Filter the mapping so we only ask for the available columns 149 | filtered, err := filterColumns(ctx, c, m, opts.structTagPrefix) 150 | if err != nil { 151 | return ErrorMapper[T](err) 152 | } 153 | 154 | mapper := regular[T]{ 155 | typ: typ, 156 | isPointer: isPointer, 157 | filtered: filtered, 158 | converter: opts.typeConverter, 159 | validator: opts.rowValidator, 160 | } 161 | switch { 162 | case opts.typeConverter == nil && opts.rowValidator == nil: 163 | return mapper.regular() 164 | 165 | default: 166 | return mapper.allOptions() 167 | } 168 | } 169 | } 170 | 171 | type regular[T any] struct { 172 | isPointer bool 173 | typ reflect.Type 174 | filtered mapping 175 | converter TypeConverter 176 | validator RowValidator 177 | } 178 | 179 | func (s regular[T]) regular() (func(*Row) (any, error), func(any) (T, error)) { 180 | return func(v *Row) (any, error) { 181 | var row reflect.Value 182 | if s.isPointer { 183 | row = reflect.New(s.typ.Elem()).Elem() 184 | } else { 185 | row = reflect.New(s.typ).Elem() 186 | } 187 | 188 | for _, info := range s.filtered { 189 | for _, v := range info.init { 190 | pv := row.FieldByIndex(v) 191 | if !pv.IsZero() { 192 | continue 193 | } 194 | 195 | pv.Set(reflect.New(pv.Type().Elem())) 196 | } 197 | 198 | fv := row.FieldByIndex(info.position) 199 | v.ScheduleScanx(info.name, fv.Addr()) 200 | } 201 | 202 | return row, nil 203 | }, func(v any) (T, error) { 204 | row := v.(reflect.Value) 205 | 206 | if s.isPointer { 207 | row = row.Addr() 208 | } 209 | 210 | return row.Interface().(T), nil 211 | } 212 | } 213 | 214 | func (s regular[T]) allOptions() (func(*Row) (any, error), func(any) (T, error)) { 215 | return func(v *Row) (any, error) { 216 | row := make([]reflect.Value, len(s.filtered)) 217 | 218 | for i, info := range s.filtered { 219 | var ft reflect.Type 220 | if s.isPointer { 221 | ft = s.typ.Elem().FieldByIndex(info.position).Type 222 | } else { 223 | ft = s.typ.FieldByIndex(info.position).Type 224 | } 225 | 226 | if s.converter != nil { 227 | row[i] = s.converter.TypeToDestination(ft) 228 | } else { 229 | row[i] = reflect.New(ft) 230 | } 231 | 232 | v.ScheduleScanx(info.name, row[i]) 233 | } 234 | 235 | return row, nil 236 | }, func(v any) (T, error) { 237 | vals := v.([]reflect.Value) 238 | 239 | if s.validator != nil && !s.validator(s.filtered.cols(), vals) { 240 | var t T 241 | return t, nil 242 | } 243 | 244 | var row reflect.Value 245 | if s.isPointer { 246 | row = reflect.New(s.typ.Elem()).Elem() 247 | } else { 248 | row = reflect.New(s.typ).Elem() 249 | } 250 | 251 | for i, info := range s.filtered { 252 | for _, v := range info.init { 253 | pv := row.FieldByIndex(v) 254 | if !pv.IsZero() { 255 | continue 256 | } 257 | 258 | pv.Set(reflect.New(pv.Type().Elem())) 259 | } 260 | 261 | var val reflect.Value 262 | if s.converter != nil { 263 | val = s.converter.ValueFromDestination(vals[i]) 264 | } else { 265 | val = vals[i].Elem() 266 | } 267 | 268 | fv := row.FieldByIndex(info.position) 269 | if info.isPointer { 270 | fv.Elem().Set(val) 271 | } else { 272 | fv.Set(val) 273 | } 274 | } 275 | 276 | if s.isPointer { 277 | row = row.Addr() 278 | } 279 | 280 | return row.Interface().(T), nil 281 | } 282 | } 283 | -------------------------------------------------------------------------------- /mapper_test.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/google/go-cmp/cmp" 12 | ) 13 | 14 | type MapperTests[T any] map[string]MapperTest[T] 15 | 16 | type MapperTest[T any] struct { 17 | row *Row 18 | scanned []any 19 | Context map[contextKey]any 20 | Mapper Mapper[T] 21 | ExpectedVal T 22 | ExpectedBeforeError error 23 | ExpectedAfterError error 24 | } 25 | 26 | func RunMapperTests[T any](t *testing.T, cases MapperTests[T]) { 27 | t.Helper() 28 | for name, tc := range cases { 29 | RunMapperTest(t, name, tc) 30 | } 31 | } 32 | 33 | func RunMapperTest[T any](t *testing.T, name string, tc MapperTest[T]) { 34 | t.Helper() 35 | 36 | f := func(t *testing.T) { 37 | t.Helper() 38 | ctx := context.Background() 39 | for k, v := range tc.Context { 40 | ctx = context.WithValue(ctx, k, v) 41 | } 42 | 43 | tc.row.scanDestinations = make([]reflect.Value, len(tc.row.columns)) 44 | 45 | before, after := tc.Mapper(ctx, tc.row.columnsCopy()) 46 | 47 | link, err := before(tc.row) 48 | if diff := diffErr(tc.ExpectedBeforeError, err); diff != "" { 49 | t.Fatalf("diff: %s", diff) 50 | } 51 | 52 | for i, ref := range tc.row.scanDestinations { 53 | if ref == zeroValue { 54 | continue 55 | } 56 | ref.Elem().Set(reflect.ValueOf(tc.scanned[i])) 57 | } 58 | 59 | val, err := after(link) 60 | if diff := diffErr(tc.ExpectedAfterError, err); diff != "" { 61 | t.Fatalf("diff: %s", diff) 62 | } 63 | if diff := cmp.Diff(tc.ExpectedVal, val); diff != "" { 64 | t.Fatalf("diff: %s", diff) 65 | } 66 | } 67 | 68 | if name == "" { 69 | f(t) 70 | } else { 71 | t.Run(name, f) 72 | } 73 | } 74 | 75 | type CustomStructMapperTest[T any] struct { 76 | MapperTest[T] 77 | Options []MappingSourceOption 78 | } 79 | 80 | func RunStructMapperTest[T any](t *testing.T, name string, tc MapperTest[T]) { 81 | t.Helper() 82 | RunMapperTest(t, name, tc) 83 | 84 | allowUnknown, _ := tc.Context[CtxKeyAllowUnknownColumns].(bool) 85 | if allowUnknown { 86 | return 87 | } 88 | 89 | cols, err := StructMapperColumns[T]() 90 | if err != nil { 91 | t.Fatalf("couldn't get columns: %v", err) 92 | } 93 | 94 | ColumnLoop: 95 | for _, col := range tc.row.columns { 96 | for _, c := range cols { 97 | if col == c { 98 | continue ColumnLoop 99 | } 100 | } 101 | t.Fatalf("column %s not found in struct mapper columns", col) 102 | } 103 | } 104 | 105 | func RunCustomStructMapperTest[T any](t *testing.T, name string, tc CustomStructMapperTest[T]) { 106 | t.Helper() 107 | m := tc.MapperTest 108 | src, err := NewStructMapperSource(tc.Options...) 109 | if diff := cmp.Diff(tc.ExpectedBeforeError, err); diff != "" { 110 | t.Fatalf("diff: %s", diff) 111 | } 112 | if err != nil { 113 | return 114 | } 115 | 116 | m.Mapper = CustomStructMapper[T](src) 117 | RunMapperTest(t, name, m) 118 | 119 | allowUnknown, _ := tc.Context[CtxKeyAllowUnknownColumns].(bool) 120 | if allowUnknown { 121 | return 122 | } 123 | 124 | cols, err := CustomStructMapperColumns[T](src) 125 | if err != nil { 126 | t.Fatalf("couldn't get columns: %v", err) 127 | } 128 | 129 | ColumnLoop: 130 | for _, col := range tc.row.columns { 131 | for _, c := range cols { 132 | if col == c { 133 | continue ColumnLoop 134 | } 135 | } 136 | t.Fatalf("column %s not found in struct mapper columns", col) 137 | } 138 | } 139 | 140 | func TestColumnMapper(t *testing.T) { 141 | RunMapperTest(t, "single column", MapperTest[int]{ 142 | row: &Row{ 143 | columns: columns(1), 144 | }, 145 | scanned: []any{100}, 146 | Mapper: ColumnMapper[int]("0"), 147 | ExpectedVal: 100, 148 | }) 149 | 150 | RunMapperTest(t, "multiple columns", MapperTest[int]{ 151 | row: &Row{ 152 | columns: columns(3), 153 | }, 154 | scanned: []any{100, 200, 300}, 155 | Mapper: ColumnMapper[int]("1"), 156 | ExpectedVal: 200, 157 | }) 158 | } 159 | 160 | func TestSingleColumnMapper(t *testing.T) { 161 | RunMapperTest(t, "multiple columns", MapperTest[int]{ 162 | row: &Row{ 163 | columns: columns(2), 164 | }, 165 | scanned: []any{100}, 166 | Mapper: SingleColumnMapper[int], 167 | ExpectedBeforeError: createError(nil, "wrong column count", "1", "2"), 168 | ExpectedAfterError: createError(nil, "wrong column count", "1", "2"), 169 | }) 170 | 171 | RunMapperTest(t, "int", MapperTest[int]{ 172 | row: &Row{ 173 | columns: columns(1), 174 | }, 175 | scanned: []any{100}, 176 | Mapper: SingleColumnMapper[int], 177 | ExpectedVal: 100, 178 | }) 179 | 180 | RunMapperTest(t, "int64", MapperTest[int64]{ 181 | row: &Row{ 182 | columns: columns(1), 183 | }, 184 | scanned: []any{int64(100)}, 185 | Mapper: SingleColumnMapper[int64], 186 | ExpectedVal: 100, 187 | }) 188 | 189 | RunMapperTest(t, "string", MapperTest[string]{ 190 | row: &Row{ 191 | columns: columns(1), 192 | }, 193 | scanned: []any{"A fancy string"}, 194 | Mapper: SingleColumnMapper[string], 195 | ExpectedVal: "A fancy string", 196 | }) 197 | 198 | RunMapperTest(t, "time.Time", MapperTest[time.Time]{ 199 | row: &Row{ 200 | columns: columns(1), 201 | }, 202 | scanned: []any{now}, 203 | Mapper: SingleColumnMapper[time.Time], 204 | ExpectedVal: now, 205 | }) 206 | } 207 | 208 | func TestSliceMapper(t *testing.T) { 209 | RunMapperTest(t, "any slice", MapperTest[[]any]{ 210 | row: &Row{ 211 | columns: columns(len(goodSlice)), 212 | }, 213 | scanned: goodSlice, 214 | Mapper: SliceMapper[any], 215 | ExpectedVal: goodSlice, 216 | }) 217 | 218 | RunMapperTest(t, "int slice", MapperTest[[]int]{ 219 | row: &Row{ 220 | columns: columns(1), 221 | }, 222 | scanned: []any{100}, 223 | Mapper: SliceMapper[int], 224 | ExpectedVal: []int{100}, 225 | }) 226 | } 227 | 228 | func TestMapMapper(t *testing.T) { 229 | RunMapperTest(t, "MapMapper", MapperTest[map[string]any]{ 230 | row: &Row{ 231 | columns: columns(len(goodSlice)), 232 | }, 233 | scanned: goodSlice, 234 | Mapper: MapMapper[any], 235 | ExpectedVal: mapToVals[any](goodSlice), 236 | }) 237 | } 238 | 239 | func TestStructMapper(t *testing.T) { 240 | RunStructMapperTest(t, "Unknown cols permitted", MapperTest[User]{ 241 | row: &Row{ 242 | columns: columnNames("random"), 243 | }, 244 | Mapper: StructMapper[User](), 245 | Context: map[contextKey]any{CtxKeyAllowUnknownColumns: true}, 246 | ExpectedVal: User{}, 247 | }) 248 | 249 | RunStructMapperTest(t, "flat struct", MapperTest[User]{ 250 | row: &Row{ 251 | columns: columnNames("id", "name"), 252 | }, 253 | scanned: []any{1, "The Name"}, 254 | Mapper: StructMapper[User](), 255 | ExpectedVal: User{ID: 1, Name: "The Name"}, 256 | }) 257 | 258 | RunStructMapperTest(t, "with pointer columns 1", MapperTest[PtrUser1]{ 259 | row: &Row{ 260 | columns: columnNames("id", "name", "created_at", "updated_at"), 261 | }, 262 | scanned: []any{toPtr(1), "The Name", &now, toPtr(now.Add(time.Hour))}, 263 | Mapper: StructMapper[PtrUser1](), 264 | ExpectedVal: PtrUser1{ 265 | ID: toPtr(1), Name: "The Name", 266 | PtrTimestamps: PtrTimestamps{CreatedAt: &now, UpdatedAt: toPtr(now.Add(time.Hour))}, 267 | }, 268 | }) 269 | 270 | RunStructMapperTest(t, "with pointer columns 2", MapperTest[PtrUser2]{ 271 | row: &Row{ 272 | columns: columnNames("id", "name", "created_at", "updated_at"), 273 | }, 274 | scanned: []any{1, toPtr("The Name"), &now, toPtr(now.Add(time.Hour))}, 275 | Mapper: StructMapper[PtrUser2](), 276 | ExpectedVal: PtrUser2{ 277 | ID: 1, Name: toPtr("The Name"), 278 | PtrTimestamps: &PtrTimestamps{CreatedAt: &now, UpdatedAt: toPtr(now.Add(time.Hour))}, 279 | }, 280 | }) 281 | 282 | RunStructMapperTest(t, "anonymous embeds", MapperTest[UserWithTimestamps]{ 283 | row: &Row{ 284 | columns: columnNames("id", "name", "created_at", "updated_at"), 285 | }, 286 | scanned: []any{10, "The Name", now, now.Add(time.Hour)}, 287 | Mapper: StructMapper[UserWithTimestamps](), 288 | ExpectedVal: UserWithTimestamps{ 289 | User: User{ID: 10, Name: "The Name"}, 290 | Timestamps: &Timestamps{CreatedAt: now, UpdatedAt: now.Add(time.Hour)}, 291 | }, 292 | }) 293 | 294 | RunStructMapperTest(t, "prefixed structs", MapperTest[Blog]{ 295 | row: &Row{ 296 | columns: columnNames("id", "user.id", "user.name", "user.created_at"), 297 | }, 298 | scanned: []any{100, 10, "The Name", now}, 299 | Mapper: StructMapper[Blog](), 300 | ExpectedVal: Blog{ 301 | ID: 100, 302 | User: UserWithTimestamps{ 303 | User: User{ID: 10, Name: "The Name"}, 304 | Timestamps: &Timestamps{CreatedAt: now}, 305 | }, 306 | }, 307 | }) 308 | 309 | RunStructMapperTest(t, "tagged", MapperTest[Tagged]{ 310 | row: &Row{ 311 | columns: columnNames("tag_id", "tag_name", "EMAIL"), 312 | }, 313 | scanned: []any{1, "The Name", "user@example.com"}, 314 | Mapper: StructMapper[Tagged](), 315 | ExpectedVal: Tagged{ID: 1, Name: "The Name", Email: "user@example.com"}, 316 | }) 317 | 318 | RunCustomStructMapperTest(t, "custom column separator", CustomStructMapperTest[Blog]{ 319 | MapperTest: MapperTest[Blog]{ 320 | row: &Row{ 321 | columns: columnNames("id", "user,id", "user,name", "user,created_at"), 322 | }, 323 | scanned: []any{100, 10, "The Name", now}, 324 | ExpectedVal: Blog{ 325 | ID: 100, 326 | User: UserWithTimestamps{ 327 | User: User{ID: 10, Name: "The Name"}, 328 | Timestamps: &Timestamps{CreatedAt: now}, 329 | }, 330 | }, 331 | }, 332 | Options: []MappingSourceOption{WithColumnSeparator(",")}, 333 | }) 334 | 335 | RunCustomStructMapperTest(t, "custom name mapper", CustomStructMapperTest[Blog]{ 336 | MapperTest: MapperTest[Blog]{ 337 | row: &Row{ 338 | columns: columnNames("ID", "USER.ID", "USER.NAME", "USER.CREATEDAT"), 339 | }, 340 | scanned: []any{100, 10, "The Name", now}, 341 | ExpectedVal: Blog{ 342 | ID: 100, 343 | User: UserWithTimestamps{ 344 | User: User{ID: 10, Name: "The Name"}, 345 | Timestamps: &Timestamps{CreatedAt: now}, 346 | }, 347 | }, 348 | }, 349 | Options: []MappingSourceOption{WithFieldNameMapper(strings.ToUpper)}, 350 | }) 351 | 352 | RunCustomStructMapperTest(t, "custom tag", CustomStructMapperTest[Tagged]{ 353 | MapperTest: MapperTest[Tagged]{ 354 | row: &Row{ 355 | columns: columnNames("custom_id", "custom_name"), 356 | }, 357 | scanned: []any{1, "The Name"}, 358 | ExpectedVal: Tagged{ID: 1, Name: "The Name"}, 359 | }, 360 | Options: []MappingSourceOption{WithStructTagKey("custom")}, 361 | }) 362 | 363 | RunMapperTest(t, "with prefix", MapperTest[User]{ 364 | row: &Row{ 365 | columns: columnNames("prefix--id", "prefix--name"), 366 | }, 367 | scanned: []any{1, "The Name"}, 368 | Mapper: StructMapper[User](WithStructTagPrefix("prefix--")), 369 | ExpectedVal: User{ID: 1, Name: "The Name"}, 370 | }) 371 | 372 | RunMapperTest(t, "with prefix and non-prefixed column", MapperTest[User]{ 373 | row: &Row{ 374 | columns: columnNames("id", "prefix--name"), 375 | }, 376 | scanned: []any{1, "The Name"}, 377 | Mapper: StructMapper[User](WithStructTagPrefix("prefix--")), 378 | ExpectedVal: User{ID: 0, Name: "The Name"}, 379 | }) 380 | 381 | RunMapperTest(t, "with type converter", MapperTest[User]{ 382 | row: &Row{ 383 | columns: columnNames("id", "name"), 384 | }, 385 | scanned: []any{wrapper{toPtr(1)}, wrapper{toPtr("The Name")}}, 386 | Mapper: StructMapper[User](WithTypeConverter(typeConverter{})), 387 | ExpectedVal: User{ID: 1, Name: "The Name"}, 388 | }) 389 | 390 | RunMapperTest(t, "with type converter ptr", MapperTest[*User]{ 391 | row: &Row{ 392 | columns: columnNames("id", "name"), 393 | }, 394 | scanned: []any{wrapper{toPtr(1)}, wrapper{toPtr("The Name")}}, 395 | Mapper: StructMapper[*User](WithTypeConverter(typeConverter{})), 396 | ExpectedVal: &User{ID: 1, Name: "The Name"}, 397 | }) 398 | 399 | RunMapperTest(t, "with type converter deep", MapperTest[PtrUser2]{ 400 | row: &Row{ 401 | columns: columnNames("id", "name", "created_at", "updated_at"), 402 | }, 403 | scanned: []any{ 404 | wrapper{toPtr(1)}, 405 | wrapper{toPtr("The Name")}, 406 | wrapper{&now}, 407 | wrapper{toPtr(now.Add(time.Hour))}, 408 | }, 409 | Mapper: StructMapper[PtrUser2](WithTypeConverter(typeConverter{})), 410 | ExpectedVal: PtrUser2{ 411 | ID: 1, Name: toPtr("The Name"), 412 | PtrTimestamps: &PtrTimestamps{CreatedAt: &now, UpdatedAt: toPtr(now.Add(time.Hour))}, 413 | }, 414 | }) 415 | 416 | RunMapperTest(t, "with row validator pass", MapperTest[User]{ 417 | row: &Row{ 418 | columns: columnNames("id", "name"), 419 | }, 420 | scanned: []any{1, "The Name"}, 421 | Mapper: StructMapper[User](WithRowValidator(func(cols []string, vals []reflect.Value) bool { 422 | for i, c := range cols { 423 | if c == "id" { 424 | return vals[i].Elem().Int() == 1 425 | } 426 | } 427 | 428 | return false 429 | })), 430 | ExpectedVal: User{ID: 1, Name: "The Name"}, 431 | }) 432 | 433 | RunMapperTest(t, "with row validator fail", MapperTest[User]{ 434 | row: &Row{ 435 | columns: columnNames("id", "name"), 436 | }, 437 | scanned: []any{1, "The Name"}, 438 | Mapper: StructMapper[User](WithRowValidator(func(cols []string, vals []reflect.Value) bool { 439 | for i, c := range cols { 440 | if c == "id" { 441 | return vals[i].Elem().Int() == 0 442 | } 443 | } 444 | 445 | return false 446 | })), 447 | ExpectedVal: User{ID: 0, Name: ""}, 448 | }) 449 | 450 | RunMapperTest(t, "with mod", MapperTest[*User]{ 451 | row: &Row{ 452 | columns: columnNames("id", "name"), 453 | }, 454 | scanned: []any{2, "The Name"}, 455 | Mapper: CustomStructMapper[*User](defaultStructMapper, WithMapperMods(userMod)), 456 | ExpectedVal: &User{ID: 400, Name: "The Name modified"}, 457 | }) 458 | } 459 | 460 | func TestScannable(t *testing.T) { 461 | type scannable interface { 462 | Scan() 463 | } 464 | 465 | type BlogWithScannableUser struct { 466 | ID int 467 | User ScannableUser 468 | } 469 | 470 | src, err := NewStructMapperSource(WithScannableTypes( 471 | (*scannable)(nil), 472 | )) 473 | if err != nil { 474 | t.Fatalf("couldn't get mapper source: %v", err) 475 | } 476 | 477 | m, err := src.getMapping(reflect.TypeOf(BlogWithScannableUser{})) 478 | if err != nil { 479 | t.Fatalf("couldn't get mapping: %v", err) 480 | } 481 | 482 | var marked bool 483 | for _, info := range m { 484 | if info.name == "user" { 485 | marked = true 486 | } 487 | } 488 | 489 | if !marked { 490 | t.Fatal("did not mark user as scannable") 491 | } 492 | } 493 | 494 | func TestScannableErrors(t *testing.T) { 495 | cases := map[string]struct { 496 | typ any 497 | err error 498 | }{ 499 | "nil": { 500 | typ: nil, 501 | err: fmt.Errorf("scannable type must be a pointer, got "), 502 | }, 503 | "non-pointer": { 504 | typ: User{}, 505 | err: fmt.Errorf("scannable type must be a pointer, got struct: scan.User"), 506 | }, 507 | "non-interface": { 508 | typ: &User{}, 509 | err: fmt.Errorf("scannable type must be a pointer to an interface, got struct: scan.User"), 510 | }, 511 | } 512 | 513 | for name, test := range cases { 514 | t.Run(name, func(t *testing.T) { 515 | _, err := NewStructMapperSource(WithScannableTypes(test.typ)) 516 | if diff := diffErr(test.err, err); diff != "" { 517 | t.Fatalf("diff: %s", diff) 518 | } 519 | }) 520 | } 521 | } 522 | -------------------------------------------------------------------------------- /mod.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type ( 8 | // MapperMod is a function that can be used to convert an existing mapper 9 | // into a new mapper using [Mod] 10 | MapperMod = func(context.Context, cols) (BeforeFunc, AfterMod) 11 | // AfterMod receives both the link of the [MapperMod] and the retrieved value from 12 | // the original mapper 13 | AfterMod = func(link any, retrieved any) error 14 | ) 15 | 16 | // Mod converts an existing mapper into a new mapper with [MapperMod]s 17 | func Mod[T any](m Mapper[T], mods ...MapperMod) Mapper[T] { 18 | return func(ctx context.Context, c cols) (func(*Row) (any, error), func(any) (T, error)) { 19 | before, after := m(ctx, c) 20 | befores := make([]BeforeFunc, len(mods)) 21 | afters := make([]AfterMod, len(mods)) 22 | links := make([]any, len(mods)) 23 | for i, m := range mods { 24 | befores[i], afters[i] = m(ctx, c) 25 | } 26 | 27 | return func(v *Row) (any, error) { 28 | a, err := before(v) 29 | if err != nil { 30 | return nil, err 31 | } 32 | 33 | for i, b := range befores { 34 | if links[i], err = b(v); err != nil { 35 | return nil, err 36 | } 37 | } 38 | 39 | return a, nil 40 | }, func(v any) (T, error) { 41 | t, err := after(v) 42 | if err != nil { 43 | return t, err 44 | } 45 | 46 | for i, a := range afters { 47 | if err := a(links[i], t); err != nil { 48 | return t, err 49 | } 50 | } 51 | 52 | return t, nil 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /pgxscan/pgxscan.go: -------------------------------------------------------------------------------- 1 | package pgxscan 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/jackc/pgx/v5" 7 | "github.com/stephenafamo/scan" 8 | ) 9 | 10 | // One scans a single row from the query and maps it to T using a [StdQueryer] 11 | // this is for use with *sql.DB, *sql.Tx or *sql.Conn or any similar implementations 12 | // that return *sql.Rows 13 | func One[T any](ctx context.Context, exec Queryer, m scan.Mapper[T], sql string, args ...any) (T, error) { 14 | return scan.One(ctx, convert(exec), m, sql, args...) 15 | } 16 | 17 | // All scans all rows from the query and returns a slice []T of all rows using a [StdQueryer] this is for use with *sql.DB, *sql.Tx or *sql.Conn or any similar implementations 18 | // that return *sql.Rows 19 | func All[T any](ctx context.Context, exec Queryer, m scan.Mapper[T], sql string, args ...any) ([]T, error) { 20 | return scan.All(ctx, convert(exec), m, sql, args...) 21 | } 22 | 23 | // Cursor returns a cursor that works similar to *sql.Rows 24 | func Cursor[T any](ctx context.Context, exec Queryer, m scan.Mapper[T], sql string, args ...any) (scan.ICursor[T], error) { 25 | return scan.Cursor(ctx, convert(exec), m, sql, args...) 26 | } 27 | 28 | // Each returns a function that can be used to iterate over the rows of a query 29 | // this function works with range-over-func so it is possible to do 30 | // 31 | // for val, err := range scan.Each(ctx, exec, m, query, args...) { 32 | // if err != nil { 33 | // return err 34 | // } 35 | // // do something with val 36 | // } 37 | func Each[T any](ctx context.Context, exec Queryer, m scan.Mapper[T], query string, args ...any) func(func(T, error) bool) { 38 | return scan.Each(ctx, convert(exec), m, query, args...) 39 | } 40 | 41 | // A Queryer that returns the concrete type [*sql.Rows] 42 | type Queryer interface { 43 | Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) 44 | } 45 | 46 | // convert wraps an Queryer and makes it a Queryer 47 | func convert(wrapped Queryer) scan.Queryer { 48 | return queryer{wrapped: wrapped} 49 | } 50 | 51 | type queryer struct { 52 | wrapped Queryer 53 | } 54 | 55 | type rows struct { 56 | pgx.Rows 57 | } 58 | 59 | func (r rows) Close() error { 60 | r.Rows.Close() 61 | return nil 62 | } 63 | 64 | func (r rows) Columns() ([]string, error) { 65 | fields := r.FieldDescriptions() 66 | cols := make([]string, len(fields)) 67 | 68 | for i, field := range fields { 69 | cols[i] = field.Name 70 | } 71 | 72 | return cols, nil 73 | } 74 | 75 | // QueryContext executes a query that returns rows, typically a SELECT. The args are for any placeholder parameters in the query. 76 | func (q queryer) QueryContext(ctx context.Context, query string, args ...any) (scan.Rows, error) { 77 | r, err := q.wrapped.Query(ctx, query, args...) 78 | return rows{r}, err 79 | } 80 | -------------------------------------------------------------------------------- /row.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | var zeroValue reflect.Value 9 | 10 | func wrapRows(r Rows, allowUnknown bool) (*Row, error) { 11 | cols, err := r.Columns() 12 | if err != nil { 13 | return nil, err 14 | } 15 | 16 | return &Row{ 17 | r: r, 18 | columns: cols, 19 | scanDestinations: make([]reflect.Value, len(cols)), 20 | allowUnknown: allowUnknown, 21 | }, nil 22 | } 23 | 24 | // Row represents a single row from the query and is passed to the [BeforeFunc] 25 | // when sent to a mapper's before function, scans should be scheduled 26 | // with either the [ScheduleScan] or [ScheduleScanx] methods 27 | type Row struct { 28 | r Rows 29 | columns []string 30 | scanDestinations []reflect.Value 31 | unknownDestinations []string 32 | allowUnknown bool 33 | } 34 | 35 | // ScheduleScan schedules a scan for the column name into the given value 36 | // val should be a pointer 37 | func (r *Row) ScheduleScan(colName string, val any) { 38 | r.ScheduleScanx(colName, reflect.ValueOf(val)) 39 | } 40 | 41 | // ScheduleScanx schedules a scan for the column name into the given reflect.Value 42 | // val.Kind() should be reflect.Pointer 43 | func (r *Row) ScheduleScanx(colName string, val reflect.Value) { 44 | for i, n := range r.columns { 45 | if n == colName { 46 | r.scanDestinations[i] = val 47 | return 48 | } 49 | } 50 | 51 | r.unknownDestinations = append(r.unknownDestinations, colName) 52 | } 53 | 54 | // To get a copy of the columns to pass to mapper generators 55 | // since modifing the map can have unintended side effects. 56 | // Ideally, a generator should only call this once 57 | func (r *Row) columnsCopy() []string { 58 | m := make([]string, len(r.columns)) 59 | copy(m, r.columns) 60 | return m 61 | } 62 | 63 | func (r *Row) scanCurrentRow() error { 64 | if len(r.unknownDestinations) > 0 { 65 | return createError(fmt.Errorf("unknown columns to map to: %v", r.unknownDestinations), r.unknownDestinations...) 66 | } 67 | 68 | targets, err := r.createTargets() 69 | if err != nil { 70 | return err 71 | } 72 | 73 | err = r.r.Scan(targets...) 74 | if err != nil { 75 | return err 76 | } 77 | 78 | r.scanDestinations = make([]reflect.Value, len(r.columns)) 79 | return nil 80 | } 81 | 82 | func (r *Row) createTargets() ([]any, error) { 83 | targets := make([]any, len(r.columns)) 84 | 85 | for i, name := range r.columns { 86 | dest := r.scanDestinations[i] 87 | if dest != zeroValue { 88 | targets[i] = dest.Interface() 89 | continue 90 | } 91 | 92 | if !r.allowUnknown { 93 | err := fmt.Errorf("No destination for column %s", name) 94 | return nil, createError(err, "no destination", name) 95 | } 96 | 97 | // See https://github.com/golang/go/issues/41607: 98 | // Some drivers cannot work with nil values, so valid pointers should be 99 | // used for all column targets, even if they are discarded afterwards. 100 | targets[i] = new(interface{}) 101 | } 102 | 103 | return targets, nil 104 | } 105 | -------------------------------------------------------------------------------- /source.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "reflect" 8 | "regexp" 9 | "strings" 10 | "sync" 11 | ) 12 | 13 | var ( 14 | matchFirstCapRe = regexp.MustCompile("(.)([A-Z][a-z]+)") 15 | matchAllCapRe = regexp.MustCompile("([a-z0-9])([A-Z])") 16 | defaultStructMapper = newDefaultMapperSourceImpl() 17 | ) 18 | 19 | // snakeCaseFieldFunc is a NameMapperFunc that maps struct field to snake case. 20 | func snakeCaseFieldFunc(str string) string { 21 | snake := matchFirstCapRe.ReplaceAllString(str, "${1}_${2}") 22 | snake = matchAllCapRe.ReplaceAllString(snake, "${1}_${2}") 23 | return strings.ToLower(snake) 24 | } 25 | 26 | func newDefaultMapperSourceImpl() *mapperSourceImpl { 27 | return &mapperSourceImpl{ 28 | structTagKey: "db", 29 | columnSeparator: ".", 30 | fieldMapperFn: snakeCaseFieldFunc, 31 | scannableTypes: []reflect.Type{reflect.TypeOf((*sql.Scanner)(nil)).Elem()}, 32 | maxDepth: 3, 33 | cache: make(map[reflect.Type]mapping), 34 | } 35 | } 36 | 37 | // NewStructMapperSource creates a new Mapping object with provided list of options. 38 | func NewStructMapperSource(opts ...MappingSourceOption) (StructMapperSource, error) { 39 | src := newDefaultMapperSourceImpl() 40 | for _, o := range opts { 41 | if err := o(src); err != nil { 42 | return nil, err 43 | } 44 | } 45 | return src, nil 46 | } 47 | 48 | // MappingSourceOption are options to modify how a struct's mappings are interpreted 49 | type MappingSourceOption func(src *mapperSourceImpl) error 50 | 51 | // WithStructTagKey allows to use a custom struct tag key. 52 | // The default tag key is `db`. 53 | func WithStructTagKey(tagKey string) MappingSourceOption { 54 | return func(src *mapperSourceImpl) error { 55 | src.structTagKey = tagKey 56 | return nil 57 | } 58 | } 59 | 60 | // WithColumnSeparator allows to use a custom separator character for column name when combining nested structs. 61 | // The default separator is "." character. 62 | func WithColumnSeparator(separator string) MappingSourceOption { 63 | return func(src *mapperSourceImpl) error { 64 | src.columnSeparator = separator 65 | return nil 66 | } 67 | } 68 | 69 | // WithFieldNameMapper allows to use a custom function to map field name to column names. 70 | // The default function maps fields names to "snake_case" 71 | func WithFieldNameMapper(mapperFn func(string) string) MappingSourceOption { 72 | return func(src *mapperSourceImpl) error { 73 | src.fieldMapperFn = mapperFn 74 | return nil 75 | } 76 | } 77 | 78 | // WithScannableTypes specifies a list of interfaces that underlying database library can scan into. 79 | // In case the destination type passed to scan implements one of those interfaces, 80 | // scan will handle it as primitive type case i.e. simply pass the destination to the database library. 81 | // Instead of attempting to map database columns to destination struct fields or map keys. 82 | // In order for reflection to capture the interface type, you must pass it by pointer. 83 | // 84 | // For example your database library defines a scanner interface like this: 85 | // 86 | // type Scanner interface { 87 | // Scan(...) error 88 | // } 89 | // 90 | // You can pass it to scan this way: 91 | // scan.WithScannableTypes((*Scanner)(nil)). 92 | func WithScannableTypes(scannableTypes ...any) MappingSourceOption { 93 | return func(src *mapperSourceImpl) error { 94 | for _, stOpt := range scannableTypes { 95 | st := reflect.TypeOf(stOpt) 96 | if st == nil { 97 | return fmt.Errorf("scannable type must be a pointer, got %T", stOpt) 98 | } 99 | if st.Kind() != reflect.Pointer { 100 | return fmt.Errorf("scannable type must be a pointer, got %s: %s", 101 | st.Kind(), st.String()) 102 | } 103 | st = st.Elem() 104 | if st.Kind() != reflect.Interface { 105 | return fmt.Errorf("scannable type must be a pointer to an interface, got %s: %s", 106 | st.Kind(), st.String()) 107 | } 108 | src.scannableTypes = append(src.scannableTypes, st) 109 | } 110 | return nil 111 | } 112 | } 113 | 114 | // mapperSourceImpl is an implementation of StructMapperSource. 115 | type mapperSourceImpl struct { 116 | structTagKey string 117 | columnSeparator string 118 | fieldMapperFn func(string) string 119 | scannableTypes []reflect.Type 120 | maxDepth int 121 | cache map[reflect.Type]mapping 122 | mutex sync.RWMutex 123 | } 124 | 125 | func (s *mapperSourceImpl) getMapping(typ reflect.Type) (mapping, error) { 126 | s.mutex.RLock() 127 | m, ok := s.cache[typ] 128 | s.mutex.RUnlock() 129 | 130 | if ok { 131 | return m, nil 132 | } 133 | 134 | s.setMappings(typ, "", make(visited), &m, nil) 135 | 136 | s.mutex.Lock() 137 | s.cache[typ] = m 138 | s.mutex.Unlock() 139 | 140 | return m, nil 141 | } 142 | 143 | func (s *mapperSourceImpl) setMappings(typ reflect.Type, prefix string, v visited, m *mapping, inits [][]int, position ...int) { 144 | count := v[typ] 145 | if count > s.maxDepth { 146 | return 147 | } 148 | v[typ] = count + 1 149 | 150 | var hasExported bool 151 | 152 | var isPointer bool 153 | if typ.Kind() == reflect.Pointer { 154 | isPointer = true 155 | typ = typ.Elem() 156 | } 157 | 158 | // If it implements a scannable type, then it can be used 159 | // as a value itself. Return it 160 | for _, scannable := range s.scannableTypes { 161 | if reflect.PtrTo(typ).Implements(scannable) { 162 | *m = append(*m, mapinfo{ 163 | name: prefix, 164 | position: position, 165 | init: inits, 166 | isPointer: isPointer, 167 | }) 168 | return 169 | } 170 | } 171 | 172 | // Go through the struct fields and populate the map. 173 | // Recursively go into any child structs, adding a prefix where necessary 174 | for i := 0; i < typ.NumField(); i++ { 175 | field := typ.Field(i) 176 | 177 | // Don't consider unexported fields 178 | if !field.IsExported() { 179 | continue 180 | } 181 | 182 | // Skip columns that have the tag "-" 183 | tag := strings.Split(field.Tag.Get(s.structTagKey), ",")[0] 184 | if tag == "-" { 185 | continue 186 | } 187 | 188 | hasExported = true 189 | 190 | key := prefix 191 | 192 | if !field.Anonymous { 193 | var sep string 194 | if prefix != "" { 195 | sep = s.columnSeparator 196 | } 197 | 198 | name := tag 199 | if tag == "" { 200 | name = s.fieldMapperFn(field.Name) 201 | } 202 | 203 | key = strings.Join([]string{key, name}, sep) 204 | } 205 | 206 | currentIndex := append(position, i) 207 | fieldType := field.Type 208 | var isPointer bool 209 | 210 | if fieldType.Kind() == reflect.Pointer { 211 | inits = append(inits, currentIndex) 212 | fieldType = fieldType.Elem() 213 | isPointer = true 214 | } 215 | 216 | if fieldType.Kind() == reflect.Struct { 217 | s.setMappings(field.Type, key, v.copy(), m, inits, currentIndex...) 218 | continue 219 | } 220 | 221 | *m = append(*m, mapinfo{ 222 | name: key, 223 | position: currentIndex, 224 | init: inits, 225 | isPointer: isPointer, 226 | }) 227 | } 228 | 229 | // If it has no exported field (such as time.Time) then we attempt to 230 | // directly scan into it 231 | if !hasExported { 232 | *m = append(*m, mapinfo{ 233 | name: prefix, 234 | position: position, 235 | init: inits, 236 | isPointer: isPointer, 237 | }) 238 | } 239 | } 240 | 241 | func filterColumns(ctx context.Context, c cols, m mapping, prefix string) (mapping, error) { 242 | // Filter the mapping so we only ask for the available columns 243 | filtered := make(mapping, 0, len(c)) 244 | for _, name := range c { 245 | key := name 246 | if prefix != "" { 247 | if !strings.HasPrefix(name, prefix) { 248 | continue 249 | } 250 | 251 | key = name[len(prefix):] 252 | } 253 | 254 | for _, info := range m { 255 | if key == info.name { 256 | info.name = name 257 | filtered = append(filtered, info) 258 | break 259 | } 260 | } 261 | } 262 | 263 | return filtered, nil 264 | } 265 | -------------------------------------------------------------------------------- /stdscan/stdscan.go: -------------------------------------------------------------------------------- 1 | package stdscan 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/stephenafamo/scan" 8 | ) 9 | 10 | // One scans a single row from the query and maps it to T using a [StdQueryer] 11 | // this is for use with *sql.DB, *sql.Tx or *sql.Conn or any similar implementations 12 | // that return *sql.Rows 13 | func One[T any](ctx context.Context, exec Queryer, m scan.Mapper[T], sql string, args ...any) (T, error) { 14 | return scan.One(ctx, convert(exec), m, sql, args...) 15 | } 16 | 17 | // All scans all rows from the query and returns a slice []T of all rows using a [StdQueryer] this is for use with *sql.DB, *sql.Tx or *sql.Conn or any similar implementations 18 | // that return *sql.Rows 19 | func All[T any](ctx context.Context, exec Queryer, m scan.Mapper[T], sql string, args ...any) ([]T, error) { 20 | return scan.All(ctx, convert(exec), m, sql, args...) 21 | } 22 | 23 | // Cursor returns a cursor that works similar to *sql.Rows 24 | func Cursor[T any](ctx context.Context, exec Queryer, m scan.Mapper[T], sql string, args ...any) (scan.ICursor[T], error) { 25 | return scan.Cursor(ctx, convert(exec), m, sql, args...) 26 | } 27 | 28 | // Each returns a function that can be used to iterate over the rows of a query 29 | // this function works with range-over-func so it is possible to do 30 | // 31 | // for val, err := range scan.Each(ctx, exec, m, query, args...) { 32 | // if err != nil { 33 | // return err 34 | // } 35 | // // do something with val 36 | // } 37 | func Each[T any](ctx context.Context, exec Queryer, m scan.Mapper[T], query string, args ...any) func(func(T, error) bool) { 38 | return scan.Each(ctx, convert(exec), m, query, args...) 39 | } 40 | 41 | // A Queryer that returns the concrete type [*sql.Rows] 42 | type Queryer interface { 43 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 44 | } 45 | 46 | // convert wraps an Queryer and makes it a Queryer 47 | func convert(wrapped Queryer) scan.Queryer { 48 | return queryer{wrapped: wrapped} 49 | } 50 | 51 | type queryer struct { 52 | wrapped Queryer 53 | } 54 | 55 | // QueryContext executes a query that returns rows, typically a SELECT. The args are for any placeholder parameters in the query. 56 | func (q queryer) QueryContext(ctx context.Context, query string, args ...any) (scan.Rows, error) { 57 | return q.wrapped.QueryContext(ctx, query, args...) 58 | } 59 | --------------------------------------------------------------------------------