├── .gitignore ├── renovate.json ├── handlers └── build │ ├── testdata │ ├── test2.sql │ ├── test2.go │ ├── test1.sql │ └── test1.go │ ├── build.go │ ├── build_test.go │ ├── generate.go │ └── fetch.go ├── log.go ├── functions_test.go ├── .github └── workflows │ ├── test.yml │ └── review.yml ├── testdata └── test.sql ├── helpers_test.go ├── LICENSE ├── field.go ├── .golangci.yml ├── cmd └── spnr │ └── main.go ├── mutation.go ├── dml_delete_test.go ├── dml.go ├── reader.go ├── mutation_delete_test.go ├── mutation_delete.go ├── dml_delete.go ├── functions.go ├── dml_insert.go ├── go.mod ├── reader_key.go ├── dml_update.go ├── helpers.go ├── reader_key_test.go ├── mutation_update.go ├── reader_query.go ├── dml_update_test.go ├── mutation_upsert.go ├── reader_query_test.go ├── dml_insert_test.go ├── mutation_upsert_test.go ├── mutation_update_test.go ├── internal └── examples │ ├── examples.go │ └── examples_test.go ├── init_test.go └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # IntelliJ 2 | .idea/ 3 | go-spnr.iml 4 | 5 | # vscode 6 | .vscode/ 7 | 8 | # Go 9 | .go-version -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "config:base" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /handlers/build/testdata/test2.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE Test2 ( 2 | `String` String(10) NOT NULL, 3 | `Bytes` Timestamp NOT NULL, 4 | ) PRIMARY KEY (String) 5 | -------------------------------------------------------------------------------- /handlers/build/testdata/test2.go: -------------------------------------------------------------------------------- 1 | package entity_test 2 | 3 | import "time" 4 | 5 | type Test2 struct { 6 | String string `spanner:"String" pk:"1"` 7 | Bytes time.Time `spanner:"Bytes"` 8 | } 9 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import "log" 4 | 5 | type logger interface { 6 | Printf(format string, v ...any) 7 | } 8 | 9 | type defaultLogger struct{} 10 | 11 | func newDefaultLogger() *defaultLogger { 12 | return &defaultLogger{} 13 | } 14 | 15 | func (d *defaultLogger) Printf(format string, v ...any) { 16 | log.Printf(format, v...) 17 | } 18 | -------------------------------------------------------------------------------- /functions_test.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "cloud.google.com/go/spanner" 5 | "fmt" 6 | "gotest.tools/assert" 7 | "testing" 8 | ) 9 | 10 | func TestToKeySets(t *testing.T) { 11 | expected := spanner.KeySetFromKeys(spanner.Key{"a"}, spanner.Key{"b"}, spanner.Key{"c"}) 12 | actual := ToKeySets([]string{"a", "b", "c"}) 13 | assert.Equal(t, fmt.Sprintf("%+v", expected), fmt.Sprintf("%+v", actual)) 14 | } 15 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | jobs: 8 | setup: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Install Dependency 12 | run: sudo apt update && sudo apt install -y gcc 13 | test: 14 | needs: setup 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Set up Go 1.19.1 18 | uses: actions/setup-go@v3 19 | with: 20 | go-version: 1.19.1 21 | - uses: actions/checkout@v3 22 | with: 23 | ref: ${{ github.event.pull_request.head.ref }} 24 | - name: go mod download 25 | run: go mod download 26 | - name: test 27 | run: go test ./... 28 | -------------------------------------------------------------------------------- /testdata/test.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE Test ( 2 | `String` STRING(MAX) NOT NULL, 3 | `Bytes` BYTES(4) NOT NULL, 4 | `Int64` INT64 NOT NULL, 5 | `Float64` FLOAT64 NOT NULL, 6 | `Numeric` NUMERIC NOT NULL, 7 | `Bool` BOOL NOT NULL, 8 | `Date` DATE NOT NULL, 9 | `Timestamp` TIMESTAMP NOT NULL, 10 | `NullString` STRING(MAX), 11 | `NullInt64` INT64, 12 | `NullFloat64` FLOAT64, 13 | `NullNumeric` NUMERIC, 14 | `NullBool` BOOL, 15 | `NullDate` DATE, 16 | `NullTimestamp` TIMESTAMP, 17 | ArrayString ARRAY, 18 | ArrayBytes ARRAY, 19 | ArrayInt64 ARRAY, 20 | ArrayFloat64 ARRAY, 21 | ArrayNumeric ARRAY, 22 | ArrayBool ARRAY, 23 | ArrayDate ARRAY, 24 | ArrayTimestamp ARRAY, 25 | ) PRIMARY KEY (`String`, `Int64`) 26 | -------------------------------------------------------------------------------- /handlers/build/testdata/test1.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE Test1 ( 2 | `String` STRING(MAX) NOT NULL, 3 | `Bytes` BYTES(4) NOT NULL, 4 | `Int64` INT64 NOT NULL, 5 | `Float64` FLOAT64 NOT NULL, 6 | `Numeric` NUMERIC NOT NULL, 7 | `Bool` BOOL NOT NULL, 8 | `Date` DATE NOT NULL, 9 | `Timestamp` TIMESTAMP NOT NULL, 10 | `NullString` STRING(MAX), 11 | `NullInt64` INT64, 12 | `NullFloat64` FLOAT64, 13 | `NullNumeric` NUMERIC, 14 | `NullBool` BOOL, 15 | `NullDate` DATE, 16 | `NullTimestamp` TIMESTAMP, 17 | ArrayString ARRAY, 18 | ArrayBytes ARRAY, 19 | ArrayInt64 ARRAY, 20 | ArrayFloat64 ARRAY, 21 | ArrayNumeric ARRAY, 22 | ArrayBool ARRAY, 23 | ArrayDate ARRAY, 24 | ArrayTimestamp ARRAY, 25 | ) PRIMARY KEY (`String`, `Int64`) 26 | -------------------------------------------------------------------------------- /helpers_test.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestValidateType(t *testing.T) { 9 | var str string 10 | var stru Test 11 | var sl []string 12 | var structSl []Test 13 | 14 | err := validateStructType(&stru) 15 | assert.Nil(t, err) 16 | err = validateSliceType(&sl) 17 | assert.Nil(t, err) 18 | err = validateStructSliceType(&structSl) 19 | assert.Nil(t, err) 20 | 21 | err = validateStructType(&sl) 22 | assert.NotNil(t, err) 23 | assert.Equal(t, "final argument must be struct but got slice", err.Error()) 24 | err = validateSliceType(&str) 25 | assert.NotNil(t, err) 26 | assert.Equal(t, "final argument must be slice but got string", err.Error()) 27 | err = validateStructSliceType(&sl) 28 | assert.NotNil(t, err) 29 | assert.Equal(t, "final argument must be slice of struct but got slice of string", err.Error()) 30 | } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 kanjih 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /field.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "reflect" 5 | "strconv" 6 | ) 7 | 8 | const ( 9 | tagColumnName = "spanner" 10 | tagPkOrder = "pk" 11 | noPk = -1 12 | ) 13 | 14 | type field struct { 15 | name string 16 | value any 17 | pkOrder int 18 | } 19 | 20 | func (f *field) isPk() bool { 21 | return f.pkOrder != noPk 22 | } 23 | 24 | func toFields(target any) []field { 25 | return structValToFields(reflect.ValueOf(target).Elem()) 26 | } 27 | 28 | func structValToFields(val reflect.Value) []field { 29 | if val.Kind() == reflect.Ptr { 30 | val = val.Elem() 31 | } 32 | tp := val.Type() 33 | var v []field 34 | for i := 0; i < val.NumField(); i++ { 35 | name := tp.Field(i).Tag.Get(tagColumnName) 36 | if name == "" { 37 | continue 38 | } 39 | f := field{ 40 | name: name, 41 | value: val.Field(i).Interface(), 42 | pkOrder: getPkOrder(tp.Field(i)), 43 | } 44 | v = append(v, f) 45 | } 46 | return v 47 | } 48 | 49 | func getPkOrder(s reflect.StructField) int { 50 | pk := s.Tag.Get(tagPkOrder) 51 | if pk == "" { 52 | return noPk 53 | } 54 | pkOrder, err := strconv.Atoi(pk) 55 | if err != nil { 56 | panic(err) 57 | } 58 | return pkOrder 59 | } 60 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | skip-dirs: 3 | - .github/ 4 | skip-files: 5 | - .gitignore 6 | - .go-version 7 | - .golangci.yml 8 | - go.mod 9 | - go.sum 10 | linters: 11 | disable-all: true 12 | enable: 13 | - staticcheck 14 | - errcheck 15 | - gosimple 16 | - govet 17 | - ineffassign 18 | - unused 19 | - varcheck 20 | - bodyclose 21 | - errorlint 22 | - godox 23 | - gomnd 24 | - goprintffuncname 25 | - gosec 26 | - nakedret 27 | - nestif 28 | - unconvert 29 | - wastedassign 30 | 31 | linters-settings: 32 | staticcheck: 33 | go: "1.16.6" 34 | checks: ["all"] 35 | gosimple: 36 | go: "1.16.6" 37 | checks: ["all"] 38 | unused: 39 | go: "1.16.6" 40 | errorlint: 41 | errorf: true 42 | asserts: true 43 | comparison: true 44 | gomnd: 45 | settings: 46 | mnd: 47 | # the list of enabled checks, see https://github.com/tommy-muehle/go-mnd/#checks for description. 48 | checks: argument,case,condition,operation,return,assign 49 | nakedret: 50 | # make an issue if func has more lines of code than this setting and it has naked returns; default is 30 51 | max-func-lines: 30 52 | -------------------------------------------------------------------------------- /cmd/spnr/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/kanjih/go-spnr/handlers/build" 6 | "github.com/urfave/cli/v2" 7 | "os" 8 | ) 9 | 10 | func main() { 11 | app := cli.NewApp() 12 | app.Usage = "Reducing boilerplate code for spanner" 13 | app.EnableBashCompletion = true 14 | app.Commands = []*cli.Command{ 15 | { 16 | Name: "build", 17 | Usage: "build structs to map records", 18 | Flags: []cli.Flag{ 19 | &cli.StringFlag{ 20 | Name: build.FlagNameProjectId, 21 | Usage: "gcp project id", 22 | Required: true, 23 | }, 24 | &cli.StringFlag{ 25 | Name: build.FlagNameInstanceName, 26 | Usage: "spanner instance name", 27 | Required: true, 28 | }, 29 | &cli.StringFlag{ 30 | Name: build.FlagNameDatabaseName, 31 | Usage: "spanner database name", 32 | Required: true, 33 | }, 34 | &cli.StringFlag{ 35 | Name: build.FlagNameOut, 36 | Usage: "output folder", 37 | Required: true, 38 | }, 39 | &cli.StringFlag{ 40 | Name: build.FlagNamePackageName, 41 | Usage: "package name", 42 | Required: false, 43 | }, 44 | }, 45 | Action: build.Run, 46 | }, 47 | } 48 | 49 | if err := app.Run(os.Args); err != nil { 50 | fmt.Println("[ERROR] " + err.Error()) 51 | os.Exit(1) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /mutation.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package spnr provides the orm for Cloud Spanner. 3 | */ 4 | package spnr 5 | 6 | import ( 7 | "context" 8 | ) 9 | 10 | // DML offers ORM with Mutation API. 11 | // It also contains read operations (call Reader method.) 12 | type Mutation struct { 13 | table string 14 | logger logger 15 | logEnabled bool 16 | } 17 | 18 | // New is alias for NewMutation. 19 | func New(tableName string) *Mutation { 20 | return &Mutation{table: tableName} 21 | } 22 | 23 | // NewMutation initializes ORM with Mutation API. 24 | // It also contains read operations (call Reader method of Mutation.) 25 | // If you want to use DML, use NewDML() instead. 26 | func NewMutation(tableName string) *Mutation { 27 | return &Mutation{table: tableName} 28 | } 29 | 30 | // NewDMLWithOptions initializes Mutation with options. 31 | // Check Options for the available options. 32 | func NewMutationWithOptions(tableName string, op *Options) *Mutation { 33 | m := &Mutation{table: tableName, logger: op.Logger, logEnabled: op.LogEnabled} 34 | if m.logger == nil { 35 | m.logger = newDefaultLogger() 36 | } 37 | return m 38 | } 39 | 40 | // Reader returns Reader struct to call read operations. 41 | func (m *Mutation) Reader(ctx context.Context, tx Transaction) *Reader { 42 | return &Reader{table: m.table, ctx: ctx, tx: tx, logger: m.logger, logEnabled: m.logEnabled} 43 | } 44 | 45 | // GetTableName returns table name 46 | func (m *Mutation) GetTableName() string { 47 | return m.table 48 | } 49 | 50 | func (m *Mutation) logf(format string, v ...any) { 51 | if m.logEnabled { 52 | m.logger.Printf(format, v...) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /handlers/build/testdata/test1.go: -------------------------------------------------------------------------------- 1 | package entity_test 2 | 3 | import ( 4 | "cloud.google.com/go/civil" 5 | "cloud.google.com/go/spanner" 6 | "math/big" 7 | "time" 8 | ) 9 | 10 | type Test1 struct { 11 | String string `spanner:"String" pk:"1"` 12 | Bytes []byte `spanner:"Bytes"` 13 | Int64 int64 `spanner:"Int64" pk:"2"` 14 | Float64 float64 `spanner:"Float64"` 15 | Numeric big.Rat `spanner:"Numeric"` 16 | Bool bool `spanner:"Bool"` 17 | Date civil.Date `spanner:"Date"` 18 | Timestamp time.Time `spanner:"Timestamp"` 19 | NullString spanner.NullString `spanner:"NullString"` 20 | NullInt64 spanner.NullInt64 `spanner:"NullInt64"` 21 | NullFloat64 spanner.NullFloat64 `spanner:"NullFloat64"` 22 | NullNumeric spanner.NullNumeric `spanner:"NullNumeric"` 23 | NullBool spanner.NullBool `spanner:"NullBool"` 24 | NullDate spanner.NullDate `spanner:"NullDate"` 25 | NullTimestamp spanner.NullTime `spanner:"NullTimestamp"` 26 | ArrayString []string `spanner:"ArrayString"` 27 | ArrayBytes [][]byte `spanner:"ArrayBytes"` 28 | ArrayInt64 []int64 `spanner:"ArrayInt64"` 29 | ArrayFloat64 []float64 `spanner:"ArrayFloat64"` 30 | ArrayNumeric []big.Rat `spanner:"ArrayNumeric"` 31 | ArrayBool []bool `spanner:"ArrayBool"` 32 | ArrayDate []civil.Date `spanner:"ArrayDate"` 33 | ArrayTimestamp []time.Time `spanner:"ArrayTimestamp"` 34 | } 35 | -------------------------------------------------------------------------------- /dml_delete_test.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | var testDMLRepository = NewDML("Test") 9 | 10 | func TestDML_buildDeleteStmt(t *testing.T) { 11 | stmt := testDMLRepository.buildDeleteStmt(testRecord1) 12 | assert.Equal(t, "DELETE FROM `Test` WHERE `String`=@w_String AND `Int64`=@w_Int64", stmt.SQL) 13 | assert.Equal(t, testRecord1.String, stmt.Params["w_String"].(string)) 14 | } 15 | 16 | func TestDML_buildDeleteAllStmt(t *testing.T) { 17 | stmt := testDMLRepository.buildDeleteAllStmt(&([]Test{*testRecord1, *testRecord2})) 18 | assert.Equal(t, "DELETE FROM `Test` WHERE (`String`=@w_String_0 AND `Int64`=@w_Int64_0) OR (`String`=@w_String_1 AND `Int64`=@w_Int64_1)", stmt.SQL) 19 | assert.Equal(t, testRecord1.String, stmt.Params["w_String_0"].(string)) 20 | assert.Equal(t, testRecord1.Int64, stmt.Params["w_Int64_0"].(int64)) 21 | assert.Equal(t, testRecord2.String, stmt.Params["w_String_1"].(string)) 22 | assert.Equal(t, testRecord2.Int64, stmt.Params["w_Int64_1"].(int64)) 23 | } 24 | 25 | func TestDML_buildDeleteAllStmtPointer(t *testing.T) { 26 | stmt := testDMLRepository.buildDeleteAllStmt(&([]*Test{testRecord1, testRecord2})) 27 | assert.Equal(t, "DELETE FROM `Test` WHERE (`String`=@w_String_0 AND `Int64`=@w_Int64_0) OR (`String`=@w_String_1 AND `Int64`=@w_Int64_1)", stmt.SQL) 28 | assert.Equal(t, testRecord1.String, stmt.Params["w_String_0"].(string)) 29 | assert.Equal(t, testRecord1.Int64, stmt.Params["w_Int64_0"].(int64)) 30 | assert.Equal(t, testRecord2.String, stmt.Params["w_String_1"].(string)) 31 | assert.Equal(t, testRecord2.Int64, stmt.Params["w_Int64_1"].(int64)) 32 | } 33 | -------------------------------------------------------------------------------- /handlers/build/build.go: -------------------------------------------------------------------------------- 1 | package build 2 | 3 | import ( 4 | "context" 5 | "github.com/pkg/errors" 6 | "github.com/urfave/cli/v2" 7 | "os" 8 | "strings" 9 | ) 10 | 11 | const ( 12 | FlagNameProjectId = "p" 13 | FlagNameInstanceName = "i" 14 | FlagNameDatabaseName = "d" 15 | FlagNameOut = "o" 16 | FlagNamePackageName = "n" 17 | ) 18 | 19 | func Run(c *cli.Context) error { 20 | out := c.String(FlagNameOut) 21 | if _, err := os.Stat(out); errors.Is(err, os.ErrNotExist) { 22 | return errors.Errorf("%s doesn't exist", out) 23 | } 24 | if !strings.HasSuffix(out, "/") { 25 | out += "/" 26 | } 27 | 28 | packageName := c.String(FlagNamePackageName) 29 | if packageName == "" { 30 | packageName = "entity" 31 | } 32 | 33 | codes, err := generateCode( 34 | c.Context, 35 | c.String(FlagNameProjectId), 36 | c.String(FlagNameInstanceName), 37 | c.String(FlagNameDatabaseName), 38 | packageName, 39 | ) 40 | if err != nil { 41 | return err 42 | } 43 | 44 | for tableName, code := range codes { 45 | if err := writeFile(out, tableName, code); err != nil { 46 | return err 47 | } 48 | } 49 | return nil 50 | } 51 | 52 | func generateCode(ctx context.Context, projectId, instanceName, dbName, packageName string) (map[string][]byte, error) { 53 | columns, err := fetchColumns(ctx, projectId, instanceName, dbName) 54 | if err != nil { 55 | return nil, err 56 | } 57 | return generate(packageName, columns) 58 | } 59 | 60 | func writeFile(dirName, tableName string, code []byte) error { 61 | f, err := os.Create(dirName + strings.ToLower(tableName) + ".go") 62 | if err != nil { 63 | return errors.New("Got error during create file") 64 | } 65 | defer f.Close() 66 | _, err = f.Write(code) 67 | return err 68 | } 69 | -------------------------------------------------------------------------------- /dml.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | const dmlLogTemplate = "executing dml... sql:%s, params:%s" 9 | 10 | // DML offers ORM with DML. 11 | // It also contains read operations (call Reader method.) 12 | type DML struct { 13 | table string 14 | logger logger 15 | logEnabled bool 16 | } 17 | 18 | // Options is for specifying the options for spnr.Mutation and spnr.DML. 19 | type Options struct { 20 | Logger logger 21 | LogEnabled bool 22 | } 23 | 24 | // NewDML initializes ORM with DML. 25 | // It also contains read operations (call Reader method of DML.) 26 | // If you want to use Mutation API, use New() or NewMutation() instead. 27 | func NewDML(tableName string) *DML { 28 | return &DML{table: tableName} 29 | } 30 | 31 | // NewDMLWithOptions initializes DML with options. 32 | // Check Options for the available options. 33 | func NewDMLWithOptions(tableName string, op *Options) *DML { 34 | dml := &DML{table: tableName, logger: op.Logger, logEnabled: op.LogEnabled} 35 | if dml.logger == nil { 36 | dml.logger = newDefaultLogger() 37 | } 38 | return dml 39 | } 40 | 41 | // Reader returns Reader struct to call read operations. 42 | func (d *DML) Reader(ctx context.Context, tx Transaction) *Reader { 43 | return &Reader{table: d.table, ctx: ctx, tx: tx, logger: d.logger, logEnabled: d.logEnabled} 44 | } 45 | 46 | // GetTableName returns table name 47 | func (d *DML) GetTableName() string { 48 | return d.table 49 | } 50 | 51 | func (d *DML) getTableName() string { 52 | return quote(d.table) 53 | } 54 | 55 | func (d *DML) log(sql string, params map[string]any) { 56 | if !d.logEnabled { 57 | return 58 | } 59 | var paramsStr string 60 | for k, v := range params { 61 | paramsStr += fmt.Sprintf("%s=%+v,", k, v) 62 | } 63 | paramsStr = paramsStr[:len(paramsStr)-1] 64 | d.logger.Printf(dmlLogTemplate, sql, paramsStr) 65 | } 66 | -------------------------------------------------------------------------------- /reader.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "reflect" 7 | 8 | "cloud.google.com/go/spanner" 9 | "github.com/googleapis/gax-go/v2/apierror" 10 | "github.com/pkg/errors" 11 | "google.golang.org/grpc/codes" 12 | ) 13 | 14 | var ( 15 | // ErrNotFound is returned when a read operation cannot find any records unexpectedly. 16 | ErrNotFound = errors.New("record not found") 17 | // ErrNotFound is returned when a read operation found multiple records unexpectedly. 18 | ErrMoreThanOneRecordFound = errors.New("more than one record found") 19 | ) 20 | 21 | const readLogTemplate = "executing read... %s, %+v" 22 | 23 | // Transaction is the interface for spanner.ReadOnlyTransaction and spanner.ReadWriteTransaction 24 | type Transaction interface { 25 | Read(ctx context.Context, table string, keys spanner.KeySet, columns []string) *spanner.RowIterator 26 | ReadRow(ctx context.Context, table string, key spanner.Key, columns []string) (*spanner.Row, error) 27 | Query(ctx context.Context, statement spanner.Statement) *spanner.RowIterator 28 | } 29 | 30 | // Reader executes read operations. 31 | type Reader struct { 32 | table string 33 | ctx context.Context 34 | tx Transaction 35 | logger logger 36 | logEnabled bool 37 | } 38 | 39 | func (r *Reader) logf(format string, v ...any) { 40 | if !r.logEnabled { 41 | return 42 | } 43 | if r.logger != nil { 44 | r.logger.Printf(format, v...) 45 | } else { 46 | log.Printf(format, v...) 47 | } 48 | } 49 | 50 | func toColumnNames(val reflect.Type) []string { 51 | var columns []string 52 | 53 | for i := 0; i < val.NumField(); i++ { 54 | columns = append(columns, val.Field(i).Name) 55 | } 56 | return columns 57 | } 58 | 59 | func isNotFound(err error) bool { 60 | var apiErr *apierror.APIError 61 | return errors.As(err, &apiErr) && 62 | apiErr.GRPCStatus().Code() == codes.NotFound 63 | } 64 | -------------------------------------------------------------------------------- /mutation_delete_test.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "cloud.google.com/go/spanner" 5 | "context" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | ) 9 | 10 | func TestMutation_Delete(t *testing.T) { 11 | ctx := context.Background() 12 | _, err := testRepository.ApplyInsertOrUpdate(ctx, dataClient, testRecord3) 13 | assert.Nil(t, err) 14 | _, err = testRepository.ApplyDelete(ctx, dataClient, testRecord3) 15 | assert.Nil(t, err) 16 | var fetched Test 17 | err = testRepository.Reader(ctx, dataClient.Single()).FindOne(spanner.Key{testRecord3.String, testRecord3.Int64}, &fetched) 18 | assert.Equal(t, ErrNotFound, err) 19 | } 20 | 21 | func TestMutation_DeleteWithSlice(t *testing.T) { 22 | ctx := context.Background() 23 | _, err := testRepository.ApplyInsertOrUpdate(ctx, dataClient, &([]*Test{testRecord3, testRecord4})) 24 | assert.Nil(t, err) 25 | 26 | _, err = testRepository.ApplyDelete(ctx, dataClient, &([]Test{*testRecord3, *testRecord4})) 27 | assert.Nil(t, err) 28 | 29 | var fetched []Test 30 | keySet := spanner.KeySetFromKeys(spanner.Key{testRecord3.String, testRecord3.Int64}, spanner.Key{testRecord4.String, testRecord4.Int64}) 31 | _ = testRepository.Reader(ctx, dataClient.Single()).FindAll(keySet, &fetched) 32 | assert.Empty(t, fetched) 33 | } 34 | 35 | func TestMutation_DeleteWithSlicePointer(t *testing.T) { 36 | ctx := context.Background() 37 | _, err := testRepository.ApplyInsertOrUpdate(ctx, dataClient, &([]*Test{testRecord3, testRecord4})) 38 | assert.Nil(t, err) 39 | 40 | _, err = testRepository.ApplyDelete(ctx, dataClient, &([]*Test{testRecord3, testRecord4})) 41 | assert.Nil(t, err) 42 | 43 | var fetched []Test 44 | keySet := spanner.KeySetFromKeys(spanner.Key{testRecord3.String, testRecord3.Int64}, spanner.Key{testRecord4.String, testRecord4.Int64}) 45 | _ = testRepository.Reader(ctx, dataClient.Single()).FindAll(keySet, &fetched) 46 | assert.Empty(t, fetched) 47 | } 48 | -------------------------------------------------------------------------------- /.github/workflows/review.yml: -------------------------------------------------------------------------------- 1 | name: review 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | 7 | jobs: 8 | setup: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Install Dependency 12 | run: sudo apt update && sudo apt install -y gcc 13 | build: 14 | needs: setup 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Set up Go 1.19.1 18 | uses: actions/setup-go@v3 19 | with: 20 | go-version: 1.19.1 21 | - uses: actions/checkout@v3 22 | with: 23 | ref: ${{ github.event.pull_request.head.ref }} 24 | - name: go mod download 25 | run: go mod download 26 | - name: test 27 | run: go test ./... 28 | formatting: 29 | needs: build 30 | runs-on: ubuntu-latest 31 | steps: 32 | - name: Set up Go 1.19.1 33 | uses: actions/setup-go@v3 34 | with: 35 | go-version: 1.19.1 36 | - uses: actions/checkout@v3 37 | with: 38 | ref: ${{ github.event.pull_request.head.ref }} # PRを作ったbranch名 39 | - name: go fmt 40 | run: go fmt ./... 41 | - name: commit & push when file is changed 42 | run: | 43 | git config user.name github-actions 44 | git config user.email github-actions@github.com 45 | git add -A 46 | git commit -m "Formatting by github actions" && git push ${REPO} HEAD:${{ github.event.pull_request.head.ref }} || true 47 | glangci_lint: 48 | needs: formatting 49 | runs-on: ubuntu-latest 50 | steps: 51 | - uses: actions/checkout@v3 52 | with: 53 | ref: ${{ github.event.pull_request.head.ref }} 54 | - name: Code Review by golangci-lint 55 | uses: reviewdog/action-golangci-lint@v2 56 | with: 57 | golangci_lint_flags: "--config=.golangci.yml" 58 | filter_mode: nofilter 59 | cache: false 60 | level: warning 61 | reporter: github-pr-review 62 | -------------------------------------------------------------------------------- /mutation_delete.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "cloud.google.com/go/spanner" 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // Delete build and execute delete operation using mutation API. 12 | // You can pass either a struct or a slice of structs. 13 | // If you pass a slice of structs, this method will build a mutation for each struct. 14 | // This method requires spanner.ReadWriteTransaction, and will call spanner.ReadWriteTransaction.BufferWrite to save the mutation to transaction. 15 | func (m *Mutation) Delete(tx *spanner.ReadWriteTransaction, target any) error { 16 | isStruct, err := validateStructOrStructSliceType(target) 17 | if err != nil { 18 | return err 19 | } 20 | if isStruct { 21 | return errors.WithStack(tx.BufferWrite(m.buildDelete([]any{target}))) 22 | } 23 | return errors.WithStack(tx.BufferWrite(m.buildDelete(toStructSlice(target)))) 24 | } 25 | 26 | // ApplyDelete is basically same as Delete, but it doesn't require transaction. 27 | // This method directly calls mutation API without transaction by calling spanner.Client.Apply method. 28 | func (m *Mutation) ApplyDelete(ctx context.Context, client *spanner.Client, target any) (time.Time, error) { 29 | isStruct, err := validateStructOrStructSliceType(target) 30 | if err != nil { 31 | return time.Time{}, err 32 | } 33 | if isStruct { 34 | t, err := client.Apply(ctx, m.buildDelete([]any{target})) 35 | return t, errors.WithStack(err) 36 | } 37 | t, err := client.Apply(ctx, m.buildDelete(toStructSlice(target))) 38 | return t, errors.WithStack(err) 39 | } 40 | 41 | func (m *Mutation) buildDelete(targets []any) []*spanner.Mutation { 42 | var ms []*spanner.Mutation 43 | for _, target := range targets { 44 | var pks spanner.Key 45 | for _, pk := range extractPks(toFields(target)) { 46 | pks = append(pks, pk.value) 47 | } 48 | ms = append(ms, spanner.Delete(m.table, pks)) 49 | m.logf("Deleting from %s, key=%+v", m.table, pks) 50 | } 51 | return ms 52 | } 53 | -------------------------------------------------------------------------------- /dml_delete.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | 9 | "cloud.google.com/go/spanner" 10 | "github.com/pkg/errors" 11 | ) 12 | 13 | // Delete build and execute delete statement from the passed struct. 14 | // You can pass either a struct or a slice of structs to target. 15 | // If you pass a slice of structs, this method will build statement which deletes multiple records in one statement like the following. 16 | // DELETE FROM `T` WHERE (`COL1` = 'a' AND `COL2` = 'b') OR (`COL1` = 'c' AND `COL2` = 'd'); 17 | func (d *DML) Delete(ctx context.Context, tx *spanner.ReadWriteTransaction, target any) (rowCount int64, err error) { 18 | isStruct, err := validateStructOrStructSliceType(target) 19 | if err != nil { 20 | return 0, err 21 | } 22 | if isStruct { 23 | rowCount, err = tx.Update(ctx, *d.buildDeleteStmt(target)) 24 | return rowCount, errors.WithStack(err) 25 | } else { 26 | rowCount, err := tx.Update(ctx, *d.buildDeleteAllStmt(target)) 27 | return rowCount, errors.WithStack(err) 28 | } 29 | } 30 | 31 | func (d *DML) buildDeleteStmt(target any) *spanner.Statement { 32 | fields := toFields(target) 33 | whereClause, params := buildWherePK(fields) 34 | sql := fmt.Sprintf("DELETE FROM %s WHERE %s", 35 | d.getTableName(), 36 | whereClause, 37 | ) 38 | d.log(sql, params) 39 | return &spanner.Statement{ 40 | SQL: sql, 41 | Params: params, 42 | } 43 | } 44 | 45 | func (d *DML) buildDeleteAllStmt(target any) *spanner.Statement { 46 | var valuesList []string 47 | params := map[string]any{} 48 | 49 | slice := reflect.ValueOf(target).Elem() 50 | for i := 0; i < slice.Len(); i++ { 51 | var values []string 52 | for _, field := range extractPks(structValToFields(slice.Index(i))) { 53 | param := addW(addIdx(field.name, i)) 54 | values = append(values, quote(field.name)+"="+addPlaceHolder(param)) 55 | params[param] = field.value 56 | } 57 | valuesList = append(valuesList, fmt.Sprintf("(%s)", strings.Join(values, " AND "))) 58 | } 59 | 60 | sql := fmt.Sprintf("DELETE FROM %s WHERE %s", 61 | d.getTableName(), 62 | strings.Join(valuesList, " OR "), 63 | ) 64 | 65 | d.log(sql, params) 66 | return &spanner.Statement{ 67 | SQL: sql, 68 | Params: params, 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /functions.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "math/big" 5 | "reflect" 6 | "strings" 7 | "time" 8 | 9 | "cloud.google.com/go/civil" 10 | "cloud.google.com/go/spanner" 11 | ) 12 | 13 | // NewNullString initializes spanner.NullString setting Valid as true 14 | func NewNullString(str string) spanner.NullString { 15 | return spanner.NullString{ 16 | StringVal: str, 17 | Valid: true, 18 | } 19 | } 20 | 21 | // NewNullBool initializes spanner.NullBool setting Valid as true 22 | func NewNullBool(b bool) spanner.NullBool { 23 | return spanner.NullBool{ 24 | Bool: b, 25 | Valid: true, 26 | } 27 | } 28 | 29 | // NewNullInt64 initializes spanner.NullInt64 setting Valid as true 30 | func NewNullInt64(val int64) spanner.NullInt64 { 31 | return spanner.NullInt64{ 32 | Int64: val, 33 | Valid: true, 34 | } 35 | } 36 | 37 | // NewNullNumeric initializes spanner.NullNumeric setting Valid as true 38 | func NewNullNumeric(a, b int64) spanner.NullNumeric { 39 | return spanner.NullNumeric{ 40 | Numeric: *big.NewRat(a, b), 41 | Valid: true, 42 | } 43 | } 44 | 45 | // NewNullDate initializes spanner.NullDate setting Valid as true 46 | func NewNullDate(d civil.Date) spanner.NullDate { 47 | return spanner.NullDate{ 48 | Date: d, 49 | Valid: true, 50 | } 51 | } 52 | 53 | // NewNullTime initializes spanner.NullTime setting Valid as true 54 | func NewNullTime(t time.Time) spanner.NullTime { 55 | return spanner.NullTime{ 56 | Time: t, 57 | Valid: true, 58 | } 59 | } 60 | 61 | // ToKeySets convert any slice to spanner.KeySet 62 | func ToKeySets(target any) spanner.KeySet { 63 | var keys []spanner.Key 64 | slice := reflect.ValueOf(target) 65 | if slice.Kind() == reflect.Ptr { 66 | slice = slice.Elem() 67 | } 68 | for i := 0; i < slice.Len(); i++ { 69 | keys = append(keys, spanner.Key{slice.Index(i).Interface()}) 70 | } 71 | return spanner.KeySetFromKeys(keys...) 72 | } 73 | 74 | // ToAllColumnNames receives struct and returns the fields that the passed struct has. 75 | // This method is useful when you build query to select all the fields. 76 | // Instead of use *(wildcard), you can specify all of the columns using this method. 77 | // Then you can avoid the risk that failing to map record to struct caused by the mismatch of an order of columns in spanner table and fields in struct. 78 | func ToAllColumnNames(target any) string { 79 | var columnNames []string 80 | for _, f := range structValToFields(reflect.ValueOf(target).Elem()) { 81 | columnNames = append(columnNames, f.name) 82 | } 83 | return strings.Join(columnNames, ", ") 84 | } 85 | -------------------------------------------------------------------------------- /dml_insert.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | 9 | "cloud.google.com/go/spanner" 10 | "github.com/pkg/errors" 11 | ) 12 | 13 | // Insert build and execute insert statement from the passed struct. 14 | // You can pass either a struct or a slice of struct to target. 15 | // If you pass a slice of struct, this method will build a statement which insert multiple records in one statement like the following 16 | // INSERT INTO `TableName` (`Column1`, `Column2`) VALUES ('a', 'b'), ('c', 'd'), ...; 17 | func (d *DML) Insert(ctx context.Context, tx *spanner.ReadWriteTransaction, target any) (rowCount int64, err error) { 18 | isStruct, err := validateStructOrStructSliceType(target) 19 | if err != nil { 20 | return 0, err 21 | } 22 | if isStruct { 23 | rowCount, err := tx.Update(ctx, *d.buildInsertStmt(target)) 24 | return rowCount, errors.WithStack(err) 25 | } else { 26 | rowCount, err := tx.Update(ctx, *d.buildInsertAllStmt(target)) 27 | return rowCount, errors.WithStack(err) 28 | } 29 | } 30 | 31 | func (d *DML) buildInsertStmt(target any) *spanner.Statement { 32 | var columns []string 33 | var values []string 34 | params := map[string]any{} 35 | for _, field := range toFields(target) { 36 | columns = append(columns, quote(field.name)) 37 | values = append(values, addPlaceHolder(field.name)) 38 | params[field.name] = field.value 39 | } 40 | 41 | sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", 42 | d.getTableName(), 43 | strings.Join(columns, ", "), 44 | strings.Join(values, ", "), 45 | ) 46 | 47 | d.log(sql, params) 48 | return &spanner.Statement{ 49 | SQL: sql, 50 | Params: params, 51 | } 52 | } 53 | 54 | func (d *DML) buildInsertAllStmt(target any) *spanner.Statement { 55 | var columns []string 56 | var valuesList []string 57 | params := map[string]any{} 58 | 59 | slice := reflect.ValueOf(target).Elem() 60 | for i := 0; i < slice.Len(); i++ { 61 | var values []string 62 | for _, field := range structValToFields(slice.Index(i)) { 63 | if i == 0 { 64 | columns = append(columns, quote(field.name)) 65 | } 66 | param := addIdx(field.name, i) 67 | values = append(values, addPlaceHolder(param)) 68 | params[param] = field.value 69 | } 70 | valuesList = append(valuesList, "("+strings.Join(values, ", ")+")") 71 | } 72 | 73 | sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", 74 | d.getTableName(), 75 | strings.Join(columns, ", "), 76 | strings.Join(valuesList, ", "), 77 | ) 78 | 79 | d.log(sql, params) 80 | return &spanner.Statement{ 81 | SQL: sql, 82 | Params: params, 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kanjih/go-spnr/v2 2 | 3 | go 1.19 4 | 5 | require ( 6 | cloud.google.com/go v0.108.0 7 | cloud.google.com/go/spanner v1.42.0 8 | github.com/googleapis/gax-go/v2 v2.7.0 9 | github.com/iancoleman/strcase v0.2.0 10 | github.com/kanjih/go-spnr v0.1.1 11 | github.com/pkg/errors v0.9.1 12 | github.com/stretchr/testify v1.8.1 13 | github.com/testcontainers/testcontainers-go v0.14.0 14 | github.com/urfave/cli/v2 v2.23.7 15 | google.golang.org/api v0.108.0 16 | google.golang.org/genproto v0.0.0-20230119192704-9d59e20e5cd1 17 | google.golang.org/grpc v1.52.3 18 | gotest.tools v2.2.0+incompatible 19 | ) 20 | 21 | require ( 22 | cloud.google.com/go/compute v1.14.0 // indirect 23 | cloud.google.com/go/compute/metadata v0.2.3 // indirect 24 | cloud.google.com/go/iam v0.8.0 // indirect 25 | cloud.google.com/go/longrunning v0.3.0 // indirect 26 | github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect 27 | github.com/Microsoft/go-winio v0.5.2 // indirect 28 | github.com/Microsoft/hcsshim v0.9.4 // indirect 29 | github.com/cenkalti/backoff/v4 v4.1.3 // indirect 30 | github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect 31 | github.com/cespare/xxhash/v2 v2.1.2 // indirect 32 | github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4 // indirect 33 | github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1 // indirect 34 | github.com/containerd/cgroups v1.0.4 // indirect 35 | github.com/containerd/containerd v1.6.8 // indirect 36 | github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect 37 | github.com/davecgh/go-spew v1.1.1 // indirect 38 | github.com/docker/distribution v2.8.1+incompatible // indirect 39 | github.com/docker/docker v20.10.17+incompatible // indirect 40 | github.com/docker/go-connections v0.4.0 // indirect 41 | github.com/docker/go-units v0.5.0 // indirect 42 | github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1 // indirect 43 | github.com/envoyproxy/protoc-gen-validate v0.1.0 // indirect 44 | github.com/gogo/protobuf v1.3.2 // indirect 45 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 46 | github.com/golang/protobuf v1.5.2 // indirect 47 | github.com/google/go-cmp v0.5.9 // indirect 48 | github.com/google/uuid v1.3.0 // indirect 49 | github.com/googleapis/enterprise-certificate-proxy v0.2.1 // indirect 50 | github.com/magiconair/properties v1.8.6 // indirect 51 | github.com/moby/sys/mount v0.3.3 // indirect 52 | github.com/moby/sys/mountinfo v0.6.2 // indirect 53 | github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 // indirect 54 | github.com/morikuni/aec v1.0.0 // indirect 55 | github.com/opencontainers/go-digest v1.0.0 // indirect 56 | github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect 57 | github.com/opencontainers/runc v1.1.3 // indirect 58 | github.com/pmezard/go-difflib v1.0.0 // indirect 59 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 60 | github.com/sirupsen/logrus v1.8.1 // indirect 61 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect 62 | go.opencensus.io v0.24.0 // indirect 63 | golang.org/x/net v0.4.0 // indirect 64 | golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect 65 | golang.org/x/sys v0.3.0 // indirect 66 | golang.org/x/text v0.5.0 // indirect 67 | golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect 68 | google.golang.org/appengine v1.6.7 // indirect 69 | google.golang.org/protobuf v1.28.1 // indirect 70 | gopkg.in/yaml.v3 v3.0.1 // indirect 71 | ) 72 | -------------------------------------------------------------------------------- /reader_key.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "reflect" 5 | 6 | "cloud.google.com/go/spanner" 7 | "github.com/pkg/errors" 8 | "google.golang.org/api/iterator" 9 | ) 10 | 11 | // FindOne fetches a record by specified primary key, and map the record into the passed pointer of struct. 12 | func (r *Reader) FindOne(key spanner.Key, target any) error { 13 | if err := validateStructType(target); err != nil { 14 | return err 15 | } 16 | r.logf(readLogTemplate, "table:"+r.table, key) 17 | 18 | row, err := r.tx.ReadRow(r.ctx, r.table, key, toColumnNames(reflect.ValueOf(target).Elem().Type())) 19 | if err != nil { 20 | if isNotFound(err) { 21 | return ErrNotFound 22 | } 23 | return errors.WithStack(err) 24 | } 25 | return row.ToStruct(target) 26 | } 27 | 28 | // FindAll fetches records by specified a set of primary keys, and map the records into the passed pointer of slice of structs. 29 | func (r *Reader) FindAll(keys spanner.KeySet, target any) error { 30 | if err := validateStructSliceType(target); err != nil { 31 | return err 32 | } 33 | if r.logEnabled { 34 | r.logger.Printf(readLogTemplate, "table:"+r.table, keys) 35 | } 36 | slice := reflect.ValueOf(target).Elem() 37 | innerType := slice.Type().Elem() 38 | 39 | rows := r.tx.Read(r.ctx, r.table, keys, toColumnNames(innerType)) 40 | defer rows.Stop() 41 | for { 42 | row, err := rows.Next() 43 | if errors.Is(err, iterator.Done) { 44 | break 45 | } 46 | if err != nil { 47 | return errors.WithStack(err) 48 | } 49 | e := reflect.New(innerType).Elem() 50 | if err := row.ToStruct(e.Addr().Interface()); err != nil { 51 | return errors.WithStack(err) 52 | } 53 | slice.Set(reflect.Append(slice, e)) 54 | } 55 | 56 | return nil 57 | } 58 | 59 | /* 60 | GetColumn fetches the specified column by specified primary key, and map the column into the passed pointer of value. 61 | 62 | Caution: 63 | 64 | It maps fetched column to the passed pointer by just calling spanner.Row.Columns method. 65 | So the type of passed value to map should be compatible to this method. 66 | For example if you fetch an INT64 column from spanner, you need to map this value to int64, not int. 67 | */ 68 | func (r *Reader) GetColumn(key spanner.Key, column string, target any) error { 69 | r.logf(readLogTemplate, "table:"+r.table, key) 70 | row, err := r.tx.ReadRow(r.ctx, r.table, key, []string{column}) 71 | if err != nil { 72 | if isNotFound(err) { 73 | return ErrNotFound 74 | } 75 | return errors.WithStack(err) 76 | } 77 | return errors.WithStack(row.Columns(target)) 78 | } 79 | 80 | // GetColumn fetches the specified column for the records that matches specified set of primary keys, 81 | // and map the column into the passed pointer of a slice of values. 82 | // Please see the caution commented in GetColumn to check type compatibility. 83 | func (r *Reader) GetColumnAll(keys spanner.KeySet, column string, target any) error { 84 | if err := validateSliceType(target); err != nil { 85 | return err 86 | } 87 | r.logf(readLogTemplate, "table:"+r.table, keys) 88 | slice := reflect.ValueOf(target).Elem() 89 | innerType := slice.Type().Elem() 90 | 91 | rows := r.tx.Read(r.ctx, r.table, keys, []string{column}) 92 | defer rows.Stop() 93 | for { 94 | row, err := rows.Next() 95 | if errors.Is(err, iterator.Done) { 96 | break 97 | } 98 | if err != nil { 99 | return errors.WithStack(err) 100 | } 101 | e := reflect.New(innerType).Elem() 102 | if err := row.Columns(e.Addr().Interface()); err != nil { 103 | return errors.WithStack(err) 104 | } 105 | slice.Set(reflect.Append(slice, e)) 106 | } 107 | 108 | return nil 109 | } 110 | -------------------------------------------------------------------------------- /dml_update.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | 9 | "cloud.google.com/go/spanner" 10 | "github.com/pkg/errors" 11 | ) 12 | 13 | // Update build and execute update statement from the passed struct. 14 | // You can pass either a struct or slice of struct to target. 15 | // If you pass a slice of struct, this method will call update statement in for loop. 16 | func (d *DML) Update(ctx context.Context, tx *spanner.ReadWriteTransaction, target any) (rowCount int64, err error) { 17 | isStruct, err := validateStructOrStructSliceType(target) 18 | if err != nil { 19 | return 0, err 20 | } 21 | if isStruct { 22 | rowCount, err := tx.Update(ctx, *d.buildUpdateStmt(target, nil)) 23 | return rowCount, errors.WithStack(err) 24 | } else { 25 | rowCount, err := d.updateAll(ctx, tx, target) 26 | return rowCount, errors.WithStack(err) 27 | } 28 | } 29 | 30 | func (d *DML) updateAll(ctx context.Context, tx *spanner.ReadWriteTransaction, target any) (rowCount int64, err error) { 31 | slice := reflect.ValueOf(target).Elem() 32 | for i := 0; i < slice.Len(); i++ { 33 | cnt, err := tx.Update(ctx, *d.buildUpdateStmt(slice.Index(i).Addr().Interface(), nil)) 34 | if err != nil { 35 | return 0, err 36 | } 37 | rowCount += cnt 38 | } 39 | return rowCount, nil 40 | } 41 | 42 | // UpdateColumns build and execute update statement from the passed column names and struct. 43 | // You can specify the columns to update. 44 | // Also, you can pass either a struct or slice of struct to target. 45 | // If you pass a slice of struct, this method will call update statement in for loop. 46 | func (d *DML) UpdateColumns(ctx context.Context, tx *spanner.ReadWriteTransaction, columns []string, target any) (rowCount int64, err error) { 47 | isStruct, err := validateStructOrStructSliceType(target) 48 | if err != nil { 49 | return 0, err 50 | } 51 | if isStruct { 52 | rowCount, err := tx.Update(ctx, *d.buildUpdateStmt(target, columns)) 53 | return rowCount, errors.WithStack(err) 54 | } else { 55 | rowCount, err := d.updateAll(ctx, tx, target) 56 | return rowCount, errors.WithStack(err) 57 | } 58 | } 59 | 60 | func (d *DML) buildUpdateStmt(target any, columns []string) *spanner.Statement { 61 | fields := toFields(target) 62 | var setClause string 63 | var params map[string]any 64 | if columns != nil { 65 | setClause, params = buildSetClauseWithColumns(fields, columns) 66 | } else { 67 | setClause, params = buildSetClause(fields) 68 | } 69 | whereClause, whereParams := buildWherePK(fields) 70 | for k, v := range whereParams { 71 | params[k] = v 72 | } 73 | sql := fmt.Sprintf("UPDATE %s SET %s WHERE %s", 74 | d.getTableName(), 75 | setClause, 76 | whereClause, 77 | ) 78 | d.log(sql, params) 79 | return &spanner.Statement{ 80 | SQL: sql, 81 | Params: params, 82 | } 83 | } 84 | 85 | func buildSetClause(fields []field) (string, map[string]any) { 86 | var columns []string 87 | params := map[string]any{} 88 | for _, field := range extractNotPks(fields) { 89 | columns = append(columns, quote(field.name)+"="+addPlaceHolder(field.name)) 90 | params[field.name] = field.value 91 | } 92 | return strings.Join(columns, ", "), params 93 | } 94 | 95 | func buildSetClauseWithColumns(fields []field, columns []string) (string, map[string]any) { 96 | fieldsMap := map[string]field{} 97 | for _, f := range fields { 98 | fieldsMap[f.name] = f 99 | } 100 | 101 | var setColumns []string 102 | params := map[string]any{} 103 | for _, c := range columns { 104 | f := fieldsMap[c] 105 | setColumns = append(setColumns, quote(f.name)+"="+addPlaceHolder(f.name)) 106 | params[f.name] = f.value 107 | } 108 | 109 | return strings.Join(setColumns, ", "), params 110 | } 111 | -------------------------------------------------------------------------------- /helpers.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sort" 7 | "strings" 8 | 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | var errNotPointer = errors.New("final argument must be passed as pointer") 13 | 14 | func extractPks(fields []field) []field { 15 | var pks []field 16 | for _, field := range fields { 17 | if field.isPk() { 18 | pks = append(pks, field) 19 | } 20 | } 21 | sort.Slice(pks, func(i, j int) bool { 22 | return fields[i].pkOrder > fields[j].pkOrder 23 | }) 24 | return pks 25 | } 26 | 27 | func extractNotPks(fields []field) []field { 28 | var notPks []field 29 | for _, field := range fields { 30 | if !field.isPk() { 31 | notPks = append(notPks, field) 32 | } 33 | } 34 | return notPks 35 | } 36 | 37 | func buildWherePK(fields []field) (string, map[string]any) { 38 | var columns []string 39 | params := map[string]any{} 40 | for _, field := range extractPks(fields) { 41 | param := addW(field.name) 42 | columns = append(columns, quote(field.name)+"="+addPlaceHolder(param)) 43 | params[param] = field.value 44 | } 45 | return strings.Join(columns, " AND "), params 46 | } 47 | 48 | func addW(str string) string { 49 | return "w_" + str 50 | } 51 | 52 | func addIdx(str string, idx int) string { 53 | return fmt.Sprintf("%s_%d", str, idx) 54 | } 55 | 56 | func addPlaceHolder(str string) string { 57 | return "@" + str 58 | } 59 | 60 | func quote(str string) string { 61 | return "`" + str + "`" 62 | } 63 | 64 | func validateStructType(target any) error { 65 | rv := reflect.ValueOf(target) 66 | if rv.Kind() != reflect.Ptr { 67 | return errNotPointer 68 | } 69 | if rv.Elem().Kind() != reflect.Struct { 70 | return errors.New("final argument must be struct but got " + rv.Elem().Kind().String()) 71 | } 72 | return nil 73 | } 74 | 75 | func validateSliceType(target any) error { 76 | rv := reflect.ValueOf(target) 77 | if rv.Kind() != reflect.Ptr { 78 | return errNotPointer 79 | } 80 | if rv.Elem().Kind() != reflect.Slice { 81 | return errors.New("final argument must be slice but got " + rv.Elem().Kind().String()) 82 | } 83 | return nil 84 | } 85 | 86 | func validateStructSliceType(target any) error { 87 | rv := reflect.ValueOf(target) 88 | if rv.Kind() != reflect.Ptr { 89 | return errNotPointer 90 | } 91 | if rv.Elem().Kind() != reflect.Slice { 92 | return errors.New("final argument must be slice of struct but got " + rv.Elem().Kind().String()) 93 | } 94 | if rv.Elem().Type().Elem().Kind() != reflect.Struct { 95 | return errors.New("final argument must be slice of struct but got slice of " + rv.Elem().Type().Elem().Kind().String()) 96 | } 97 | return nil 98 | } 99 | 100 | func validateStructOrStructSliceType(target any) (isStruct bool, err error) { 101 | rv := reflect.ValueOf(target) 102 | if rv.Kind() != reflect.Ptr { 103 | return false, errNotPointer 104 | } 105 | switch rv.Elem().Kind() { 106 | case reflect.Struct: 107 | return true, nil 108 | case reflect.Slice: 109 | el := rv.Elem().Type().Elem() 110 | if el.Kind() == reflect.Struct { 111 | return false, nil 112 | } 113 | if el.Kind() != reflect.Ptr || el.Elem().Kind() != reflect.Struct { 114 | return false, errors.New("final argument must be slice of struct but got slice of " + rv.Elem().Type().Elem().Kind().String()) 115 | } 116 | return false, nil 117 | default: 118 | return false, errors.New("final argument must be struct or slice of struct but got " + rv.Elem().Kind().String()) 119 | } 120 | } 121 | 122 | // toStructSlice converts any to slice of struct s 123 | func toStructSlice(target any) []any { 124 | var parsed []any 125 | slice := reflect.ValueOf(target).Elem() 126 | for i := 0; i < slice.Len(); i++ { 127 | e := slice.Index(i) 128 | if e.Kind() == reflect.Struct { 129 | parsed = append(parsed, e.Addr().Interface()) 130 | } else { 131 | parsed = append(parsed, e.Interface()) 132 | } 133 | } 134 | return parsed 135 | } 136 | -------------------------------------------------------------------------------- /handlers/build/build_test.go: -------------------------------------------------------------------------------- 1 | package build 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | database "cloud.google.com/go/spanner/admin/database/apiv1" 10 | instance "cloud.google.com/go/spanner/admin/instance/apiv1" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/testcontainers/testcontainers-go" 13 | "github.com/testcontainers/testcontainers-go/wait" 14 | databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 15 | instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" 16 | ) 17 | 18 | const ( 19 | projectName = "test-project" 20 | instanceName = "test" 21 | databaseName = "test" 22 | projectID = "projects/" + projectName 23 | instanceID = projectID + "/instances/" + instanceName 24 | ) 25 | 26 | var ( 27 | insAdminClient *instance.InstanceAdminClient 28 | adminClient *database.DatabaseAdminClient 29 | ) 30 | 31 | func TestGenerateCode(t *testing.T) { 32 | codes, err := generateCode(context.Background(), projectName, instanceName, databaseName, "entity_test") 33 | assert.Nil(t, err) 34 | b, err := os.ReadFile("testdata/test1.go") 35 | assert.Nil(t, err) 36 | assert.Equal(t, string(b), string(codes["Test1"])) 37 | b, err = os.ReadFile("testdata/test2.go") 38 | assert.Nil(t, err) 39 | assert.Equal(t, string(b), string(codes["Test2"])) 40 | } 41 | 42 | func TestMain(m *testing.M) { 43 | ctx := context.Background() 44 | c, err := initSpannerContainer(ctx) 45 | if c != nil { 46 | defer c.Terminate(ctx) //nolint:errcheck 47 | } 48 | if err != nil { 49 | panic(err) 50 | } 51 | if err = initClients(ctx); err != nil { 52 | panic(err) 53 | } 54 | if err = initDatabase(ctx); err != nil { 55 | panic(err) 56 | } 57 | os.Exit(m.Run()) 58 | } 59 | 60 | func initSpannerContainer(ctx context.Context) (testcontainers.Container, error) { 61 | req := testcontainers.ContainerRequest{ 62 | Image: "gcr.io/cloud-spanner-emulator/emulator:1.3.0", 63 | ExposedPorts: []string{"9010/tcp"}, 64 | WaitingFor: wait.ForLog("gateway.go:142: gRPC server listening at 0.0.0.0:9010"), 65 | } 66 | spannerC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 67 | ContainerRequest: req, 68 | Started: true, 69 | }) 70 | if err != nil { 71 | return nil, err 72 | } 73 | h, err := spannerC.Host(ctx) 74 | if err != nil { 75 | return nil, err 76 | } 77 | p, err := spannerC.MappedPort(ctx, "9010") 78 | if err != nil { 79 | return nil, err 80 | } 81 | return spannerC, os.Setenv("SPANNER_EMULATOR_HOST", fmt.Sprintf("%s:%s", h, p.Port())) 82 | } 83 | 84 | func initClients(ctx context.Context) (err error) { 85 | insAdminClient, err = instance.NewInstanceAdminClient(ctx) 86 | if err != nil { 87 | return err 88 | } 89 | adminClient, err = database.NewDatabaseAdminClient(ctx) 90 | return err 91 | } 92 | 93 | func initDatabase(ctx context.Context) (err error) { 94 | createInstanceReq := &instancepb.CreateInstanceRequest{ 95 | Parent: projectID, 96 | Instance: &instancepb.Instance{ 97 | Name: instanceID, 98 | Config: projectID + "/instanceConfigs/test", 99 | DisplayName: instanceName, 100 | NodeCount: 1, 101 | }, 102 | InstanceId: instanceName, 103 | } 104 | ciOp, err := insAdminClient.CreateInstance(ctx, createInstanceReq) 105 | if err != nil { 106 | return err 107 | } 108 | if _, err = ciOp.Wait(ctx); err != nil { 109 | return err 110 | } 111 | 112 | b1, err := os.ReadFile("testdata/test1.sql") 113 | if err != nil { 114 | return err 115 | } 116 | b2, err := os.ReadFile("testdata/test2.sql") 117 | if err != nil { 118 | return err 119 | } 120 | 121 | createDatabaseReq := &databasepb.CreateDatabaseRequest{ 122 | Parent: instanceID, 123 | CreateStatement: "CREATE DATABASE " + databaseName, 124 | ExtraStatements: []string{string(b1), string(b2)}, 125 | } 126 | cdOp, err := adminClient.CreateDatabase(ctx, createDatabaseReq) 127 | if err != nil { 128 | return err 129 | } 130 | _, err = cdOp.Wait(ctx) 131 | return err 132 | } 133 | -------------------------------------------------------------------------------- /reader_key_test.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "cloud.google.com/go/spanner" 5 | "context" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | ) 9 | 10 | func TestFind(t *testing.T) { 11 | ctx := context.Background() 12 | assert.Nil(t, prepareReadTest(ctx)) 13 | 14 | var fetched []TestOrderChanged 15 | keys := spanner.KeySetFromKeys(spanner.Key{testRecord1.String, testRecord1.Int64}, spanner.Key{testRecord2.String, testRecord2.Int64}) 16 | err := testRepository.Reader(ctx, dataClient.Single()).FindAll(keys, &fetched) 17 | assert.Nil(t, err) 18 | assert.Len(t, fetched, 2) 19 | 20 | assert.Equal(t, testRecord1.String, fetched[0].String) 21 | assert.Equal(t, testRecord1.NullString, fetched[0].NullString) 22 | assert.Equal(t, testRecord1.NullInt64, fetched[0].NullInt64) 23 | assert.Equal(t, testRecord1.ArrayInt64, fetched[0].ArrayInt64) 24 | 25 | assert.Equal(t, testRecord2.String, fetched[1].String) 26 | assert.Equal(t, testRecord2.NullString, fetched[1].NullString) 27 | assert.Equal(t, testRecord2.NullInt64, fetched[1].NullInt64) 28 | assert.Equal(t, testRecord2.ArrayInt64, fetched[1].ArrayInt64) 29 | 30 | assert.Nil(t, cleanUpReadTest(ctx)) 31 | } 32 | 33 | func TestFindOne(t *testing.T) { 34 | ctx := context.Background() 35 | assert.Nil(t, prepareReadTest(ctx)) 36 | 37 | var fetched1 TestOrderChanged 38 | err := testRepository.Reader(ctx, dataClient.Single()).FindOne(spanner.Key{testRecord1.String, testRecord1.Int64}, &fetched1) 39 | assert.Nil(t, err) 40 | assert.Equal(t, testRecord1.String, fetched1.String) 41 | assert.Equal(t, testRecord1.NullString, fetched1.NullString) 42 | assert.Equal(t, testRecord1.NullInt64, fetched1.NullInt64) 43 | assert.Equal(t, testRecord1.ArrayInt64, fetched1.ArrayInt64) 44 | 45 | var fetched2 Test 46 | err = testRepository.Reader(ctx, dataClient.Single()).FindOne(spanner.Key{testRecord2.String, testRecord2.Int64}, &fetched2) 47 | assert.Nil(t, err) 48 | assert.Equal(t, testRecord2.String, fetched2.String) 49 | assert.Equal(t, testRecord2.NullString, fetched2.NullString) 50 | assert.Equal(t, testRecord2.NullInt64, fetched2.NullInt64) 51 | assert.Equal(t, testRecord2.ArrayInt64, fetched2.ArrayInt64) 52 | 53 | assert.Nil(t, cleanUpReadTest(ctx)) 54 | } 55 | 56 | func TestGetColumnAll(t *testing.T) { 57 | ctx := context.Background() 58 | assert.Nil(t, prepareReadTest(ctx)) 59 | 60 | var nullStrings []spanner.NullString 61 | keys := spanner.KeySetFromKeys(spanner.Key{testRecord1.String, testRecord1.Int64}, spanner.Key{testRecord2.String, testRecord2.Int64}) 62 | err := testRepository.Reader(ctx, dataClient.Single()).GetColumnAll(keys, "NullString", &nullStrings) 63 | assert.Nil(t, err) 64 | assert.Len(t, nullStrings, 2) 65 | assert.True(t, nullStrings[0].Valid) 66 | assert.True(t, nullStrings[1].Valid) 67 | assert.Equal(t, testRecord1.NullString.StringVal, nullStrings[0].StringVal) 68 | assert.Equal(t, testRecord2.NullString.StringVal, nullStrings[1].StringVal) 69 | 70 | assert.Nil(t, cleanUpReadTest(ctx)) 71 | } 72 | 73 | func TestGetColumn(t *testing.T) { 74 | ctx := context.Background() 75 | assert.Nil(t, prepareReadTest(ctx)) 76 | 77 | var nullString1 spanner.NullString 78 | err := testRepository.Reader(ctx, dataClient.Single()).GetColumn(spanner.Key{testRecord1.String, testRecord1.Int64}, "NullString", &nullString1) 79 | assert.Nil(t, err) 80 | assert.True(t, nullString1.Valid) 81 | assert.Equal(t, testRecord1.NullString.StringVal, nullString1.StringVal) 82 | 83 | var nullString2 spanner.NullString 84 | err = testRepository.Reader(ctx, dataClient.Single()).GetColumn(spanner.Key{testRecord2.String, testRecord2.Int64}, "NullString", &nullString2) 85 | assert.Nil(t, err) 86 | assert.True(t, nullString2.Valid) 87 | assert.Equal(t, testRecord2.NullString.StringVal, nullString2.StringVal) 88 | 89 | assert.Nil(t, cleanUpReadTest(ctx)) 90 | } 91 | 92 | func prepareReadTest(ctx context.Context) error { 93 | ls := []*Test{testRecord1, testRecord2} 94 | _, err := testRepository.ApplyInsertOrUpdate(ctx, dataClient, &ls) 95 | return err 96 | } 97 | 98 | func cleanUpReadTest(ctx context.Context) error { 99 | ls := []*Test{testRecord1, testRecord2} 100 | _, err := testRepository.ApplyDelete(ctx, dataClient, &ls) 101 | return err 102 | } 103 | -------------------------------------------------------------------------------- /mutation_update.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "time" 7 | 8 | "cloud.google.com/go/spanner" 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | // Update build and execute update operation using mutation API. 13 | // You can pass either a struct or a slice of structs. 14 | // If you pass a slice of structs, this method will call multiple mutations for each struct. 15 | // This method requires spanner.ReadWriteTransaction, and will call spanner.ReadWriteTransaction.BufferWrite to save the mutation to transaction. 16 | // If you want to update only the specified columns, use UpdateColumns instead. 17 | func (m *Mutation) Update(tx *spanner.ReadWriteTransaction, target any) error { 18 | isStruct, err := validateStructOrStructSliceType(target) 19 | if err != nil { 20 | return err 21 | } 22 | if isStruct { 23 | return errors.WithStack(tx.BufferWrite(m.buildUpdate([]any{target}))) 24 | } 25 | return errors.WithStack(tx.BufferWrite(m.buildUpdate(toStructSlice(target)))) 26 | } 27 | 28 | // ApplyUpdate is basically same as Update, but it doesn't require transaction. 29 | // This method directly calls mutation API without transaction by calling spanner.Client.Apply method. 30 | // If you want to update only the specified columns, use ApplyUpdateColumns instead. 31 | func (m *Mutation) ApplyUpdate(ctx context.Context, client *spanner.Client, target any) (time.Time, error) { 32 | isStruct, err := validateStructOrStructSliceType(target) 33 | if err != nil { 34 | return time.Time{}, err 35 | } 36 | if isStruct { 37 | t, err := client.Apply(ctx, m.buildUpdate([]any{target})) 38 | return t, errors.WithStack(err) 39 | } 40 | t, err := client.Apply(ctx, m.buildUpdate(toStructSlice(target))) 41 | return t, errors.WithStack(err) 42 | } 43 | 44 | // UpdateColumns build and execute update operation for specified columns using mutation API. 45 | // You can pass either a struct or a slice of structs to target. 46 | // If you pass a slice of structs, this method will build a mutation for each struct. 47 | // This method requires spanner.ReadWriteTransaction, and will call spanner.ReadWriteTransaction.BufferWrite to save the mutation to transaction. 48 | func (m *Mutation) UpdateColumns(tx *spanner.ReadWriteTransaction, columns []string, target any) error { 49 | isStruct, err := validateStructOrStructSliceType(target) 50 | if err != nil { 51 | return err 52 | } 53 | if isStruct { 54 | return errors.WithStack(tx.BufferWrite(m.buildUpdateWithColumns([]any{target}, columns))) 55 | } 56 | return errors.WithStack(tx.BufferWrite(m.buildUpdateWithColumns(toStructSlice(target), columns))) 57 | } 58 | 59 | // ApplyUpdateColumns is basically same as UpdateColumns, but it doesn't require transaction. 60 | // This method directly calls mutation API without transaction by calling spanner.Client.Apply method. 61 | func (m *Mutation) ApplyUpdateColumns(ctx context.Context, client *spanner.Client, columns []string, target any) (time.Time, error) { 62 | isStruct, err := validateStructOrStructSliceType(target) 63 | if err != nil { 64 | return time.Time{}, err 65 | } 66 | if isStruct { 67 | t, err := client.Apply(ctx, m.buildUpdateWithColumns([]any{target}, columns)) 68 | return t, errors.WithStack(err) 69 | } 70 | t, err := client.Apply(ctx, m.buildUpdateWithColumns(toStructSlice(target), columns)) 71 | return t, errors.WithStack(err) 72 | } 73 | 74 | func (m *Mutation) buildUpdate(targets []any) []*spanner.Mutation { 75 | var ms []*spanner.Mutation 76 | for _, target := range targets { 77 | var columns []string 78 | var values []any 79 | for _, field := range toFields(target) { 80 | columns = append(columns, field.name) 81 | values = append(values, field.value) 82 | } 83 | m.logf("Update %s, columns=%+v, values=%+v", m.table, columns, values) 84 | ms = append(ms, spanner.Update(m.table, columns, values)) 85 | } 86 | return ms 87 | } 88 | 89 | func (m *Mutation) buildUpdateWithColumns(targets []any, columns []string) []*spanner.Mutation { 90 | var ms []*spanner.Mutation 91 | for _, target := range targets { 92 | fieldNameToField := map[string]field{} 93 | for _, f := range toFields(target) { 94 | fieldNameToField[strings.ToLower(f.name)] = f 95 | } 96 | var values []any 97 | for _, c := range columns { 98 | values = append(values, fieldNameToField[strings.ToLower(c)].value) 99 | } 100 | m.logf("Update %s, columns=%+v, values=%+v", m.table, columns, values) 101 | ms = append(ms, spanner.Update(m.table, columns, values)) 102 | } 103 | return ms 104 | } 105 | -------------------------------------------------------------------------------- /reader_query.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "reflect" 5 | 6 | "cloud.google.com/go/spanner" 7 | "github.com/pkg/errors" 8 | "google.golang.org/api/iterator" 9 | ) 10 | 11 | /* 12 | QueryOne fetches a record by calling specified query, and map the record into the passed pointer of struct. 13 | 14 | Errors: 15 | 16 | If no records are found, this method will return ErrNotFound. 17 | If multiple records are found, this method will return ErrMoreThanOneRecordFound. 18 | 19 | If you don't need to fetch all columns but only needs one column, use QueryValue instead. 20 | If you don't need to fetch all columns but only needs some columns, please make a temporal struct to map the columns. 21 | */ 22 | func (r *Reader) QueryOne(sql string, params map[string]any, target any) error { 23 | if err := validateStructType(target); err != nil { 24 | return err 25 | } 26 | r.logf(readLogTemplate, "sql:"+sql, params) 27 | 28 | iter := r.tx.Query(r.ctx, spanner.Statement{SQL: sql, Params: params}) 29 | defer iter.Stop() 30 | 31 | row, err := iter.Next() 32 | if errors.Is(err, iterator.Done) { 33 | return ErrNotFound 34 | } 35 | if err != nil { 36 | return errors.WithStack(err) 37 | } 38 | 39 | err = row.ToStruct(target) 40 | if err != nil { 41 | return errors.WithStack(err) 42 | } 43 | 44 | _, err = iter.Next() 45 | if errors.Is(err, iterator.Done) { 46 | return nil 47 | } else { 48 | return ErrMoreThanOneRecordFound 49 | } 50 | } 51 | 52 | // Query fetches records by calling specified query, and map the records into the passed pointer of a slice of struct. 53 | func (r *Reader) Query(sql string, params map[string]any, target any) error { 54 | if err := validateStructSliceType(target); err != nil { 55 | return err 56 | } 57 | r.logf(readLogTemplate, "sql:"+sql, params) 58 | slice := reflect.ValueOf(target).Elem() 59 | innerType := slice.Type().Elem() 60 | 61 | iter := r.tx.Query(r.ctx, spanner.Statement{SQL: sql, Params: params}) 62 | defer iter.Stop() 63 | 64 | for { 65 | row, err := iter.Next() 66 | if errors.Is(err, iterator.Done) { 67 | break 68 | } 69 | if err != nil { 70 | return errors.WithStack(err) 71 | } 72 | e := reflect.New(innerType).Elem() 73 | if err := row.ToStruct(e.Addr().Interface()); err != nil { 74 | return errors.WithStack(err) 75 | } 76 | slice.Set(reflect.Append(slice, e)) 77 | } 78 | return nil 79 | } 80 | 81 | /* 82 | QueryValue fetches one value by calling specified query, and map the value into the passed pointer of value. 83 | 84 | Errors: 85 | 86 | If no records are found, this method will return ErrNotFound. 87 | If multiple records are found, this method will return ErrMoreThanOneRecordFound. 88 | 89 | Example: 90 | var cnt int64 91 | QueryValue("select count(*) as cnt from Singers", nil, &cnt) 92 | */ 93 | func (r *Reader) QueryValue(sql string, params map[string]any, target any) error { 94 | r.logf(readLogTemplate, "sql:"+sql, params) 95 | iter := r.tx.Query(r.ctx, spanner.Statement{SQL: sql, Params: params}) 96 | defer iter.Stop() 97 | 98 | row, err := iter.Next() 99 | if err != nil { 100 | if errors.Is(err, iterator.Done) { 101 | return ErrNotFound 102 | } 103 | return errors.WithStack(err) 104 | } 105 | if row == nil { 106 | return ErrNotFound 107 | } 108 | 109 | err = row.Columns(target) 110 | if err != nil { 111 | return errors.WithStack(err) 112 | } 113 | 114 | _, err = iter.Next() 115 | if errors.Is(err, iterator.Done) { 116 | return nil 117 | } else { 118 | return ErrMoreThanOneRecordFound 119 | } 120 | } 121 | 122 | /* 123 | QueryValues fetches each value of multiple records by calling specified query, and map the values into the passed pointer of a slice of struct. 124 | 125 | Example: 126 | var names []string 127 | QueryValue("select Name from Singers", nil, &names) 128 | */ 129 | func (r *Reader) QueryValues(sql string, params map[string]any, target any) error { 130 | if err := validateSliceType(target); err != nil { 131 | return err 132 | } 133 | r.logf(readLogTemplate, "sql:"+sql, params) 134 | slice := reflect.ValueOf(target).Elem() 135 | innerType := slice.Type().Elem() 136 | 137 | iter := r.tx.Query(r.ctx, spanner.Statement{SQL: sql, Params: params}) 138 | defer iter.Stop() 139 | 140 | for { 141 | row, err := iter.Next() 142 | if errors.Is(err, iterator.Done) { 143 | break 144 | } 145 | if err != nil { 146 | return errors.WithStack(err) 147 | } 148 | e := reflect.New(innerType).Elem() 149 | if err := row.Columns(e.Addr().Interface()); err != nil { 150 | return errors.WithStack(err) 151 | } 152 | slice.Set(reflect.Append(slice, e)) 153 | } 154 | return nil 155 | } 156 | -------------------------------------------------------------------------------- /dml_update_test.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "cloud.google.com/go/spanner" 5 | "context" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | ) 9 | 10 | func TestDML_buildUpdateStmt(t *testing.T) { 11 | stmt := testDMLRepository.buildUpdateStmt(testRecord1, nil) 12 | assert.Equal(t, "UPDATE `Test` SET `Bytes`=@Bytes, `Float64`=@Float64, `Numeric`=@Numeric, `Bool`=@Bool, `Date`=@Date, `Timestamp`=@Timestamp, `NullString`=@NullString, `NullInt64`=@NullInt64, `NullFloat64`=@NullFloat64, `NullNumeric`=@NullNumeric, `NullBool`=@NullBool, `NullDate`=@NullDate, `NullTimestamp`=@NullTimestamp, `ArrayString`=@ArrayString, `ArrayBytes`=@ArrayBytes, `ArrayInt64`=@ArrayInt64, `ArrayFloat64`=@ArrayFloat64, `ArrayNumeric`=@ArrayNumeric, `ArrayBool`=@ArrayBool, `ArrayDate`=@ArrayDate, `ArrayTimestamp`=@ArrayTimestamp WHERE `String`=@w_String AND `Int64`=@w_Int64", stmt.SQL) 13 | assert.Equal(t, testRecord1.String, stmt.Params["w_String"].(string)) 14 | assert.Equal(t, testRecord1.NullString.StringVal, (stmt.Params["NullString"].(spanner.NullString)).StringVal) 15 | assert.Equal(t, testRecord1.NullInt64.Int64, (stmt.Params["NullInt64"].(spanner.NullInt64)).Int64) 16 | 17 | stmt = testDMLRepository.buildUpdateStmt(testRecord1, []string{"NullString", "NullInt64"}) 18 | assert.Equal(t, "UPDATE `Test` SET `NullString`=@NullString, `NullInt64`=@NullInt64 WHERE `String`=@w_String AND `Int64`=@w_Int64", stmt.SQL) 19 | assert.Equal(t, testRecord1.String, stmt.Params["w_String"].(string)) 20 | assert.Equal(t, testRecord1.NullString.StringVal, (stmt.Params["NullString"].(spanner.NullString)).StringVal) 21 | assert.Equal(t, testRecord1.NullInt64.Int64, (stmt.Params["NullInt64"].(spanner.NullInt64)).Int64) 22 | } 23 | 24 | func TestDML_buildUpdateStmtWithSlice(t *testing.T) { 25 | _, err := dataClient.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 26 | _, err := testDMLRepository.Insert(ctx, tx, &([]*Test{testRecord3, testRecord4})) 27 | assert.Nil(t, err) 28 | 29 | testRecord5 := *testRecord3 30 | testRecord6 := *testRecord4 31 | testRecord5.Bytes = testRecord6.Bytes 32 | testRecord6.Bytes = testRecord5.Bytes 33 | 34 | _, err = testDMLRepository.Update(ctx, tx, &([]Test{testRecord5, testRecord6})) 35 | assert.Nil(t, err) 36 | 37 | var fetched Test 38 | err = testRepository.Reader(ctx, tx).FindOne(spanner.Key{testRecord5.String, testRecord5.Int64}, &fetched) 39 | assert.Nil(t, err) 40 | assert.Equal(t, testRecord5.String, fetched.String) 41 | assert.Equal(t, testRecord5.Int64, fetched.Int64) 42 | assert.Equal(t, testRecord6.Bytes, fetched.Bytes) 43 | assert.Equal(t, testRecord5.Float64, fetched.Float64) 44 | 45 | err = testRepository.Reader(ctx, tx).FindOne(spanner.Key{testRecord6.String, testRecord6.Int64}, &fetched) 46 | assert.Nil(t, err) 47 | assert.Equal(t, testRecord6.String, fetched.String) 48 | assert.Equal(t, testRecord6.Int64, fetched.Int64) 49 | assert.Equal(t, testRecord5.Bytes, fetched.Bytes) 50 | assert.Equal(t, testRecord6.Float64, fetched.Float64) 51 | 52 | _, err = testDMLRepository.Delete(ctx, tx, &([]*Test{testRecord3, testRecord4})) 53 | assert.Nil(t, err) 54 | 55 | return nil 56 | }) 57 | assert.Nil(t, err) 58 | } 59 | 60 | func TestDML_buildUpdateStmtWithSlicePointer(t *testing.T) { 61 | _, err := dataClient.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 62 | _, err := testDMLRepository.Insert(ctx, tx, &([]*Test{testRecord3, testRecord4})) 63 | assert.Nil(t, err) 64 | 65 | testRecord5 := *testRecord3 66 | testRecord6 := *testRecord4 67 | testRecord5.Bytes = testRecord6.Bytes 68 | testRecord6.Bytes = testRecord5.Bytes 69 | 70 | _, err = testDMLRepository.Update(ctx, tx, &([]*Test{&testRecord5, &testRecord6})) 71 | assert.Nil(t, err) 72 | 73 | var fetched Test 74 | err = testRepository.Reader(ctx, tx).FindOne(spanner.Key{testRecord5.String, testRecord5.Int64}, &fetched) 75 | assert.Nil(t, err) 76 | assert.Equal(t, testRecord5.String, fetched.String) 77 | assert.Equal(t, testRecord5.Int64, fetched.Int64) 78 | assert.Equal(t, testRecord6.Bytes, fetched.Bytes) 79 | assert.Equal(t, testRecord5.Float64, fetched.Float64) 80 | 81 | err = testRepository.Reader(ctx, tx).FindOne(spanner.Key{testRecord6.String, testRecord6.Int64}, &fetched) 82 | assert.Nil(t, err) 83 | assert.Equal(t, testRecord6.String, fetched.String) 84 | assert.Equal(t, testRecord6.Int64, fetched.Int64) 85 | assert.Equal(t, testRecord5.Bytes, fetched.Bytes) 86 | assert.Equal(t, testRecord6.Float64, fetched.Float64) 87 | 88 | _, err = testDMLRepository.Delete(ctx, tx, &([]*Test{testRecord3, testRecord4})) 89 | assert.Nil(t, err) 90 | 91 | return nil 92 | }) 93 | assert.Nil(t, err) 94 | } 95 | -------------------------------------------------------------------------------- /handlers/build/generate.go: -------------------------------------------------------------------------------- 1 | package build 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/iancoleman/strcase" 7 | "go/format" 8 | "strings" 9 | "text/template" 10 | ) 11 | 12 | const ( 13 | tmplImportSpanner = `"cloud.google.com/go/spanner"` 14 | tmplImportCivil = `"cloud.google.com/go/civil"` 15 | tmplImportBig = `"math/big"` 16 | tmplImportTime = `"time"` 17 | tmpl = `package {{ .PackageName }} 18 | {{ .Import }} 19 | type {{ .StructName }} struct { 20 | {{ range $field := .Fields }} {{ $field }} 21 | {{ end }}} 22 | ` 23 | ) 24 | 25 | type tmplValues struct { 26 | PackageName string 27 | Import string 28 | StructName string 29 | Fields []string 30 | } 31 | 32 | func generate(pkgName string, tableNameColumns map[string][]column) (map[string][]byte, error) { 33 | res := map[string][]byte{} 34 | for tableName, columns := range tableNameColumns { 35 | b, err := buildCode(buildTmplValues(pkgName, tableName, columns)) 36 | if err != nil { 37 | return nil, err 38 | } 39 | res[tableName] = b 40 | } 41 | return res, nil 42 | } 43 | 44 | func buildTmplValues(pkgName, tableName string, columns []column) tmplValues { 45 | var fields []string 46 | var containsNullable, containsDate, containsBig, containsTimestamp bool 47 | for _, c := range columns { 48 | if c.nullable { 49 | if c.tp == tpString || c.tp == tpInt64 || c.tp == tpFloat64 || c.tp == tpNumeric || c.tp == tpBool || c.tp == tpDate || c.tp == tpTimestamp { 50 | containsNullable = true 51 | } 52 | } 53 | if c.tp == tpDate || c.tp == tpArrayDate { 54 | containsDate = true 55 | } else if c.tp == tpNumeric || c.tp == tpArrayNumeric { 56 | containsBig = true 57 | } else if c.tp == tpTimestamp || c.tp == tpArrayTimestamp { 58 | containsTimestamp = true 59 | } 60 | fields = append(fields, fmt.Sprintf("%s %s `%s`", strcase.ToCamel(c.name), buildType(c), buildFieldName(c)+buildPk(c))) 61 | } 62 | return tmplValues{ 63 | PackageName: pkgName, 64 | StructName: strcase.ToCamel(tableName), 65 | Import: buildImport(containsNullable, containsDate, containsBig, containsTimestamp), 66 | Fields: fields, 67 | } 68 | } 69 | 70 | func buildType(c column) string { 71 | switch c.tp { 72 | case tpString: 73 | if c.nullable { 74 | return "spanner.NullString" 75 | } 76 | return "string" 77 | case tpBytes: 78 | return "[]byte" 79 | case tpInt64: 80 | if c.nullable { 81 | return "spanner.NullInt64" 82 | } 83 | return "int64" 84 | case tpFloat64: 85 | if c.nullable { 86 | return "spanner.NullFloat64" 87 | } 88 | return "float64" 89 | case tpNumeric: 90 | if c.nullable { 91 | return "spanner.NullNumeric" 92 | } 93 | return "big.Rat" 94 | case tpBool: 95 | if c.nullable { 96 | return "spanner.NullBool" 97 | } 98 | return "bool" 99 | case tpDate: 100 | if c.nullable { 101 | return "spanner.NullDate" 102 | } 103 | return "civil.Date" 104 | case tpTimestamp: 105 | if c.nullable { 106 | return "spanner.NullTime" 107 | } 108 | return "time.Time" 109 | case rpArrayString: 110 | return "[]string" 111 | case tpArrayBytes: 112 | return "[][]byte" 113 | case tpArrayInt64: 114 | return "[]int64" 115 | case tpArrayFloat64: 116 | return "[]float64" 117 | case tpArrayNumeric: 118 | return "[]big.Rat" 119 | case tpArrayBool: 120 | return "[]bool" 121 | case tpArrayDate: 122 | return "[]civil.Date" 123 | case tpArrayTimestamp: 124 | return "[]time.Time" 125 | } 126 | return "undefinedType" 127 | } 128 | 129 | func buildFieldName(c column) string { 130 | return fmt.Sprintf(`spanner:"%s"`, c.name) 131 | } 132 | 133 | func buildPk(c column) string { 134 | if !c.isPk { 135 | return "" 136 | } 137 | return fmt.Sprintf(` pk:"%d"`, c.pkOrder) 138 | } 139 | 140 | func buildImport(containsNullable, containsDate, containsBig, containsTimestamp bool) string { 141 | var imports []string 142 | if containsNullable { 143 | imports = append(imports, tmplImportSpanner) 144 | } 145 | if containsDate { 146 | imports = append(imports, tmplImportCivil) 147 | } 148 | if containsBig { 149 | imports = append(imports, tmplImportBig) 150 | } 151 | if containsTimestamp { 152 | imports = append(imports, tmplImportTime) 153 | } 154 | 155 | if len(imports) == 0 { 156 | return "" 157 | } 158 | if len(imports) == 1 { 159 | return "import " + imports[0] 160 | } 161 | return fmt.Sprintf(`import (%s)`, strings.Join(imports, "\n")) 162 | } 163 | 164 | func buildCode(v tmplValues) ([]byte, error) { 165 | var buf bytes.Buffer 166 | tmpl := template.Must(template.New("").Parse(tmpl)) 167 | if err := tmpl.Execute(&buf, v); err != nil { 168 | return nil, err 169 | } 170 | return format.Source(buf.Bytes()) 171 | } 172 | -------------------------------------------------------------------------------- /mutation_upsert.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "time" 7 | 8 | "cloud.google.com/go/spanner" 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | // InsertOrUpdate build and execute insert_or_update operation using mutation API. 13 | // You can pass either a struct or a slice of structs. 14 | // If you pass a slice of structs, this method will call multiple mutations for each struct. 15 | // This method requires spanner.ReadWriteTransaction, and will call spanner.ReadWriteTransaction.BufferWrite to save the mutation to transaction. 16 | // If you want to insert or update only the specified columns, use InsertOrUpdateColumns instead. 17 | func (m *Mutation) InsertOrUpdate(tx *spanner.ReadWriteTransaction, target any) error { 18 | isStruct, err := validateStructOrStructSliceType(target) 19 | if err != nil { 20 | return err 21 | } 22 | if isStruct { 23 | return errors.WithStack(tx.BufferWrite(m.buildInsertOrUpdate([]any{target}))) 24 | } 25 | return errors.WithStack(tx.BufferWrite(m.buildInsertOrUpdate(toStructSlice(target)))) 26 | } 27 | 28 | // ApplyInsertOrUpdate is basically same as InsertOrUpdate, but it doesn't require transaction. 29 | // This method directly calls mutation API without transaction by calling spanner.Client.Apply method. 30 | // If you want to insert or update only the specified columns, use ApplyInsertOrUpdateColumns instead. 31 | func (m *Mutation) ApplyInsertOrUpdate(ctx context.Context, client *spanner.Client, target any) (time.Time, error) { 32 | isStruct, err := validateStructOrStructSliceType(target) 33 | if err != nil { 34 | return time.Time{}, err 35 | } 36 | if isStruct { 37 | t, err := client.Apply(ctx, m.buildInsertOrUpdate([]any{target})) 38 | return t, errors.WithStack(err) 39 | } 40 | t, err := client.Apply(ctx, m.buildInsertOrUpdate(toStructSlice(target))) 41 | return t, errors.WithStack(err) 42 | } 43 | 44 | // InsertOrUpdateColumns build and execute insert_or_update operation for specified columns using mutation API. 45 | // You can pass either a struct or a slice of structs to target. 46 | // If you pass a slice of structs, this method will build a mutation for each struct. 47 | // This method requires spanner.ReadWriteTransaction, and will call spanner.ReadWriteTransaction.BufferWrite to save the mutation to transaction. 48 | func (m *Mutation) InsertOrUpdateColumns(tx *spanner.ReadWriteTransaction, columns []string, target any) error { 49 | isStruct, err := validateStructOrStructSliceType(target) 50 | if err != nil { 51 | return err 52 | } 53 | if isStruct { 54 | return errors.WithStack(tx.BufferWrite(m.buildInsertOrUpdateWithColumns(columns, []any{target}))) 55 | } 56 | return errors.WithStack(tx.BufferWrite(m.buildInsertOrUpdateWithColumns(columns, toStructSlice(target)))) 57 | } 58 | 59 | // ApplyInsertOrUpdateColumns is basically same as InsertOrUpdateColumns, but it doesn't require transaction. 60 | // This method directly calls mutation API without transaction by calling spanner.Client.Apply method. 61 | func (m *Mutation) ApplyInsertOrUpdateColumns(ctx context.Context, client *spanner.Client, columns []string, target any) (time.Time, error) { 62 | isStruct, err := validateStructOrStructSliceType(target) 63 | if err != nil { 64 | return time.Time{}, err 65 | } 66 | if isStruct { 67 | t, err := client.Apply(ctx, m.buildInsertOrUpdateWithColumns(columns, []any{target})) 68 | return t, errors.WithStack(err) 69 | } 70 | t, err := client.Apply(ctx, m.buildInsertOrUpdateWithColumns(columns, toStructSlice(target))) 71 | return t, errors.WithStack(err) 72 | } 73 | 74 | func (m *Mutation) buildInsertOrUpdate(targets []any) []*spanner.Mutation { 75 | var ms []*spanner.Mutation 76 | for _, target := range targets { 77 | var columns []string 78 | var values []any 79 | for _, field := range toFields(target) { 80 | columns = append(columns, field.name) 81 | values = append(values, field.value) 82 | } 83 | m.logf("InsertOrUpdate into %s, columns=%+v, values=%+v", m.table, columns, values) 84 | ms = append(ms, spanner.InsertOrUpdate(m.table, columns, values)) 85 | } 86 | return ms 87 | } 88 | 89 | func (m *Mutation) buildInsertOrUpdateWithColumns(columns []string, targets []any) []*spanner.Mutation { 90 | var ms []*spanner.Mutation 91 | for _, target := range targets { 92 | fieldNameField := map[string]field{} 93 | for _, f := range toFields(target) { 94 | fieldNameField[strings.ToLower(f.name)] = f 95 | } 96 | var values []any 97 | for _, c := range columns { 98 | values = append(values, fieldNameField[c]) 99 | } 100 | m.logf("Update %s, columns=%+v, values=%+v", m.table, columns, values) 101 | ms = append(ms, spanner.InsertOrUpdate(m.table, columns, values)) 102 | } 103 | return ms 104 | 105 | } 106 | -------------------------------------------------------------------------------- /handlers/build/fetch.go: -------------------------------------------------------------------------------- 1 | package build 2 | 3 | import ( 4 | "cloud.google.com/go/spanner" 5 | "context" 6 | "fmt" 7 | "github.com/kanjih/go-spnr" 8 | "strings" 9 | ) 10 | 11 | type spannerType int 12 | 13 | const ( 14 | tpUndefined spannerType = iota 15 | tpString 16 | tpBytes 17 | tpInt64 18 | tpFloat64 19 | tpNumeric 20 | tpBool 21 | tpDate 22 | tpTimestamp 23 | rpArrayString 24 | tpArrayBytes 25 | tpArrayInt64 26 | tpArrayFloat64 27 | tpArrayNumeric 28 | tpArrayBool 29 | tpArrayDate 30 | tpArrayTimestamp 31 | ) 32 | 33 | type columnRecord struct { 34 | TableName string `spanner:"TABLE_NAME"` 35 | ColumnsName string `spanner:"COLUMN_NAME"` 36 | Nullable string `spanner:"IS_NULLABLE"` 37 | Type string `spanner:"SPANNER_TYPE"` 38 | } 39 | 40 | type indexColumnRecord struct { 41 | TableName string `spanner:"TABLE_NAME"` 42 | ColumnsName string `spanner:"COLUMN_NAME"` 43 | Order int64 `spanner:"ORDINAL_POSITION"` 44 | } 45 | 46 | type column struct { 47 | name string 48 | tp spannerType 49 | nullable bool 50 | isPk bool 51 | pkOrder int 52 | } 53 | 54 | func fetchColumns(ctx context.Context, projectId, instanceName, dbName string) (map[string][]column, error) { 55 | client, err := spanner.NewClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceName, dbName)) 56 | if err != nil { 57 | return nil, err 58 | } 59 | columns, err := fetchColumnRecords(ctx, client) 60 | if err != nil { 61 | return nil, err 62 | } 63 | primaryKeys, err := fetchPrimaryKeys(ctx, client) 64 | if err != nil { 65 | return nil, err 66 | } 67 | return buildColumns(columns, primaryKeys), nil 68 | } 69 | 70 | func fetchColumnRecords(ctx context.Context, client *spanner.Client) (map[string][]columnRecord, error) { 71 | q := "select TABLE_NAME, COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE from information_schema.columns where TABLE_SCHEMA = '' order by ORDINAL_POSITION" 72 | var columns []columnRecord 73 | if err := spnr.New("").Reader(ctx, client.Single()).Query(q, nil, &columns); err != nil { 74 | return nil, err 75 | } 76 | res := map[string][]columnRecord{} 77 | for _, c := range columns { 78 | cols, exists := res[c.TableName] 79 | if !exists { 80 | res[c.TableName] = []columnRecord{c} 81 | } else { 82 | res[c.TableName] = append(cols, c) 83 | } 84 | } 85 | return res, nil 86 | } 87 | 88 | func fetchPrimaryKeys(ctx context.Context, client *spanner.Client) (map[string]map[string]int64, error) { 89 | q := "select TABLE_NAME, COLUMN_NAME, ORDINAL_POSITION from information_schema.INDEX_COLUMNS where TABLE_SCHEMA = '' and INDEX_NAME = 'PRIMARY_KEY'" 90 | var columns []indexColumnRecord 91 | if err := spnr.New("").Reader(ctx, client.Single()).Query(q, nil, &columns); err != nil { 92 | return nil, err 93 | } 94 | res := map[string]map[string]int64{} 95 | for _, c := range columns { 96 | m, exists := res[c.TableName] 97 | if !exists { 98 | m = map[string]int64{} 99 | } 100 | m[c.ColumnsName] = c.Order 101 | res[c.TableName] = m 102 | } 103 | return res, nil 104 | } 105 | 106 | func buildColumns(columnRecords map[string][]columnRecord, pkLists map[string]map[string]int64) map[string][]column { 107 | res := map[string][]column{} 108 | for tableName, columnRecords := range columnRecords { 109 | pks := pkLists[tableName] 110 | var columns []column 111 | for _, r := range columnRecords { 112 | pkOrder, isPk := pks[r.ColumnsName] 113 | columns = append(columns, column{ 114 | name: r.ColumnsName, 115 | tp: parseType(r.Type), 116 | nullable: r.Nullable == "YES", 117 | isPk: isPk, 118 | pkOrder: int(pkOrder), 119 | }) 120 | } 121 | res[tableName] = columns 122 | } 123 | return res 124 | } 125 | 126 | func parseType(tp string) spannerType { 127 | switch tp { 128 | case "INT64": 129 | return tpInt64 130 | case "FLOAT64": 131 | return tpFloat64 132 | case "NUMERIC": 133 | return tpNumeric 134 | case "BOOL": 135 | return tpBool 136 | case "DATE": 137 | return tpDate 138 | case "TIMESTAMP": 139 | return tpTimestamp 140 | case "ARRAY": 141 | return tpArrayInt64 142 | case "ARRAY": 143 | return tpArrayFloat64 144 | case "ARRAY": 145 | return tpArrayNumeric 146 | case "ARRAY": 147 | return tpArrayBool 148 | case "ARRAY": 149 | return tpArrayDate 150 | case "ARRAY": 151 | return tpArrayTimestamp 152 | } 153 | if strings.HasPrefix(tp, "STRING") { 154 | return tpString 155 | } 156 | if strings.HasPrefix(tp, "BYTES") { 157 | return tpBytes 158 | } 159 | if strings.HasPrefix(tp, "ARRAY INSERT INTO `Singers` (`SingerId`, `Name`) VALUES (@SingerId, @Name) 128 | singerStore.Insert(ctx, tx, &singers) 129 | // -> INSERT INTO `Singers` (`SingerId`, `Name`) VALUES (@SingerId_0, @Name_0), (@SingerId_1, @Name_1) 130 | 131 | singerStore.Update(ctx, tx, singer) 132 | // -> UPDATE `Singers` SET `Name`=@Name WHERE `SingerId`=@w_SingerId 133 | singerStore.Update(ctx, tx, &singers) 134 | // -> UPDATE `Singers` SET `Name`=@Name WHERE `SingerId`=@w_SingerId 135 | // -> UPDATE `Singers` SET `Name`=@Name WHERE `SingerId`=@w_SingerId 136 | 137 | singerStore.Delete(ctx, tx, singer) 138 | // -> DELETE FROM `Singers` WHERE `SingerId`=@w_SingerId 139 | singerStore.Delete(ctx, tx, &singers) 140 | // -> DELETE FROM `Singers` WHERE (`SingerId`=@w_SingerId_0) OR (`SingerId`=@w_SingerId_1) 141 | } 142 | 143 | // Embedding examples 144 | type SingerStore struct { 145 | spnr.DML // use spnr.Mutation for mutation API 146 | } 147 | 148 | func NewSingerStore() *SingerStore { 149 | return &SingerStore{DML: *spnr.NewDML("Singers")} 150 | } 151 | 152 | // Any methods you want to add 153 | func (s *SingerStore) GetCount(ctx context.Context, tx spnr.Transaction, cnt any) error { 154 | query := "select count(*) as cnt from Singers" 155 | return s.Reader(ctx, tx).QueryValue(query, nil, cnt) 156 | } 157 | 158 | func useSingerStore(ctx context.Context, client *spanner.Client) { 159 | singerStore := NewSingerStore() 160 | 161 | client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 162 | // You can use all operations that spnr.DML has 163 | singerStore.Insert(ctx, tx, &Singer{SingerID: "a", Name: "Alice"}) 164 | var singer Singer 165 | singerStore.Reader(ctx, tx).FindOne(spanner.Key{"a"}, &singer) 166 | 167 | // And you can use the methods you added !! 168 | var cnt int 169 | singerStore.GetCount(ctx, tx, &cnt) 170 | 171 | return nil 172 | }) 173 | } 174 | -------------------------------------------------------------------------------- /init_test.go: -------------------------------------------------------------------------------- 1 | package spnr 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "math/big" 7 | "os" 8 | "testing" 9 | "time" 10 | 11 | "cloud.google.com/go/civil" 12 | "cloud.google.com/go/spanner" 13 | database "cloud.google.com/go/spanner/admin/database/apiv1" 14 | instance "cloud.google.com/go/spanner/admin/instance/apiv1" 15 | "github.com/testcontainers/testcontainers-go" 16 | "github.com/testcontainers/testcontainers-go/wait" 17 | databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 18 | instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" 19 | ) 20 | 21 | const ( 22 | instanceName = "test" 23 | databaseName = "test" 24 | projectID = "projects/test-project" 25 | instanceID = projectID + "/instances/" + instanceName 26 | databaseID = instanceID + "/databases/" + databaseName 27 | ) 28 | 29 | var ( 30 | insAdminClient *instance.InstanceAdminClient 31 | adminClient *database.DatabaseAdminClient 32 | dataClient *spanner.Client 33 | testRecord1 = &Test{ 34 | String: "testId1", 35 | Bytes: []byte{1}, 36 | Int64: 10, 37 | Float64: 84.217403, 38 | Numeric: *big.NewRat(17893, 8473), 39 | Date: civil.DateOf(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), 40 | Timestamp: time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC), 41 | NullString: NewNullString("a"), 42 | NullInt64: NewNullInt64(100), 43 | NullNumeric: NewNullNumeric(53, 10384), 44 | ArrayInt64: []int64{1, 2, 3}, 45 | ArrayBytes: [][]byte{{80}, {90}}, 46 | } 47 | testRecord2 = &Test{ 48 | String: "testId2", 49 | Bytes: []byte{2}, 50 | Int64: 20, 51 | Date: civil.DateOf(time.Date(1999, 1, 1, 0, 0, 0, 0, time.UTC)), 52 | Timestamp: time.Date(2999, 1, 1, 0, 0, 0, 0, time.UTC), 53 | NullString: NewNullString("b"), 54 | NullInt64: NewNullInt64(200), 55 | ArrayInt64: []int64{4}, 56 | } 57 | testRepository = NewMutation("Test") 58 | ) 59 | 60 | type Test struct { 61 | String string `spanner:"String" pk:"1"` 62 | Bytes []byte `spanner:"Bytes"` 63 | Int64 int64 `spanner:"Int64" pk:"2"` 64 | Float64 float64 `spanner:"Float64"` 65 | Numeric big.Rat `spanner:"Numeric"` 66 | Bool bool `spanner:"Bool"` 67 | Date civil.Date `spanner:"Date"` 68 | Timestamp time.Time `spanner:"Timestamp"` 69 | NullString spanner.NullString `spanner:"NullString"` 70 | NullInt64 spanner.NullInt64 `spanner:"NullInt64"` 71 | NullFloat64 spanner.NullFloat64 `spanner:"NullFloat64"` 72 | NullNumeric spanner.NullNumeric `spanner:"NullNumeric"` 73 | NullBool spanner.NullBool `spanner:"NullBool"` 74 | NullDate spanner.NullDate `spanner:"NullDate"` 75 | NullTimestamp spanner.NullTime `spanner:"NullTimestamp"` 76 | ArrayString []string `spanner:"ArrayString"` 77 | ArrayBytes [][]byte `spanner:"ArrayBytes"` 78 | ArrayInt64 []int64 `spanner:"ArrayInt64"` 79 | ArrayFloat64 []float64 `spanner:"ArrayFloat64"` 80 | ArrayNumeric []big.Rat `spanner:"ArrayNumeric"` 81 | ArrayBool []bool `spanner:"ArrayBool"` 82 | ArrayDate []civil.Date `spanner:"ArrayDate"` 83 | ArrayTimestamp []time.Time `spanner:"ArrayTimestamp"` 84 | } 85 | 86 | type TestOrderChanged struct { 87 | ArrayString []string `spanner:"ArrayString"` 88 | ArrayBytes [][]byte `spanner:"ArrayBytes"` 89 | ArrayInt64 []int64 `spanner:"ArrayInt64"` 90 | ArrayFloat64 []float64 `spanner:"ArrayFloat64"` 91 | ArrayNumeric []big.Rat `spanner:"ArrayNumeric"` 92 | ArrayBool []bool `spanner:"ArrayBool"` 93 | ArrayDate []civil.Date `spanner:"ArrayDate"` 94 | ArrayTimestamp []time.Time `spanner:"ArrayTimestamp"` 95 | String string `spanner:"String" pk:"1"` 96 | Bytes []byte `spanner:"Bytes"` 97 | Int64 int64 `spanner:"Int64"` 98 | Float64 float64 `spanner:"Float64"` 99 | Numeric big.Rat `spanner:"Numeric"` 100 | Bool bool `spanner:"Bool"` 101 | Date civil.Date `spanner:"Date"` 102 | Timestamp time.Time `spanner:"Timestamp"` 103 | NullString spanner.NullString `spanner:"NullString"` 104 | NullInt64 spanner.NullInt64 `spanner:"NullInt64"` 105 | NullFloat64 spanner.NullFloat64 `spanner:"NullFloat64"` 106 | NullNumeric spanner.NullNumeric `spanner:"NullNumeric"` 107 | NullBool spanner.NullBool `spanner:"NullBool"` 108 | NullDate spanner.NullDate `spanner:"NullDate"` 109 | NullTimestamp spanner.NullTime `spanner:"NullTimestamp"` 110 | } 111 | 112 | func TestMain(m *testing.M) { 113 | ctx := context.Background() 114 | c, err := initSpannerContainer(ctx) 115 | if c != nil { 116 | defer c.Terminate(ctx) //nolint:errcheck 117 | } 118 | if err != nil { 119 | panic(err) 120 | } 121 | if err = initClients(ctx, databaseID); err != nil { 122 | panic(err) 123 | } 124 | if err = initDatabase(ctx); err != nil { 125 | panic(err) 126 | } 127 | os.Exit(m.Run()) 128 | } 129 | 130 | func initSpannerContainer(ctx context.Context) (testcontainers.Container, error) { 131 | req := testcontainers.ContainerRequest{ 132 | Image: "gcr.io/cloud-spanner-emulator/emulator:1.3.0", 133 | ExposedPorts: []string{"9010/tcp"}, 134 | WaitingFor: wait.ForLog("gateway.go:142: gRPC server listening at 0.0.0.0:9010"), 135 | } 136 | spannerC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 137 | ContainerRequest: req, 138 | Started: true, 139 | }) 140 | if err != nil { 141 | return nil, err 142 | } 143 | h, err := spannerC.Host(ctx) 144 | if err != nil { 145 | return nil, err 146 | } 147 | p, err := spannerC.MappedPort(ctx, "9010") 148 | if err != nil { 149 | return nil, err 150 | } 151 | return spannerC, os.Setenv("SPANNER_EMULATOR_HOST", fmt.Sprintf("%s:%s", h, p.Port())) 152 | } 153 | 154 | func initClients(ctx context.Context, databaseId string) (err error) { 155 | insAdminClient, err = instance.NewInstanceAdminClient(ctx) 156 | if err != nil { 157 | return err 158 | } 159 | adminClient, err = database.NewDatabaseAdminClient(ctx) 160 | if err != nil { 161 | return err 162 | } 163 | dataClient, err = spanner.NewClient(ctx, databaseId) 164 | return err 165 | } 166 | 167 | func initDatabase(ctx context.Context) (err error) { 168 | createInstanceReq := &instancepb.CreateInstanceRequest{ 169 | Parent: projectID, 170 | Instance: &instancepb.Instance{ 171 | Name: instanceID, 172 | Config: projectID + "/instanceConfigs/test", 173 | DisplayName: instanceName, 174 | NodeCount: 1, 175 | }, 176 | InstanceId: instanceName, 177 | } 178 | ciOp, err := insAdminClient.CreateInstance(ctx, createInstanceReq) 179 | if err != nil { 180 | return err 181 | } 182 | if _, err = ciOp.Wait(ctx); err != nil { 183 | return err 184 | } 185 | 186 | b, err := os.ReadFile("testdata/test.sql") 187 | if err != nil { 188 | return err 189 | } 190 | createDatabaseReq := &databasepb.CreateDatabaseRequest{ 191 | Parent: instanceID, 192 | CreateStatement: "CREATE DATABASE " + databaseName, 193 | ExtraStatements: []string{string(b)}, 194 | } 195 | cdOp, err := adminClient.CreateDatabase(ctx, createDatabaseReq) 196 | if err != nil { 197 | return err 198 | } 199 | _, err = cdOp.Wait(ctx) 200 | if err != nil { 201 | return err 202 | } 203 | 204 | return err 205 | } 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ORM for Cloud Spanner to boost your productivity 🚀 4 | 5 | [![Go Reference](https://pkg.go.dev/badge/github.com/kanjih/go-spnr/v2.svg)](https://pkg.go.dev/github.com/kanjih/go-spnr/v2) 6 | [![Actions Status](https://github.com/kanjih/go-spnr/workflows/test/badge.svg?branch=main)](https://github.com/kanjih/go-spnr/actions) 7 | 8 | 9 | ## Example 🔧 10 | ```go 11 | package main 12 | 13 | import ( 14 | "cloud.google.com/go/spanner" 15 | "context" 16 | "github.com/kanjih/go-spnr" 17 | ) 18 | 19 | type Singer struct { 20 | // spnr supports 2 types of tags. 21 | // - spanner: spanner column name 22 | // - pk: primary key order 23 | SingerID string `spanner:"SingerId" pk:"1"` 24 | Name string `spanner:"Name"` 25 | } 26 | 27 | func main() { 28 | ctx := context.Background() 29 | client, _ := spanner.NewClient(ctx, "projects/{project_id}/instances/{instance_id}/databases/{database_id}") 30 | 31 | // initialize 32 | singerStore := spnr.New("Singers") // specify table name 33 | 34 | // save record (spnr supports both Mutation API & DML!) 35 | singerStore.ApplyInsertOrUpdate(ctx, client, &Singer{SingerID: "a", Name: "Alice"}) 36 | 37 | // fetch record 38 | var singer Singer 39 | singerStore.Reader(ctx, client.Single()).FindOne(spanner.Key{"a"}, &singer) 40 | 41 | // fetch record using raw query 42 | var singers []Singer 43 | query := "select * from Singers where SingerId=@singerId" 44 | params := map[string]any{"singerId": "a"} 45 | singerStore.Reader(ctx, client.Single()).Query(query, params, &singers) 46 | } 47 | ``` 48 | 49 | ## Features 50 | - Supports both **Mutation API** & **DML** 51 | - Supports code generation to map records 52 | - Supports raw SQLs for complicated cases 53 | 54 | spnr is designed ... 55 | - 🙆‍♂️ for reducing boliderplate codes (i.e. mapping selected records to struct or write simple insert/update/delete operations) 56 | - 🙅‍♀️ not for hiding queries executed in background (spnr doesn't support abstractions for complicated operations) 57 | 58 | ## Table of contents 59 | - [Installation](#installation) 60 | - spnr APIs 61 | - [Read operations](#read-operations) 62 | - [Mutation API](#mutation-api) 63 | - [DML](#dml) 64 | - [Embedding](#embedding) 65 | - [Code generation](#code-generation) 66 | - [Helper functions](#helper-functions) 67 | 68 | ## Installation 69 | ``` 70 | go get github.com/kanjih/go-spnr/v2 71 | ``` 72 | \* v2 requires go 1.18 or later. If you use previous go versions please use v1. 73 | 74 | ## Read operations 75 | spnr provides the following types of read operations 💪 76 | 1. Select records using primary keys 77 | 2. Select one column using primary keys 78 | 3. Select records using query 79 | 4. Select one value using query 80 | 81 | ### 1. Select records using primary keys 82 | ```go 83 | var singer Singer 84 | singerStore.Reader(ctx, tx).FindOne(spanner.Key{"a"}, &singer) 85 | 86 | var singers []Singer 87 | keys := spanner.KeySetFromKeys(spanner.Key{"a"}, spanner.Key{"b"}) 88 | singerStore.Reader(ctx, tx).FindAll(keys, &singers) 89 | ``` 90 | 91 | #### 📝 Note 92 | `tx` is the transaction object. You can get it by calling `spanner.Client.ReadOnly(ReadWrite)Transaction`, or `spanner.Client.Single` method. 93 | 94 | ### 2. Select one column using primary keys 95 | ```go 96 | var name string 97 | singerStore.Reader(ctx, tx).GetColumn(spanner.Key{"a"}, "Name", &name) 98 | 99 | var names []string 100 | keys := spanner.KeySetFromKeys(spanner.Key{"a"}, spanner.Key{"b"}) 101 | singerStore.Reader(ctx, tx).GetColumnAll(keys, "Name", &names) 102 | ``` 103 | 104 | #### In the case you want to fetch multiple columns 105 | Making temporal struct to map columns is the best solution. 106 | ```go 107 | type cols struct { 108 | Name string `spanner:"Name"` 109 | Score spanner.NullInt64 `spanner:"Score"` 110 | } 111 | var res cols 112 | singerStore.Reader(ctx, tx).FindOne(spanner.Key{"1"}, &res) 113 | ``` 114 | 115 | ### 3. Select records using query 116 | ```go 117 | var singer Singer 118 | query := "select * from `Singers` where SingerId=@singerId" 119 | params := map[string]any{"singerId": "a"} 120 | singerStore.Reader(ctx, tx).QueryOne(query, params, &singer) 121 | 122 | var singers []Singer 123 | query = "select * from Singers" 124 | singerStore.Reader(ctx, tx).Query(query, nil, &singers) 125 | ``` 126 | 127 | ### 4. Select one value using query 128 | ```go 129 | var cnt int64 130 | query := "select count(*) as cnt from Singers" 131 | singerStore.Reader(ctx, tx).QueryValue(query, nil, &cnt) 132 | ``` 133 | 134 | ### * Notes 135 | - `FindOne`, `GetColumn` method uses `ReadRow` method of `spanner.ReadWrite(ReadOnly)Transaction`. 136 | - `FindAll`, `GetColumnAll` method uses `Read` method. 137 | - `QueryOne`, `Query` Method uses `Query` method. 138 | 139 | ## Mutation API 140 | Executing mutation API using spnr is badly simple! Here's the example 👇 141 | ```go 142 | singer := &Singer{SingerID: "a", Name: "Alice"} 143 | singers := []Singer{{SingerID: "b", Name: "Bob"}, {SingerID: "c", Name: "Carol"}} 144 | 145 | singerStore := spnr.New("Singers") // specify table name 146 | 147 | singerStore.InsertOrUpdate(tx, singer) // Insert or update 148 | singerStore.InsertOrUpdate(tx, &singers) // Insert or update multiple records 149 | 150 | singerStore.Update(tx, singer) // Update 151 | singerStore.Update(tx, &singers) // Update multple records 152 | 153 | singerStore.Delete(tx, singer) // Delete 154 | singerStore.Delete(tx, &singers) // Delete multiple records 155 | ``` 156 | 157 | Don't want to use in transaction? You can use `ApplyXXX`. 158 | ```go 159 | singerStore.ApplyInsertOrUpdate(ctx, client, singer) // client is spanner.Dataclient 160 | singerStore.ApplyDelete(ctx, client, &singers) 161 | ``` 162 | 163 | ## DML 164 | spnr parses struct then build DML 💪 165 | ```go 166 | singer := &Singer{SingerID: "a", Name: "Alice"} 167 | singers := []Singer{{SingerID: "b", Name: "Bob"}, {SingerID: "c", Name: "Carol"}} 168 | 169 | singerStore := spnr.NewDML("Singers") // specify table name 170 | 171 | singerStore.Insert(ctx, tx, singer) 172 | // -> INSERT INTO `Singers` (`SingerId`, `Name`) VALUES (@SingerId, @Name) 173 | 174 | singerStore.Insert(ctx, tx, &singers) 175 | // -> INSERT INTO `Singers` (`SingerId`, `Name`) VALUES (@SingerId_0, @Name_0), (@SingerId_1, @Name_1) 176 | 177 | singerStore.Update(ctx, tx, singer) 178 | // -> UPDATE `Singers` SET `Name`=@Name WHERE `SingerId`=@w_SingerId 179 | singerStore.Update(ctx, tx, &singers) 180 | // -> UPDATE `Singers` SET `Name`=@Name WHERE `SingerId`=@w_SingerId 181 | // -> UPDATE `Singers` SET `Name`=@Name WHERE `SingerId`=@w_SingerId 182 | 183 | singerStore.Delete(ctx, tx, singer) 184 | // -> DELETE FROM `Singers` WHERE `SingerId`=@w_SingerId 185 | ``` 186 | 187 | ### Want to use raw SQL? 188 | You don't need spnr in this case! Plain spanner SDK is enough. 189 | ```go 190 | sql := "UPDATE `Singers` SET `Name` = xx WHERE `Id` = @Id" 191 | params := map[string]any 192 | spannerClient.Update(tx, spanner.Statement{SQL: sql, Params: params}) 193 | ``` 194 | 195 | ## Embedding 196 | spnr is also designed to use with embedding.
197 | You can make structs to manipulate records for each table & can add any methods you want. 198 | 199 | ```go 200 | type SingerStore struct { 201 | spnr.DML // use spnr.Mutation for mutation API 202 | } 203 | 204 | func NewSingerStore() *SingerStore { 205 | return &SingerStore{DML: *spnr.NewDML("Singers")} 206 | } 207 | 208 | // Any methods you want to add 209 | func (s *SingerStore) GetCount(ctx context.Context, tx spnr.Transaction, cnt any) error { 210 | query := "select count(*) as cnt from Singers" 211 | return s.Reader(ctx, tx).Query(query, nil, &cnt) 212 | } 213 | 214 | func useSingerStore(ctx context.Context, client *spanner.Client) { 215 | singerStore := NewSingerStore() 216 | 217 | client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 218 | // You can use all operations that spnr.DML has 219 | singerStore.Insert(ctx, tx, &Singer{SingerID: "a", Name: "Alice"}) 220 | var singer Singer 221 | singerStore.Reader(ctx, tx).FindOne(spanner.Key{"a"}, &singer) 222 | 223 | // And you can use the methods you added !! 224 | var cnt int 225 | singerStore.GetCount(ctx, tx, &cnt) 226 | 227 | return nil 228 | }) 229 | } 230 | ``` 231 | 232 | ## Code generation 233 | Tired to write struct code to map records for every table?
234 | Don't worry! spnr provides code generation 🚀 235 | ```sh 236 | go install github.com/kanjih/go-spnr/cmd/spnr@latest 237 | spnr build -p {PROJECT_ID} -i {INSTANCE_ID} -d {DATABASE_ID} -n {PACKAGE_NAME} -o {OUTPUT_DIR} 238 | ``` 239 | 240 | ## Helper functions 241 | spnr provides some helper functions to reduce boilerplates. 242 | - **`NewNullXXX`** 243 | - `spanner.NullString{StringVal: "a", Valid: true}` can be `spnr.NewNullString("a")` 244 | - **`ToKeySets`** 245 | - You can convert slice to keysets using `spnr.ToKeySets([]string{"a", "b"})` 246 | 247 | Love reporting issues! 248 | 249 | [godev-image]: https://pkg.go.dev/badge/github.com/kanjih/go-spnr 250 | [godev-url]: https://pkg.go.dev/github.com/kanjih/go-spnr 251 | -------------------------------------------------------------------------------- /internal/examples/examples_test.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "cloud.google.com/go/spanner" 10 | database "cloud.google.com/go/spanner/admin/database/apiv1" 11 | instance "cloud.google.com/go/spanner/admin/instance/apiv1" 12 | "github.com/kanjih/go-spnr" 13 | "github.com/testcontainers/testcontainers-go" 14 | "github.com/testcontainers/testcontainers-go/wait" 15 | databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 16 | instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" 17 | "gotest.tools/assert" 18 | ) 19 | 20 | const ( 21 | instanceName = "test" 22 | databaseName = "test" 23 | projectID = "projects/test-project" 24 | instanceID = projectID + "/instances/" + instanceName 25 | databaseID = instanceID + "/databases/" + databaseName 26 | ) 27 | 28 | var ( 29 | insAdminClient *instance.InstanceAdminClient 30 | adminClient *database.DatabaseAdminClient 31 | client *spanner.Client 32 | singer = &Singer{SingerID: "a", Name: "Alice"} 33 | singers = []Singer{{SingerID: "b", Name: "Bob"}, {SingerID: "c", Name: "Carol"}} 34 | ) 35 | 36 | func TestExample(t *testing.T) { 37 | singer := &Singer{SingerID: "a", Name: "Alice"} 38 | ctx := context.Background() 39 | singerStore := spnr.New("Singers") 40 | _, err := singerStore.ApplyInsertOrUpdate(ctx, client, singer) 41 | assert.NilError(t, err) 42 | 43 | var fetched Singer 44 | err = singerStore.Reader(ctx, client.Single()).FindOne(spanner.Key{"a"}, &fetched) 45 | assert.NilError(t, err) 46 | assert.Equal(t, *singer, fetched) 47 | 48 | var singers []Singer 49 | query := "select * from Singers where SingerId=@singerId" 50 | params := map[string]any{"singerId": "a"} 51 | err = singerStore.Reader(ctx, client.Single()).Query(query, params, &singers) 52 | assert.NilError(t, err) 53 | assert.Equal(t, 1, len(singers)) 54 | assert.Equal(t, *singer, singers[0]) 55 | assert.NilError(t, deleteAllSingers()) 56 | 57 | } 58 | 59 | func TestSelectRecordsUsingPrimaryKeys(t *testing.T) { 60 | ctx := context.Background() 61 | singerStore := spnr.New("Singers") 62 | _, err := singerStore.ApplyInsertOrUpdate(ctx, client, singer) 63 | assert.NilError(t, err) 64 | _, err = singerStore.ApplyInsertOrUpdate(ctx, client, &singers) 65 | assert.NilError(t, err) 66 | 67 | var fetched Singer 68 | err = singerStore.Reader(ctx, client.Single()).FindOne(spanner.Key{"a"}, &fetched) 69 | assert.NilError(t, err) 70 | assert.Equal(t, *singer, fetched) 71 | 72 | var fetchedSingers []Singer 73 | keys := spanner.KeySetFromKeys(spanner.Key{"a"}, spanner.Key{"b"}) 74 | err = singerStore.Reader(ctx, client.Single()).FindAll(keys, &fetchedSingers) 75 | assert.NilError(t, err) 76 | assert.Equal(t, *singer, fetchedSingers[0]) 77 | assert.Equal(t, singers[0], fetchedSingers[1]) 78 | 79 | var name string 80 | err = singerStore.Reader(ctx, client.Single()).GetColumn(spanner.Key{"a"}, "Name", &name) 81 | assert.NilError(t, err) 82 | assert.Equal(t, singer.Name, name) 83 | 84 | var names []string 85 | err = singerStore.Reader(ctx, client.Single()).GetColumnAll(keys, "Name", &names) 86 | assert.NilError(t, err) 87 | assert.Equal(t, singer.Name, names[0]) 88 | assert.Equal(t, singers[0].Name, names[1]) 89 | 90 | assert.NilError(t, deleteAllSingers()) 91 | } 92 | 93 | func TestSelectMultipleColumnsUsingPrimaryKeys(t *testing.T) { 94 | ctx := context.Background() 95 | album := &Album{ 96 | SingerID: "a", 97 | AlbumID: 1, 98 | Title: spnr.NewNullString("test"), 99 | } 100 | albumStore := spnr.NewMutationWithOptions("Albums", &spnr.Options{LogEnabled: true}) 101 | _, err := albumStore.ApplyInsertOrUpdate(ctx, client, album) 102 | assert.NilError(t, err) 103 | 104 | type cols struct { 105 | AlbumID int64 `spanner:"AlbumId"` 106 | Title spanner.NullString `spanner:"Title"` 107 | } 108 | var res cols 109 | err = albumStore.Reader(ctx, client.Single()).FindOne(spanner.Key{"a", 1}, &res) 110 | assert.NilError(t, err) 111 | assert.Equal(t, album.AlbumID, res.AlbumID) 112 | assert.Equal(t, album.Title, res.Title) 113 | } 114 | 115 | func TestSelectRecordsUsingQuery(t *testing.T) { 116 | ctx := context.Background() 117 | singerStore := spnr.New("Singers") 118 | _, err := singerStore.ApplyInsertOrUpdate(ctx, client, singer) 119 | assert.NilError(t, err) 120 | _, err = singerStore.ApplyInsertOrUpdate(ctx, client, &singers) 121 | assert.NilError(t, err) 122 | 123 | var fetched Singer 124 | query := "select * from `Singers` where SingerId=@singerId" 125 | params := map[string]any{"singerId": "a"} 126 | err = singerStore.Reader(ctx, client.Single()).QueryOne(query, params, &fetched) 127 | assert.NilError(t, err) 128 | assert.Equal(t, *singer, fetched) 129 | 130 | var fetchedSingers []Singer 131 | query = "select * from Singers order by SingerId" 132 | err = singerStore.Reader(ctx, client.Single()).Query(query, nil, &fetchedSingers) 133 | assert.NilError(t, err) 134 | assert.Equal(t, *singer, fetchedSingers[0]) 135 | assert.Equal(t, singers[0], fetchedSingers[1]) 136 | assert.Equal(t, singers[1], fetchedSingers[2]) 137 | 138 | assert.NilError(t, deleteAllSingers()) 139 | } 140 | 141 | func TestSelectOneValueUsingQuery(t *testing.T) { 142 | ctx := context.Background() 143 | singerStore := spnr.New("Singers") 144 | _, err := singerStore.ApplyInsertOrUpdate(ctx, client, &singers) 145 | assert.NilError(t, err) 146 | 147 | var cnt int64 148 | query := "select count(*) as cnt from Singers" 149 | err = singerStore.Reader(ctx, client.Single()).QueryValue(query, nil, &cnt) 150 | assert.NilError(t, err) 151 | assert.Equal(t, int64(2), cnt) 152 | 153 | assert.NilError(t, deleteAllSingers()) 154 | } 155 | 156 | func TestMutationAPI(t *testing.T) { 157 | ctx := context.Background() 158 | singerStore := spnr.New("Singers") 159 | 160 | client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 161 | err := singerStore.InsertOrUpdate(tx, singer) 162 | assert.NilError(t, err) 163 | err = singerStore.InsertOrUpdate(tx, &singers) 164 | assert.NilError(t, err) 165 | var cnt int64 166 | query := "select count(*) as cnt from Singers" 167 | err = singerStore.Reader(ctx, tx).QueryValue(query, nil, &cnt) 168 | assert.NilError(t, err) 169 | assert.Equal(t, int64(0), cnt) 170 | return nil 171 | }) 172 | var cnt int64 173 | query := "select count(*) as cnt from Singers" 174 | err := singerStore.Reader(ctx, client.Single()).QueryValue(query, nil, &cnt) 175 | assert.NilError(t, err) 176 | assert.Equal(t, int64(3), cnt) 177 | 178 | var fetched Singer 179 | updatedSinger := *singer 180 | updatedSinger.Name = "Mallory" 181 | 182 | var fetchedSingers []Singer 183 | updatedSinger1 := singers[0] 184 | updatedSinger2 := singers[1] 185 | updatedSinger1.Name = "Marvin" 186 | updatedSinger2.Name = "Mallet" 187 | keySet := spanner.KeySetFromKeys(spanner.Key{singers[0].SingerID}, spanner.Key{singers[1].SingerID}) 188 | 189 | client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 190 | err = singerStore.Update(tx, &updatedSinger) 191 | assert.NilError(t, err) 192 | 193 | err = singerStore.Reader(ctx, tx).FindOne(spanner.Key{singer.SingerID}, &fetched) 194 | assert.NilError(t, err) 195 | assert.Equal(t, *singer, fetched) 196 | 197 | err = singerStore.Update(tx, &([]Singer{updatedSinger1, updatedSinger2})) 198 | assert.NilError(t, err) 199 | 200 | err = singerStore.Reader(ctx, tx).FindAll(keySet, &fetchedSingers) 201 | assert.NilError(t, err) 202 | assert.Equal(t, singers[0], fetchedSingers[0]) 203 | assert.Equal(t, singers[1], fetchedSingers[1]) 204 | return nil 205 | }) 206 | 207 | err = singerStore.Reader(ctx, client.Single()).FindOne(spanner.Key{singer.SingerID}, &fetched) 208 | assert.NilError(t, err) 209 | assert.Equal(t, updatedSinger, fetched) 210 | 211 | fetchedSingers = nil 212 | err = singerStore.Reader(ctx, client.Single()).FindAll(keySet, &fetchedSingers) 213 | assert.NilError(t, err) 214 | assert.Equal(t, updatedSinger1, fetchedSingers[0]) 215 | assert.Equal(t, updatedSinger2, fetchedSingers[1]) 216 | 217 | client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 218 | err := singerStore.Delete(tx, singer) 219 | assert.NilError(t, err) 220 | err = singerStore.Reader(ctx, tx).FindOne(spanner.Key{singer.SingerID}, &fetched) 221 | assert.NilError(t, err) 222 | return nil 223 | }) 224 | err = singerStore.Reader(ctx, client.Single()).FindOne(spanner.Key{singer.SingerID}, &fetched) 225 | assert.Equal(t, spnr.ErrNotFound, err) 226 | 227 | client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 228 | err := singerStore.Delete(tx, &singers) 229 | assert.NilError(t, err) 230 | fetchedSingers = nil 231 | err = singerStore.Reader(ctx, tx).FindAll(keySet, &fetchedSingers) 232 | assert.NilError(t, err) 233 | assert.Equal(t, 2, len(fetchedSingers)) 234 | return nil 235 | }) 236 | fetchedSingers = nil 237 | err = singerStore.Reader(ctx, client.Single()).FindAll(keySet, &fetchedSingers) 238 | assert.NilError(t, err) 239 | assert.Equal(t, 0, len(fetchedSingers)) 240 | 241 | assert.NilError(t, deleteAllSingers()) 242 | } 243 | 244 | func TestDML(t *testing.T) { 245 | singerStore := spnr.NewDMLWithOptions("Singers", &spnr.Options{LogEnabled: true}) 246 | client.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 247 | _, err := singerStore.Insert(ctx, tx, singer) 248 | assert.NilError(t, err) 249 | var fetched Singer 250 | err = singerStore.Reader(ctx, tx).FindOne(spanner.Key{singer.SingerID}, &fetched) 251 | assert.NilError(t, err) 252 | assert.Equal(t, singer.SingerID, fetched.SingerID) 253 | assert.Equal(t, singer.Name, fetched.Name) 254 | 255 | _, err = singerStore.Insert(ctx, tx, &singers) 256 | assert.NilError(t, err) 257 | err = singerStore.Reader(ctx, tx).FindOne(spanner.Key{singers[0].SingerID}, &fetched) 258 | assert.NilError(t, err) 259 | assert.Equal(t, singers[0].SingerID, fetched.SingerID) 260 | assert.Equal(t, singers[0].Name, fetched.Name) 261 | err = singerStore.Reader(ctx, tx).FindOne(spanner.Key{singers[1].SingerID}, &fetched) 262 | assert.NilError(t, err) 263 | assert.Equal(t, singers[1].SingerID, fetched.SingerID) 264 | assert.Equal(t, singers[1].Name, fetched.Name) 265 | 266 | updatedSinger := *singer 267 | updatedSinger.Name = "Mallory" 268 | _, err = singerStore.Update(ctx, tx, &updatedSinger) 269 | assert.NilError(t, err) 270 | err = singerStore.Reader(ctx, tx).FindOne(spanner.Key{updatedSinger.SingerID}, &fetched) 271 | assert.NilError(t, err) 272 | assert.Equal(t, updatedSinger.Name, fetched.Name) 273 | 274 | updatedSinger1 := singers[0] 275 | updatedSinger1.Name = "Marvin" 276 | updatedSinger2 := singers[1] 277 | updatedSinger2.Name = "Mallet" 278 | 279 | _, err = singerStore.Update(ctx, tx, &([]Singer{updatedSinger1, updatedSinger2})) 280 | assert.NilError(t, err) 281 | var fetchedSingers []Singer 282 | keySet := spanner.KeySetFromKeys(spanner.Key{"b"}, spanner.Key{"c"}) 283 | err = singerStore.Reader(ctx, tx).FindAll(keySet, &fetchedSingers) 284 | assert.NilError(t, err) 285 | assert.Equal(t, updatedSinger1.Name, fetchedSingers[0].Name) 286 | assert.Equal(t, updatedSinger2.Name, fetchedSingers[1].Name) 287 | 288 | _, err = singerStore.Delete(ctx, tx, singer) 289 | assert.NilError(t, err) 290 | err = singerStore.Reader(ctx, tx).FindOne(spanner.Key{updatedSinger.SingerID}, &fetched) 291 | assert.Equal(t, spnr.ErrNotFound, err) 292 | 293 | _, err = singerStore.Delete(ctx, tx, &singers) 294 | assert.NilError(t, err) 295 | fetchedSingers = nil 296 | err = singerStore.Reader(ctx, tx).FindAll(keySet, &fetchedSingers) 297 | assert.Equal(t, 0, len(fetchedSingers)) 298 | 299 | return nil 300 | }) 301 | 302 | assert.NilError(t, deleteAllSingers()) 303 | } 304 | 305 | func TestSingerStore(t *testing.T) { 306 | singerStore := NewSingerStore() 307 | client.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 308 | _, err := tx.Update(ctx, spanner.Statement{SQL: "delete from Singers where true"}) 309 | assert.NilError(t, err) 310 | 311 | _, err = singerStore.Insert(ctx, tx, singer) 312 | assert.NilError(t, err) 313 | var cnt int64 314 | err = singerStore.GetCount(ctx, tx, &cnt) 315 | assert.NilError(t, err) 316 | assert.Equal(t, int64(1), cnt) 317 | 318 | _, err = singerStore.Delete(ctx, tx, singer) 319 | assert.NilError(t, err) 320 | 321 | return nil 322 | }) 323 | 324 | } 325 | 326 | func TestMain(m *testing.M) { 327 | ctx := context.Background() 328 | c, err := initSpannerContainer(ctx) 329 | if c != nil { 330 | defer c.Terminate(ctx) 331 | } 332 | if err != nil { 333 | panic(err) 334 | } 335 | if err = initClients(ctx, databaseID); err != nil { 336 | panic(err) 337 | } 338 | if err = initDatabase(ctx); err != nil { 339 | panic(err) 340 | } 341 | os.Exit(m.Run()) 342 | } 343 | 344 | func initSpannerContainer(ctx context.Context) (testcontainers.Container, error) { 345 | req := testcontainers.ContainerRequest{ 346 | Image: "gcr.io/cloud-spanner-emulator/emulator:1.3.0", 347 | ExposedPorts: []string{"9010/tcp"}, 348 | WaitingFor: wait.ForLog("gateway.go:142: gRPC server listening at 0.0.0.0:9010"), 349 | } 350 | spannerC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 351 | ContainerRequest: req, 352 | Started: true, 353 | }) 354 | if err != nil { 355 | return nil, err 356 | } 357 | h, err := spannerC.Host(ctx) 358 | if err != nil { 359 | return nil, err 360 | } 361 | p, err := spannerC.MappedPort(ctx, "9010") 362 | if err != nil { 363 | return nil, err 364 | } 365 | return spannerC, os.Setenv("SPANNER_EMULATOR_HOST", fmt.Sprintf("%s:%s", h, p.Port())) 366 | } 367 | 368 | func initClients(ctx context.Context, databaseId string) (err error) { 369 | insAdminClient, err = instance.NewInstanceAdminClient(ctx) 370 | if err != nil { 371 | return err 372 | } 373 | adminClient, err = database.NewDatabaseAdminClient(ctx) 374 | if err != nil { 375 | return err 376 | } 377 | client, err = spanner.NewClient(ctx, databaseId) 378 | return err 379 | } 380 | 381 | func initDatabase(ctx context.Context) (err error) { 382 | createInstanceReq := &instancepb.CreateInstanceRequest{ 383 | Parent: projectID, 384 | Instance: &instancepb.Instance{ 385 | Name: instanceID, 386 | Config: projectID + "/instanceConfigs/test", 387 | DisplayName: instanceName, 388 | NodeCount: 1, 389 | }, 390 | InstanceId: instanceName, 391 | } 392 | ciOp, err := insAdminClient.CreateInstance(ctx, createInstanceReq) 393 | if err != nil { 394 | return err 395 | } 396 | if _, err = ciOp.Wait(ctx); err != nil { 397 | return err 398 | } 399 | 400 | createDatabaseReq := &databasepb.CreateDatabaseRequest{ 401 | Parent: instanceID, 402 | CreateStatement: "CREATE DATABASE " + databaseName, 403 | ExtraStatements: []string{ddlSingers, ddlAlbums}, 404 | } 405 | cdOp, err := adminClient.CreateDatabase(ctx, createDatabaseReq) 406 | if err != nil { 407 | return err 408 | } 409 | _, err = cdOp.Wait(ctx) 410 | if err != nil { 411 | return err 412 | } 413 | 414 | return err 415 | } 416 | 417 | func deleteAllSingers() error { 418 | _, err := client.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { 419 | _, err := tx.Update(ctx, spanner.Statement{SQL: "delete from Singers where true"}) 420 | return err 421 | }) 422 | return err 423 | } 424 | --------------------------------------------------------------------------------