├── .gitignore ├── contrib ├── .gitignore ├── downloadH2.sh └── runStandalone.sh ├── example ├── .gitignore ├── tx.go ├── prep-exec.go └── crud.go ├── go.mod ├── .github └── workflows │ └── testing.yml ├── util.go ├── go.sum ├── result.go ├── stmt.go ├── tx.go ├── client.go ├── README.md ├── driver.go ├── conn.go ├── session.go ├── driver_test.go └── transfer.go /.gitignore: -------------------------------------------------------------------------------- 1 | study 2 | -------------------------------------------------------------------------------- /contrib/.gitignore: -------------------------------------------------------------------------------- 1 | h2.jar 2 | -------------------------------------------------------------------------------- /example/.gitignore: -------------------------------------------------------------------------------- 1 | example 2 | -------------------------------------------------------------------------------- /contrib/downloadH2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | VER=1.4.200 3 | curl https://repo1.maven.org/maven2/com/h2database/h2/${VER}/h2-${VER}.jar -o h2.jar 4 | -------------------------------------------------------------------------------- /contrib/runStandalone.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | java -classpath h2.jar org.h2.tools.Server -tcp -tcpAllowOthers -ifNotExists -trace -baseDir tmpData 3 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/jmrobles/h2go 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/pkg/errors v0.9.1 7 | github.com/sirupsen/logrus v1.4.2 8 | golang.org/x/text v0.3.2 9 | ) 10 | -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | on: 3 | - push 4 | - pull_request 5 | jobs: 6 | testing: 7 | name: Go testing 8 | runs-on: ubuntu-18.04 9 | # container: golang:1.15.5-alpine3.12 10 | services: 11 | h2server: 12 | image: jmrobles/h2:1.4.200 13 | ports: 14 | - 9092:9092 15 | env: 16 | H2_OPTIONS: -ifNotExists 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v2 20 | - name: Setup go 21 | uses: actions/setup-go@v2 22 | with: 23 | go-version: '^1.15.4' 24 | - run: go test -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "crypto/sha256" 21 | "fmt" 22 | "strings" 23 | 24 | log "github.com/sirupsen/logrus" 25 | "golang.org/x/text/encoding/unicode" 26 | ) 27 | 28 | func getHashedPassword(username string, password string) ([32]byte, error) { 29 | payload := fmt.Sprintf("%s@%s", strings.ToUpper(username), password) 30 | data, err := unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewEncoder().Bytes([]byte(payload)) 31 | if err != nil { 32 | return [32]byte{}, err 33 | } 34 | return sha256.Sum256(data), nil 35 | } 36 | 37 | // L Log if apply 38 | func L(level log.Level, text string, args ...interface{}) { 39 | if !doLogging { 40 | return 41 | } 42 | log.StandardLogger().Logf(level, text, args...) 43 | } 44 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= 4 | github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= 5 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 6 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 7 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= 10 | github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= 11 | github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 12 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 13 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 14 | golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= 15 | golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 16 | golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= 17 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 18 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 19 | -------------------------------------------------------------------------------- /example/tx.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "context" 21 | "database/sql" 22 | "log" 23 | 24 | _ "github.com/jmrobles/h2go" 25 | ) 26 | 27 | func main() { 28 | log.Printf("H2GO Example") 29 | 30 | conn, err := sql.Open("h2", "h2://sa@localhost/test?mem=true&logging=info") 31 | if err != nil { 32 | log.Fatalf("ERROR: %s", err) 33 | } 34 | // Create table 35 | stmt, err := conn.Prepare("CREATE TABLE test (id int)") 36 | if err != nil { 37 | log.Fatalf("Can't preparate: %s", err) 38 | } 39 | _, err = stmt.Exec() 40 | if err != nil { 41 | log.Fatalf("Can't execute exec: %s", err) 42 | } 43 | // Begin TX for INSERT 44 | ctx := context.Background() 45 | tx, err := conn.BeginTx(ctx, nil) 46 | if err != nil { 47 | log.Fatalf("Can't start tx: %s", err) 48 | } 49 | _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES 10") 50 | if err != nil { 51 | log.Fatalf("Can't execute insert: %s", err) 52 | } 53 | // Commit 54 | err = tx.Commit() 55 | if err != nil { 56 | log.Fatal(err) 57 | } 58 | // Check values 59 | rows, err := conn.Query("SELECT * FROM test") 60 | if err != nil { 61 | log.Fatalf("Can't select: %s", err) 62 | } 63 | for rows.Next() { 64 | var v int 65 | err := rows.Scan(&v) 66 | if err != nil { 67 | log.Printf("Can't scan row") 68 | continue 69 | } 70 | log.Printf("Value: %d", v) 71 | } 72 | log.Printf("End tx") 73 | 74 | } 75 | -------------------------------------------------------------------------------- /result.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "database/sql/driver" 21 | "io" 22 | 23 | "github.com/pkg/errors" 24 | ) 25 | 26 | type h2Result struct { 27 | query string 28 | columns []string 29 | numRows int32 30 | curRow int32 31 | trans *transfer 32 | 33 | // Interface 34 | driver.Rows 35 | } 36 | 37 | // Rows interface 38 | 39 | func (h2r *h2Result) Close() error { 40 | return nil 41 | } 42 | 43 | func (h2r *h2Result) Columns() []string { 44 | return h2r.columns 45 | } 46 | 47 | func (h2r *h2Result) Next(dest []driver.Value) error { 48 | var err error 49 | // log.Printf("LEN: %d", len(dest)) 50 | if h2r.curRow == h2r.numRows { 51 | return io.EOF 52 | } 53 | h2r.curRow++ 54 | next, err := h2r.trans.readBool() 55 | if err != nil { 56 | return err 57 | } 58 | if !next { 59 | return io.EOF 60 | } 61 | // log.Printf(">>> DEST: %v", dest) 62 | for i := range h2r.columns { 63 | v, err := h2r.trans.readValue() 64 | if err != nil { 65 | return errors.Wrapf(err, "Can't read value") 66 | } 67 | dest[i] = driver.Value(v) 68 | } 69 | return nil 70 | } 71 | 72 | type h2ExecResult struct { 73 | nUpdated int32 74 | // Interface 75 | driver.Result 76 | } 77 | 78 | func (h2er *h2ExecResult) LastInsertId() (int64, error) { 79 | return 1, nil 80 | } 81 | 82 | func (h2er *h2ExecResult) RowsAffected() (int64, error) { 83 | return int64(h2er.nUpdated), nil 84 | } 85 | -------------------------------------------------------------------------------- /stmt.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "context" 21 | "database/sql/driver" 22 | ) 23 | 24 | type h2stmt struct { 25 | id int32 26 | oID int32 27 | isQuery bool 28 | isRO bool 29 | numParams int32 30 | parameters []h2parameter 31 | client h2client 32 | query string 33 | // Interfaces 34 | driver.Stmt 35 | driver.StmtQueryContext 36 | driver.StmtExecContext 37 | } 38 | 39 | type h2parameter struct { 40 | kind int32 41 | precission int64 42 | scale int32 43 | nullable bool 44 | } 45 | 46 | // Interface Stmt 47 | func (h2s h2stmt) Close() error { 48 | // TODO: check for action 49 | return nil 50 | } 51 | 52 | func (h2s h2stmt) NumInput() int { 53 | return int(h2s.numParams) 54 | } 55 | 56 | // Interface StmtQueryContext 57 | func (h2s h2stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 58 | cols, nRows, err := h2s.client.sess.executeQuery(&h2s, &h2s.client.trans) 59 | if err != nil { 60 | return nil, err 61 | } 62 | return &h2Result{query: h2s.query, columns: cols, numRows: nRows, trans: &h2s.client.trans, curRow: 0}, nil 63 | } 64 | 65 | // Interface StmtExecContext 66 | func (h2s h2stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 67 | var argsValues []driver.Value 68 | for _, arg := range args { 69 | argsValues = append(argsValues, arg.Value) 70 | } 71 | nUpdated, err := h2s.client.sess.executeQueryUpdate(&h2s, &h2s.client.trans, argsValues) 72 | if err != nil { 73 | return nil, err 74 | } 75 | return &h2ExecResult{nUpdated: nUpdated}, nil 76 | } 77 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "database/sql/driver" 21 | 22 | log "github.com/sirupsen/logrus" 23 | ) 24 | 25 | type h2tx struct { 26 | conn h2Conn 27 | // Interfaces 28 | driver.Tx 29 | } 30 | 31 | // Interface Tx 32 | func (h2t h2tx) Commit() error { 33 | L(log.DebugLevel, "Commit") 34 | stmt, err := h2t.conn.client.sess.prepare2(&h2t.conn.client.trans, "COMMIT") 35 | if err != nil { 36 | return err 37 | } 38 | st, _ := stmt.(h2stmt) 39 | _, err = h2t.conn.client.sess.executeQueryUpdate(&st, &h2t.conn.client.trans, []driver.Value{}) 40 | if err != nil { 41 | return err 42 | } 43 | err = h2t.restoreAutocommit() 44 | if err != nil { 45 | return err 46 | } 47 | return nil 48 | } 49 | 50 | func (h2t h2tx) Rollback() error { 51 | L(log.DebugLevel, "Rollback") 52 | stmt, err := h2t.conn.client.sess.prepare2(&h2t.conn.client.trans, "ROLLBACK") 53 | if err != nil { 54 | return err 55 | } 56 | st, _ := stmt.(h2stmt) 57 | _, err = h2t.conn.client.sess.executeQueryUpdate(&st, &h2t.conn.client.trans, []driver.Value{}) 58 | if err != nil { 59 | return err 60 | } 61 | err = h2t.restoreAutocommit() 62 | if err != nil { 63 | return err 64 | } 65 | return nil 66 | } 67 | 68 | // Helpers 69 | 70 | func (h2t h2tx) restoreAutocommit() error { 71 | stmt, err := h2t.conn.client.sess.prepare2(&h2t.conn.client.trans, "SET AUTOCOMMIT TRUE") 72 | if err != nil { 73 | return err 74 | } 75 | st, _ := stmt.(h2stmt) 76 | _, err = h2t.conn.client.sess.executeQueryUpdate(&st, &h2t.conn.client.trans, []driver.Value{}) 77 | if err != nil { 78 | return err 79 | } 80 | return nil 81 | 82 | } 83 | -------------------------------------------------------------------------------- /example/prep-exec.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "database/sql" 21 | "log" 22 | 23 | _ "github.com/jmrobles/h2go" 24 | ) 25 | 26 | func main() { 27 | log.Printf("H2GO Example") 28 | 29 | conn, err := sql.Open("h2", "h2://sa@localhost/test?mem=true&logging=info") 30 | if err != nil { 31 | log.Fatalf("ERROR: %s", err) 32 | } 33 | // Create table 34 | log.Printf("CREATE TABLE") 35 | stmt, err := conn.Prepare("SELECT 1+1") 36 | if err != nil { 37 | log.Fatalf("Can't preparate: %s", err) 38 | } 39 | rows, err := stmt.Query() 40 | if err != nil { 41 | log.Fatalf("Can't query: %s", err) 42 | } 43 | for rows.Next() { 44 | var v int32 45 | err := rows.Scan(&v) 46 | if err != nil { 47 | log.Fatalf("Can't scan: %s", err) 48 | } 49 | log.Printf("Row: %d", v) 50 | } 51 | 52 | // Exec 53 | stmt, err = conn.Prepare("CREATE TABLE test (id int)") 54 | if err != nil { 55 | log.Fatalf("Can't preparate: %s", err) 56 | } 57 | _, err = stmt.Exec() 58 | if err != nil { 59 | log.Fatalf("Can't execute exec: %s", err) 60 | } 61 | 62 | stmt, err = conn.Prepare("INSERT INTO test VALUES (?)") 63 | if err != nil { 64 | log.Fatalf("Can't preparate: %s", err) 65 | } 66 | v := 123 67 | _, err = stmt.Exec(v) 68 | if err != nil { 69 | log.Fatalf("Can't execute exec: %s", err) 70 | } 71 | // Select 72 | stmt, err = conn.Prepare("SELECT * FROM test") 73 | if err != nil { 74 | log.Fatalf("Can't preparate: %s", err) 75 | } 76 | rows, err = stmt.Query() 77 | if err != nil { 78 | log.Fatalf("Can't query: %s", err) 79 | } 80 | for rows.Next() { 81 | var v int32 82 | err := rows.Scan(&v) 83 | if err != nil { 84 | log.Fatalf("Can't scan: %s", err) 85 | } 86 | log.Printf("Row: %d", v) 87 | } 88 | 89 | } 90 | -------------------------------------------------------------------------------- /example/crud.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "database/sql" 21 | "log" 22 | 23 | _ "github.com/jmrobles/h2go" 24 | ) 25 | 26 | func main() { 27 | log.Printf("H2GO Example") 28 | 29 | conn, err := sql.Open("h2", "h2://sa@localhost/test?mem=true&logging=info") 30 | if err != nil { 31 | log.Fatalf("ERROR: %s", err) 32 | } 33 | // Create table 34 | log.Printf("CREATE TABLE") 35 | ret, err := conn.Exec("CREATE TABLE test (id int not null, name varchar(100))") 36 | if err != nil { 37 | log.Printf("Can't execute sentence: %s", err) 38 | return 39 | } 40 | // Insert 41 | ret, err = conn.Exec("INSERT INTO test VALUES (?, ?)", 42 | 1, "John") 43 | if err != nil { 44 | log.Printf("Can't execute sentence: %s", err) 45 | return 46 | } 47 | lastID, err := ret.LastInsertId() 48 | if err != nil { 49 | log.Printf("Can't get last ID: %s", err) 50 | } 51 | nRows, err := ret.RowsAffected() 52 | if err != nil { 53 | log.Printf("Can't get num rows: %s", err) 54 | } 55 | log.Printf("LastID: %d - NumRowsAffected: %d", lastID, nRows) 56 | // Query 57 | rows, err := conn.Query("SELECT * FROM test") 58 | if err != nil { 59 | log.Printf("Can't execute query: %s", err) 60 | } 61 | for rows.Next() { 62 | var ( 63 | id int 64 | name string 65 | ) 66 | err = rows.Scan(&id, &name) 67 | if err != nil { 68 | log.Printf("Can't scan values in row: %s", err) 69 | continue 70 | } 71 | log.Printf("Row: %d - %s", id, name) 72 | } 73 | rows.Close() 74 | // Update 75 | ret, err = conn.Exec("UPDATE test SET name = 'Juan' WHERE id = 1") 76 | if err != nil { 77 | log.Printf("Can't execute sentence: %s", err) 78 | return 79 | } 80 | lastID, err = ret.LastInsertId() 81 | if err != nil { 82 | log.Printf("Can't get last ID: %s", err) 83 | } 84 | nRows, err = ret.RowsAffected() 85 | if err != nil { 86 | log.Printf("Can't get num rows: %s", err) 87 | } 88 | log.Printf("LastID: %d - NumRowsAffected: %d", lastID, nRows) 89 | // Delete 90 | ret, err = conn.Exec("DELETE FROM test WHERE id = 1") 91 | if err != nil { 92 | log.Printf("Can't execute sentence: %s", err) 93 | return 94 | } 95 | lastID, err = ret.LastInsertId() 96 | if err != nil { 97 | log.Printf("Can't get last ID: %s", err) 98 | } 99 | nRows, err = ret.RowsAffected() 100 | if err != nil { 101 | log.Printf("Can't get num rows: %s", err) 102 | } 103 | log.Printf("LastID: %d - NumRowsAffected: %d", lastID, nRows) 104 | conn.Close() 105 | // time.Sleep(20 * time.Second) 106 | log.Printf("Done") 107 | } 108 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "net" 21 | 22 | log "github.com/sirupsen/logrus" 23 | 24 | "github.com/pkg/errors" 25 | ) 26 | 27 | type h2client struct { 28 | conn net.Conn 29 | trans transfer 30 | sess session 31 | } 32 | 33 | func (c *h2client) doHandshake(ci h2connInfo) error { 34 | var err error 35 | // 1. send min client version 36 | err = c.trans.writeInt32(9) 37 | if err != nil { 38 | return errors.Wrapf(err, "H2 handshake: can't send min client version") 39 | } 40 | // 2. send max client version 41 | err = c.trans.writeInt32(19) 42 | if err != nil { 43 | return errors.Wrapf(err, "H2 handshake: can't send max client version") 44 | } 45 | // 3. Send db name 46 | err = c.trans.writeString(ci.database) 47 | if err != nil { 48 | return errors.Wrapf(err, "H2 handshake: can't send database name") 49 | } 50 | // 4. Send original url 51 | err = c.trans.writeString("jdbc:h2:" + ci.database) 52 | if err != nil { 53 | return errors.Wrapf(err, "H2 handshake: can't send original url") 54 | } 55 | // 5. Send username 56 | err = c.trans.writeString(ci.username) 57 | if err != nil { 58 | return errors.Wrapf(err, "H2 handshake: can't send username") 59 | } 60 | // 6. Send password 61 | hashedPassword, err := getHashedPassword(ci.username, ci.password) 62 | if err != nil { 63 | return errors.Wrapf(err, "H2 handshake: can't hash password") 64 | } 65 | err = c.trans.writeBytes(hashedPassword[:]) 66 | if err != nil { 67 | return errors.Wrapf(err, "H2 handshake: can't hashed password") 68 | } 69 | // 7. Send file password hash 70 | err = c.trans.writeBytes(nil) 71 | if err != nil { 72 | return errors.Wrapf(err, "H2 handshake: can't send hashed file password") 73 | } 74 | // 8. Send aditional properties 75 | // TODO: bynow, 0 properties tos send 76 | err = c.trans.writeInt32(0) 77 | if err != nil { 78 | return errors.Wrapf(err, "H2 handshake: can't send properties") 79 | } 80 | err = c.trans.flush() 81 | if err != nil { 82 | return errors.Wrapf(err, "H2 handshake: can't flush data to socket") 83 | } 84 | // 9. Wait for Status OK ack 85 | code, err := c.trans.readInt32() 86 | if err != nil { 87 | return errors.Wrapf(err, "H2 handshake: can't get H2 Server status code") 88 | } 89 | // 10. Read client version 90 | clientVer, err := c.trans.readInt32() 91 | if err != nil { 92 | return errors.Wrapf(err, "H2 handshake: can't get H2 Server client version ack") 93 | } 94 | L(log.InfoLevel, "H2 server code: %d - client ver: %d", code, clientVer) 95 | return nil 96 | } 97 | 98 | func (c *h2client) close() error { 99 | err := c.sess.close(&c.trans) 100 | if err != nil { 101 | return err 102 | } 103 | // Close client 104 | return c.conn.Close() 105 | } 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Apache H2 Database Go Driver 2 | 3 | __This driver is VERY experimental state__ 4 | 5 | __NOT use for production yet__ 6 | 7 | ## Introduction 8 | 9 | [Apache H2 Database](https://h2database.com) is a very-low footprint database with in-memory capabilities. 10 | 11 | It's written in Java and it's fully ACID compliant. 12 | 13 | You can use H2 as embedded database or via TCP/IP. 14 | 15 | It has interfaces for Postgres protocol and native TCP server. 16 | 17 | ## Motivation 18 | 19 | Until now, using H2 in your Go projects could only be done through the Postgres driver. 20 | 21 | This approach has several cons. The poor error messagens or not being able to use native data types are some of them. 22 | 23 | This pure Go driver uses the native TCP interface. 24 | 25 | ## Pre-requesites 26 | 27 | In "contrib" folder you can find the scripts to download and launch the H2 database server. 28 | You need to have any Java Runtime installed. 29 | 30 | ```bash 31 | cd contrib 32 | ./downloadH2.sh 33 | ./runStandalone.sh 34 | ``` 35 | 36 | ## Usage 37 | 38 | First make sure the H2 server is running in TCP server mode. You can launch using the `runStandalone.sh` or with a command similar to the following: 39 | 40 | ```bash 41 | java -classpath h2.jar org.h2.tools.Server -tcp -tcpAllowOthers -ifNotExists 42 | ``` 43 | 44 | This starts the server at the defaulr port (9092) 45 | 46 | The following example connect to H2 and creates an in-memory database. 47 | 48 | ```go 49 | package main 50 | 51 | import ( 52 | "database/sql" 53 | "log" 54 | _ "github.com/jmrobles/h2go" 55 | ) 56 | 57 | func main() { 58 | conn, err := sql.Open("h2", "h2://sa@localhost/testdb?mem=true") 59 | if err != nil { 60 | log.Fatalf("Can't connet to H2 Database: %s", err) 61 | } 62 | err = conn.Ping() 63 | if err != nil { 64 | log.Fatalf("Can't ping to H2 Database: %s", err) 65 | } 66 | log.Printf("H2 Database connected") 67 | conn.Close() 68 | } 69 | ``` 70 | 71 | In the folder `examples` you can find more examples. 72 | ## Connection string 73 | 74 | In the connection string you must specify: 75 | 76 | - Database driver: `h2` literal 77 | - Username (optional) 78 | - Password (optinal) 79 | - Host: format (:)? 80 | - Database name 81 | - Other connection options 82 | 83 | ### Options 84 | 85 | You can use the following options: 86 | 87 | - mem=(true|false): to use in-memory or in-disk database 88 | - logging=(none|info|debug|error|warn|panic|trace): the common logging level 89 | 90 | 91 | ## Parameters 92 | 93 | For the use of parameters in SQL statement you need to use the `?` placeholder symbol. 94 | 95 | For example: 96 | ```go 97 | conn.Exec("INSERT INTO employees VALUES (?,?,?)", name, age, salary) 98 | ``` 99 | 100 | ## Data types 101 | 102 | The following H2 datatypes are implemented: 103 | 104 | | H2 Data type | Go mapping | 105 | |--------------|------------| 106 | | String | string | 107 | | StringIgnoreCase | string | 108 | | StringFixed | string | 109 | | Bool | bool | 110 | | Short | int16 | 111 | | Int | int32 | 112 | | Long | int64 | 113 | | Float | float32 | 114 | | Double | float64 | 115 | | Byte | byte | 116 | | Bytes | []byte | 117 | | Time | time.Time | 118 | | Time with timezone | time.Time | 119 | | Date | time.Time 120 | | Timestamp | time.Time | 121 | | Timestamp with timezone | time.Time 122 | 123 | ## H2 Supported version 124 | 125 | This driver supports H2 database version 1.4.200 or above. 126 | 127 | ## ToDo 128 | 129 | - Rest of native data types (UUID, JSON, Decimal, ...) 130 | - `NamedValue` interface 131 | - Multiple result sets 132 | - Improve `context` usage (timeouts, ...) 133 | - Submit your issue 134 | 135 | ## Contributors 136 | 137 | [jmrobles](https://jmrobles.medium.com) 138 | 139 | Pull Requests are welcome 140 | 141 | ## License 142 | 143 | MIT License 144 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "context" 21 | "database/sql" 22 | "database/sql/driver" 23 | 24 | "net" 25 | "net/url" 26 | "strconv" 27 | "strings" 28 | 29 | "github.com/pkg/errors" 30 | log "github.com/sirupsen/logrus" 31 | ) 32 | 33 | var doLogging = false 34 | 35 | type h2connInfo struct { 36 | host string 37 | port int 38 | database string 39 | username string 40 | password string 41 | isMem bool 42 | logging bool 43 | 44 | dialer net.Dialer 45 | } 46 | type h2Driver struct { 47 | driver.DriverContext 48 | driver.Driver 49 | } 50 | 51 | type h2Connector struct { 52 | driver.Connector 53 | 54 | ci h2connInfo 55 | driver h2Driver 56 | } 57 | 58 | func (h2d h2Driver) Open(dsn string) (driver.Conn, error) { 59 | ci, err := parseURL(dsn) 60 | L(log.InfoLevel, "Open") 61 | L(log.DebugLevel, "Open with dsn: %s", dsn) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return connect(ci) 66 | } 67 | 68 | func (h2d *h2Driver) OpenConnector(dsn string) (driver.Connector, error) { 69 | L(log.DebugLevel, "OpenConnector") 70 | ci, err := parseURL(dsn) 71 | if err != nil { 72 | return nil, err 73 | } 74 | return &h2Connector{ci: ci, driver: *h2d}, nil 75 | } 76 | 77 | func (h2c *h2Connector) Connect(ctx context.Context) (driver.Conn, error) { 78 | L(log.DebugLevel, "Connect") 79 | return connect(h2c.ci) 80 | } 81 | 82 | func (h2c *h2Connector) Driver() driver.Driver { 83 | return h2c.driver 84 | } 85 | func init() { 86 | sql.Register("h2", &h2Driver{}) 87 | } 88 | 89 | // Helpers 90 | 91 | func parseURL(dsnurl string) (h2connInfo, error) { 92 | var ci h2connInfo 93 | u, err := url.Parse(dsnurl) 94 | if err != nil { 95 | return ci, errors.Wrapf(err, "failed to parse connection url") 96 | } 97 | // Set host 98 | if ci.host = u.Hostname(); len(ci.host) == 0 { 99 | ci.host = "127.0.0.1" 100 | } 101 | // Set port 102 | ci.port, _ = strconv.Atoi(u.Port()) 103 | if ci.port == 0 { 104 | ci.port = defaultH2port 105 | } 106 | // Set database 107 | if ci.database = u.Path; len(ci.database) == 0 { 108 | ci.database = "~/test" 109 | } 110 | // Username & password 111 | userinfo := u.User 112 | if userinfo != nil { 113 | ci.username = userinfo.Username() 114 | if pass, ok := userinfo.Password(); ok { 115 | ci.password = pass 116 | } 117 | } 118 | for k, v := range u.Query() { 119 | var val string 120 | if len(v) > 0 { 121 | val = strings.TrimSpace(v[0]) 122 | } 123 | switch strings.ToLower(k) { 124 | case "mem": 125 | ci.isMem = val == "" || val == "1" || val == "yes" || val == "true" 126 | if ci.isMem { 127 | ci.database = strings.Replace(ci.database, "/", "", 1) 128 | ci.database = "mem:" + ci.database 129 | } 130 | case "logging": 131 | logType := strings.ToLower(v[0]) 132 | switch logType { 133 | case "none": 134 | doLogging = false 135 | case "info": 136 | doLogging = true 137 | log.SetLevel(log.InfoLevel) 138 | case "debug": 139 | doLogging = true 140 | log.SetLevel(log.DebugLevel) 141 | case "error": 142 | doLogging = true 143 | log.SetLevel(log.ErrorLevel) 144 | case "warn": 145 | case "warning": 146 | doLogging = true 147 | log.SetLevel(log.WarnLevel) 148 | case "panic": 149 | doLogging = true 150 | log.SetLevel(log.PanicLevel) 151 | case "trace": 152 | doLogging = true 153 | log.SetLevel(log.TraceLevel) 154 | } 155 | default: 156 | return ci, errors.Errorf("unknown H2 server connection parameters => \"%s\" : \"%s\"", k, val) 157 | } 158 | 159 | } 160 | return ci, nil 161 | } 162 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "context" 21 | "database/sql/driver" 22 | "fmt" 23 | 24 | "net" 25 | 26 | log "github.com/sirupsen/logrus" 27 | 28 | "github.com/pkg/errors" 29 | ) 30 | 31 | const defaultH2port = 9092 32 | 33 | type h2Conn struct { 34 | connInfo h2connInfo 35 | client h2client 36 | 37 | // Interfaces 38 | driver.Conn 39 | driver.Pinger 40 | driver.Validator 41 | driver.QueryerContext 42 | driver.ExecerContext 43 | driver.ConnBeginTx 44 | } 45 | 46 | // Pinger interface 47 | func (h2c h2Conn) Ping(ctx context.Context) error { 48 | L(log.DebugLevel, "Ping") 49 | var err error 50 | stmt, err := h2c.client.sess.prepare(&h2c.client.trans, "SELECT 1") 51 | if err != nil { 52 | return driver.ErrBadConn 53 | } 54 | st, _ := stmt.(h2stmt) 55 | _, _, err = h2c.client.sess.executeQuery(&st, &h2c.client.trans) 56 | if err != nil { 57 | return driver.ErrBadConn 58 | } 59 | return nil 60 | } 61 | 62 | // Validator interface 63 | func (h2c h2Conn) IsValid() bool { 64 | // TODO: check for real valid connection 65 | L(log.DebugLevel, "IsValid") 66 | return true 67 | } 68 | 69 | // Conn interface 70 | func (h2c h2Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 71 | L(log.DebugLevel, "BeginTx") 72 | // Set autocommit to false 73 | stmt, err := h2c.client.sess.prepare2(&h2c.client.trans, "SET AUTOCOMMIT FALSE") 74 | if err != nil { 75 | return nil, err 76 | } 77 | st, _ := stmt.(h2stmt) 78 | _, err = h2c.client.sess.executeQueryUpdate(&st, &h2c.client.trans, []driver.Value{}) 79 | if err != nil { 80 | return nil, err 81 | } 82 | return &h2tx{conn: h2c}, nil 83 | } 84 | func (h2c *h2Conn) Close() error { 85 | L(log.DebugLevel, "Close conn") 86 | 87 | return h2c.client.close() 88 | } 89 | 90 | func (h2c *h2Conn) Prepare(query string) (driver.Stmt, error) { 91 | L(log.DebugLevel, "Prepare: %s", query) 92 | var err error 93 | stmt, err := h2c.client.sess.prepare2(&h2c.client.trans, query) 94 | if err != nil { 95 | return nil, err 96 | } 97 | h2stmtIns := stmt.(h2stmt) 98 | h2stmtIns.client = h2c.client 99 | h2stmtIns.query = query 100 | return h2stmtIns, nil 101 | } 102 | 103 | // QuerierContext interface 104 | func (h2c *h2Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 105 | L(log.DebugLevel, "QueryContext: %s", query) 106 | var err error 107 | stmt, err := h2c.client.sess.prepare(&h2c.client.trans, query) 108 | if err != nil { 109 | return nil, err 110 | } 111 | st, _ := stmt.(h2stmt) 112 | cols, nRows, err := h2c.client.sess.executeQuery(&st, &h2c.client.trans) 113 | if err != nil { 114 | return nil, err 115 | } 116 | return &h2Result{query: query, columns: cols, numRows: nRows, trans: &h2c.client.trans, curRow: 0}, nil 117 | } 118 | 119 | func (h2c *h2Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 120 | L(log.DebugLevel, "ExecContext: %s", query) 121 | var err error 122 | var argsValues []driver.Value 123 | for _, arg := range args { 124 | argsValues = append(argsValues, arg.Value) 125 | } 126 | stmt, err := h2c.client.sess.prepare2(&h2c.client.trans, query) 127 | if err != nil { 128 | return nil, err 129 | } 130 | st, _ := stmt.(h2stmt) 131 | nUpdated, err := h2c.client.sess.executeQueryUpdate(&st, &h2c.client.trans, argsValues) 132 | if err != nil { 133 | return nil, err 134 | } 135 | return &h2ExecResult{nUpdated: nUpdated}, nil 136 | } 137 | 138 | // Specific code 139 | 140 | func connect(ci h2connInfo) (driver.Conn, error) { 141 | var conn net.Conn 142 | var err error 143 | address := fmt.Sprintf("%s:%d", ci.host, ci.port) 144 | conn, err = ci.dialer.Dial("tcp", address) 145 | if err != nil { 146 | return nil, errors.Wrapf(err, "failed to open H2 connection") 147 | } 148 | t := newTransfer(conn) 149 | c := h2client{conn: conn, trans: t, sess: newSession()} 150 | err = c.doHandshake(ci) 151 | if err != nil { 152 | return nil, errors.Wrapf(err, "error doing H2 server handshake") 153 | } 154 | // ci.client = c 155 | return &h2Conn{connInfo: ci, client: c}, nil 156 | } 157 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "database/sql/driver" 21 | "fmt" 22 | "time" 23 | 24 | "github.com/pkg/errors" 25 | log "github.com/sirupsen/logrus" 26 | ) 27 | 28 | const ( 29 | sessionPrepare = 0 30 | sessionClose = 1 31 | sessionCommandExecuteQuery = 2 32 | sessionCommandExecuteUpdate = 3 33 | sessionCommandClose = 4 34 | sessionResultFetchRows = 5 35 | sessionResultReset = 6 36 | sessionResultClose = 7 37 | sessionCommandCommit = 8 38 | sessionChangeID = 9 39 | sessionCommandgetMetaData = 10 40 | sessionPrepareReadParams = 11 41 | sessionSetID = 12 42 | sessionCancelStatement = 13 43 | sessionCheckKey = 14 44 | sessionSetAutocommit = 15 45 | sessionHasPendingTransaction = 16 46 | sessionLobRead = 17 47 | sessionPrepareReadParams2 = 18 48 | 49 | sessionStatusError = 0 50 | sessionStatusOk = 1 51 | sessionStatusClosed = 2 52 | sessionStatusOkStateChanged = 3 53 | ) 54 | 55 | type session struct { 56 | seqID int32 57 | } 58 | 59 | func newSession() session { 60 | return session{} 61 | } 62 | 63 | func (s *session) prepare(t *transfer, sql string) (driver.Stmt, error) { 64 | var err error 65 | stmt := h2stmt{} 66 | // 0. Write SESSION_PREPARE 67 | err = t.writeInt32(sessionPrepare) 68 | // 1. Write ID 69 | stmt.id = s.getNextID() 70 | err = t.writeInt32(stmt.id) 71 | if err != nil { 72 | return stmt, err 73 | } 74 | // 2. Write SQL text 75 | err = t.writeString(sql) 76 | if err != nil { 77 | return stmt, err 78 | } 79 | // 4. Flush data and wait server info 80 | err = t.flush() 81 | if err != nil { 82 | return stmt, err 83 | } 84 | // 5. Read old state 85 | state, err := t.readInt32() 86 | if err != nil { 87 | return stmt, err 88 | } 89 | err = s.checkSQLError(state, t) 90 | if err != nil { 91 | return stmt, err 92 | } 93 | // 6. Read Is Query 94 | isQuery, err := t.readBool() 95 | if err != nil { 96 | return stmt, err 97 | } 98 | // 7. Read Is Read-only 99 | isRO, err := t.readBool() 100 | if err != nil { 101 | return stmt, err 102 | } 103 | // 8. Read params size 104 | numParams, err := t.readInt32() 105 | if err != nil { 106 | return stmt, err 107 | } 108 | L(log.DebugLevel, "STATUS: %d, IsQuery: %v, Is Read-Only: %v, Num Params: %d", state, isQuery, isRO, numParams) 109 | return stmt, nil 110 | } 111 | 112 | func (s *session) executeQuery(stmt *h2stmt, t *transfer) ([]string, int32, error) { 113 | var err error 114 | // 0. Write COMMAND EXECUTE QUERY 115 | L(log.DebugLevel, "Execute query") 116 | err = t.writeInt32(sessionCommandExecuteQuery) 117 | if err != nil { 118 | return nil, -1, err 119 | } 120 | // 1. Write ID of query 121 | //st := (*stmt).(h2stmt) 122 | err = t.writeInt32(stmt.id) 123 | if err != nil { 124 | return nil, -1, err 125 | } 126 | // 2. Write Object ID 127 | stmt.oID = s.getNextID() 128 | err = t.writeInt32(stmt.oID) 129 | if err != nil { 130 | return nil, -1, err 131 | } 132 | // 3. Write Max rows 133 | err = t.writeInt32(200) 134 | if err != nil { 135 | return nil, -1, err 136 | } 137 | // 4. Write Fetch max size 138 | err = t.writeInt32(64) 139 | if err != nil { 140 | return nil, -1, err 141 | } 142 | // 4. Write Fetch max size 143 | err = t.writeInt32(0) 144 | if err != nil { 145 | return nil, -1, err 146 | } 147 | 148 | // 5. Flush data 149 | err = t.flush() 150 | if err != nil { 151 | return nil, -1, err 152 | } 153 | // Read query status 154 | status, err := t.readInt32() 155 | if err != nil { 156 | return nil, -1, err 157 | } 158 | err = s.checkSQLError(status, t) 159 | if err != nil { 160 | return nil, -1, err 161 | } 162 | colCnt, err := t.readInt32() 163 | if err != nil { 164 | return nil, -1, err 165 | } 166 | rowCnt, err := t.readInt32() 167 | if err != nil { 168 | return nil, -1, err 169 | } 170 | L(log.DebugLevel, "Status: %d - Num cols: %d - Num rows: %d", status, colCnt, rowCnt) 171 | cols, err := s.readColumns(t, colCnt) 172 | if err != nil { 173 | return nil, -1, err 174 | } 175 | 176 | return cols, rowCnt, nil 177 | } 178 | func (s *session) readColumns(t *transfer, colCnt int32) ([]string, error) { 179 | // Alias 180 | cols := []string{} 181 | for i := 0; i < int(colCnt); i++ { 182 | alias, err := t.readString() 183 | if err != nil { 184 | return nil, err 185 | } 186 | // Schema 187 | // Ignored 188 | _, err = t.readString() 189 | if err != nil { 190 | return nil, err 191 | } 192 | // TableName 193 | // Ignored 194 | _, err = t.readString() 195 | if err != nil { 196 | return nil, err 197 | } 198 | // Column name 199 | colName, err := t.readString() 200 | if err != nil { 201 | return nil, err 202 | } 203 | // Skip other info 204 | // - Value type (int) 205 | _, err = t.readInt32() 206 | if err != nil { 207 | return nil, err 208 | } 209 | // - Precision (long) 210 | _, err = t.readLong() 211 | if err != nil { 212 | return nil, err 213 | } 214 | // - Scale (int) 215 | _, err = t.readInt32() 216 | if err != nil { 217 | return nil, err 218 | } 219 | // - Display Size (int) 220 | _, err = t.readInt32() 221 | if err != nil { 222 | return nil, err 223 | } 224 | // - Autoincrement (bool) 225 | _, err = t.readBool() 226 | if err != nil { 227 | return nil, err 228 | } 229 | // - Nullable (int) 230 | _, err = t.readInt32() 231 | if err != nil { 232 | return nil, err 233 | } 234 | // Set columns name 235 | if alias != "" { 236 | cols = append(cols, alias) 237 | } else { 238 | cols = append(cols, colName) 239 | } 240 | } 241 | return cols, nil 242 | 243 | } 244 | func (s *session) getNextID() int32 { 245 | s.seqID++ 246 | return s.seqID 247 | } 248 | 249 | type h2error struct { 250 | strError string 251 | msg string 252 | sql string 253 | codeError int32 254 | trace string 255 | error 256 | } 257 | 258 | func (s *session) checkSQLError(state int32, t *transfer) error { 259 | if state == 1 { 260 | return nil 261 | } 262 | // SQL Error 263 | sqlError, err := t.readString() 264 | if err != nil { 265 | return errors.Wrapf(err, "SQL Error: unknown") 266 | } 267 | sqlMsg, err := t.readString() 268 | if err != nil { 269 | return errors.Wrapf(err, "SQL Error: unknown") 270 | } 271 | sqlSQL, err := t.readString() 272 | if err != nil { 273 | return errors.Wrapf(err, "SQL Error: unknown") 274 | } 275 | errCode, err := t.readInt32() 276 | if err != nil { 277 | return errors.Wrapf(err, "SQL Error: unknown") 278 | } 279 | sqlTrace, err := t.readString() 280 | if err != nil { 281 | return errors.Wrapf(err, "SQL Error: unknown") 282 | } 283 | 284 | return newError(sqlError, sqlMsg, sqlSQL, errCode, sqlTrace) 285 | 286 | } 287 | 288 | func newError(strError string, msg string, sql string, codeError int32, trace string) *h2error { 289 | return &h2error{strError: strError, msg: msg, sql: sql, codeError: codeError, trace: trace} 290 | } 291 | func (err *h2error) Error() string { 292 | 293 | return fmt.Sprintf("H2 SQL Exception: [%s] %s", err.strError, err.msg) 294 | } 295 | 296 | func (s *session) executeQueryUpdate(stmt *h2stmt, t *transfer, values []driver.Value) (int32, error) { 297 | var err error 298 | // Check for params 299 | if stmt.numParams != int32(len(values)) { 300 | return -1, fmt.Errorf("Num expected parameters mismatch: %d != %d", stmt.numParams, len(values)) 301 | } 302 | // 0. Write COMMAND EXECUTE QUERY 303 | L(log.DebugLevel, "Execute query update") 304 | err = t.writeInt32(sessionCommandExecuteUpdate) 305 | if err != nil { 306 | return -1, err 307 | } 308 | // 1. Write ID of query 309 | //st := (*stmt).(h2stmt) 310 | err = t.writeInt32(stmt.id) 311 | if err != nil { 312 | return -1, err 313 | } 314 | // 2. Write params 315 | // -- num parameters 316 | err = t.writeInt32(stmt.numParams) 317 | if err != nil { 318 | return -1, err 319 | } 320 | // -- parameters 321 | for idx, value := range values { 322 | switch value.(type) { 323 | case time.Time: 324 | err = t.writeDatetimeValue(value.(time.Time), stmt.parameters[idx]) 325 | default: 326 | err = t.writeValue(value) 327 | } 328 | if err != nil { 329 | return -1, err 330 | } 331 | } 332 | // 3. Write Generate keys mode support 333 | // TODO 334 | err = t.writeInt32(0) 335 | if err != nil { 336 | return -1, err 337 | } 338 | err = t.flush() 339 | if err != nil { 340 | return -1, err 341 | } 342 | L(log.DebugLevel, "Read status") 343 | // Read query status 344 | status, err := t.readInt32() 345 | if err != nil { 346 | return -1, err 347 | } 348 | err = s.checkSQLError(status, t) 349 | if err != nil { 350 | return -1, err 351 | } 352 | // TODO: assert status == 1 353 | // Read num rows updated 354 | nUpdated, err := t.readInt32() 355 | if err != nil { 356 | return -1, err 357 | } 358 | // Read auto-commit status 359 | // TODO 360 | autoCommit, err := t.readBool() 361 | if err != nil { 362 | return -1, err 363 | } 364 | L(log.DebugLevel, "Status: %d - Num updated: %d - Autocommit: %v", status, nUpdated, autoCommit) 365 | return nUpdated, nil 366 | } 367 | 368 | func (s *session) prepare2(t *transfer, sql string) (driver.Stmt, error) { 369 | var err error 370 | stmt := h2stmt{} 371 | // 0. Write SESSION_PREPARE 372 | err = t.writeInt32(sessionPrepareReadParams2) 373 | // 1. Write ID 374 | stmt.id = s.getNextID() 375 | err = t.writeInt32(stmt.id) 376 | if err != nil { 377 | return stmt, err 378 | } 379 | // 2. Write SQL text 380 | err = t.writeString(sql) 381 | if err != nil { 382 | return stmt, err 383 | } 384 | 385 | // 4. Flush data and wait server info 386 | err = t.flush() 387 | if err != nil { 388 | return stmt, err 389 | } 390 | // 5. Read state 391 | state, err := t.readInt32() 392 | if err != nil { 393 | return stmt, err 394 | } 395 | err = s.checkSQLError(state, t) 396 | if err != nil { 397 | return stmt, err 398 | } 399 | 400 | // 6. Read Is Query 401 | isQuery, err := t.readBool() 402 | if err != nil { 403 | return stmt, err 404 | } 405 | // 7. Read Is Read-only 406 | isRO, err := t.readBool() 407 | if err != nil { 408 | return stmt, err 409 | } 410 | // Get command type 411 | cmdType, err := t.readInt32() 412 | if err != nil { 413 | return stmt, err 414 | } 415 | L(log.DebugLevel, "CMD type: %d", cmdType) 416 | // 8. Read params size 417 | numParams, err := t.readInt32() 418 | if err != nil { 419 | return stmt, err 420 | } 421 | L(log.DebugLevel, "STATUS: %d, IsQuery: %v, Is Read-Only: %v, Num Params: %d", state, isQuery, isRO, numParams) 422 | stmt.isQuery = isQuery 423 | stmt.isRO = isRO 424 | stmt.numParams = numParams 425 | // We receive metadata for each parameter 426 | // Metadata parameter type: int:type - long:precission - int:scale - int:nullable 427 | for i := 0; i < int(numParams); i++ { 428 | param := h2parameter{} 429 | // -- Type 430 | param.kind, err = t.readInt32() 431 | if err != nil { 432 | return nil, err 433 | } 434 | // -- Precission 435 | param.precission, err = t.readInt64() 436 | if err != nil { 437 | return nil, err 438 | } 439 | // -- Scale 440 | param.scale, err = t.readInt32() 441 | if err != nil { 442 | return nil, err 443 | } 444 | // -- Nullable 445 | tmp, err := t.readInt32() 446 | if err != nil { 447 | return nil, err 448 | } 449 | // 0 = Not null, 1 == Nullable, 2 == Unknown 450 | param.nullable = tmp == 1 451 | L(log.DebugLevel, "PARAM: Kind: %d - Precission: %d - Scale: %d - Nullable: %v", param.kind, param.precission, param.scale, param.nullable) 452 | stmt.parameters = append(stmt.parameters, param) 453 | } 454 | return stmt, nil 455 | } 456 | 457 | func (s *session) close(t *transfer) error { 458 | var err error 459 | // 0. Write SESSION_CLOSE 460 | err = t.writeInt32(sessionClose) 461 | if err != nil { 462 | return err 463 | } 464 | err = t.flush() 465 | if err != nil { 466 | return err 467 | } 468 | // 1. Read ID 469 | status, err := t.readInt32() 470 | if err != nil { 471 | return err 472 | } 473 | L(log.DebugLevel, "Status: %d", status) 474 | t.close() 475 | return nil 476 | } 477 | -------------------------------------------------------------------------------- /driver_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "database/sql" 21 | "fmt" 22 | "log" 23 | "net" 24 | "os" 25 | "strconv" 26 | "testing" 27 | "time" 28 | ) 29 | 30 | // Using a testing pattern similar to Go MySQL Driver (https://github.com/go-sql-driver/mysql) 31 | // Main testing entry point 32 | var ( 33 | user string 34 | pass string 35 | addr string 36 | dbname string 37 | inMem bool 38 | dsn string 39 | available bool 40 | ) 41 | 42 | type dbTest struct { 43 | *testing.T 44 | conn *sql.DB 45 | } 46 | 47 | func (dt dbTest) checkErr(err error) { 48 | if err != nil { 49 | dt.Errorf("error: %s", err) 50 | } 51 | } 52 | 53 | func init() { 54 | env := func(key, defVal string) string { 55 | if val := os.Getenv(key); val != "" { 56 | return val 57 | } 58 | return defVal 59 | } 60 | user = env("H2_TEST_USER", "sa") 61 | pass = env("H2_TEST_PASSWORD", "") 62 | addr = env("H2_TEST_ADDR", "localhost:9092") 63 | dbname = env("H2_TEST_DBNAME", "test") 64 | inMemS := env("H2_TEST_IN_MEMORY", "true") 65 | inMem, err := strconv.ParseBool(inMemS) 66 | if err != nil { 67 | inMem = true 68 | } 69 | if pass != "" { 70 | dsn = fmt.Sprintf("h2://%s:%s@%s/%s?mem=%t", user, pass, addr, dbname, inMem) 71 | } else { 72 | dsn = fmt.Sprintf("h2://%s@%s/%s?mem=%t", user, addr, dbname, inMem) 73 | } 74 | // Check alive 75 | c, err := net.Dial("tcp", addr) 76 | if err == nil { 77 | available = true 78 | c.Close() 79 | } else { 80 | log.Printf("Can't connect: %s", err) 81 | } 82 | } 83 | func runTests(t *testing.T, tests ...func(dt *dbTest)) { 84 | var err error 85 | if !available { 86 | t.Errorf("H2 Server not running on %s", addr) 87 | } 88 | conn, err := sql.Open("h2", dsn) 89 | if err != nil { 90 | t.Fatalf("Can't connect to the H2 server: %s", err) 91 | } 92 | defer conn.Close() 93 | db := &dbTest{t, conn} 94 | for _, test := range tests { 95 | test(db) 96 | conn.Exec("DROP TABLE IF EXISTS test") 97 | 98 | } 99 | 100 | } 101 | 102 | /* 103 | func TestConnection(t *testing.T) { 104 | conn, err := sql.Open("h2", "h2://sa@h2server:9092/test?mem=true") 105 | if err != nil { 106 | t.Errorf("Can't connect to the server: %s", err) 107 | } 108 | err = conn.Ping() 109 | if err != nil { 110 | t.Errorf("Can't do PING: %s", err) 111 | } 112 | } 113 | */ 114 | 115 | func TestPing(t *testing.T) { 116 | runTests(t, func(dt *dbTest) { 117 | err := dt.conn.Ping() 118 | dt.checkErr(err) 119 | }) 120 | } 121 | 122 | func TestSimpleCRUD(t *testing.T) { 123 | runTests(t, func(dt *dbTest) { 124 | var err error 125 | // Create table 126 | sent := "CREATE TABLE test (id int, name varchar, age int)" 127 | _, err = dt.conn.Exec(sent) 128 | dt.checkErr(err) 129 | // Insert a row 130 | sent = "INSERT INTO test VALUES (1, 'Paco', 23)" 131 | result, err := dt.conn.Exec(sent) 132 | dt.checkErr(err) 133 | nR, err := result.RowsAffected() 134 | dt.checkErr(err) 135 | if nR != 1 { 136 | dt.Errorf("Num rows inserted not equal to 1") 137 | } 138 | // Query 139 | sent = "SELECT * FROM test" 140 | rows, err := dt.conn.Query(sent) 141 | dt.checkErr(err) 142 | for rows.Next() { 143 | var ( 144 | id int 145 | name string 146 | age int 147 | ) 148 | err = rows.Scan(&id, &name, &age) 149 | dt.checkErr(err) 150 | if id != 1 { 151 | dt.Errorf("ID mismatch (not equal to 1)") 152 | } 153 | if name != "Paco" { 154 | dt.Errorf("Name mismatch (not equal to 'Paco')") 155 | } 156 | if age != 23 { 157 | dt.Errorf("Age mismatch (not equal to 23)") 158 | } 159 | } 160 | err = rows.Close() 161 | // Update row 162 | sent = "UPDATE test SET age = 24 WHERE id = 1" 163 | result, err = dt.conn.Exec(sent) 164 | dt.checkErr(err) 165 | nR, err = result.RowsAffected() 166 | dt.checkErr(err) 167 | if nR != 1 { 168 | dt.Errorf("Num rows updated not equal to 1") 169 | } 170 | // Query again 171 | sent = "SELECT * FROM test" 172 | rows, err = dt.conn.Query(sent) 173 | dt.checkErr(err) 174 | for rows.Next() { 175 | var ( 176 | id int 177 | name string 178 | age int 179 | ) 180 | err = rows.Scan(&id, &name, &age) 181 | dt.checkErr(err) 182 | if id != 1 { 183 | dt.Errorf("ID mismatch (not equal to 1)") 184 | } 185 | if name != "Paco" { 186 | dt.Errorf("Name mismatch (not equal to 'Paco')") 187 | } 188 | if age != 24 { 189 | dt.Errorf("Age mismatch (not equal to 24)") 190 | } 191 | } 192 | err = rows.Close() 193 | // Insert another row 194 | sent = "INSERT INTO test VALUES (2, 'John', 24)" 195 | result, err = dt.conn.Exec(sent) 196 | dt.checkErr(err) 197 | nR, err = result.RowsAffected() 198 | dt.checkErr(err) 199 | if nR != 1 { 200 | dt.Errorf("Num rows inserted not equal to 1") 201 | } 202 | // Delete all 203 | sent = "DELETE FROM test" 204 | result, err = dt.conn.Exec(sent) 205 | dt.checkErr(err) 206 | nR, err = result.RowsAffected() 207 | dt.checkErr(err) 208 | if nR != 2 { 209 | dt.Errorf("Num rows deleted not equal to 2") 210 | } 211 | // Skip DROP TABLE (done by the wrapper) 212 | }) 213 | } 214 | 215 | func TestCRUDwithParameters(t *testing.T) { 216 | runTests(t, func(dt *dbTest) { 217 | var err error 218 | var sent string 219 | var ( 220 | id int = 1 221 | name string = "Paco" 222 | age int = 23 223 | ) 224 | // Create table 225 | sent = "CREATE TABLE test (id int, name varchar, age int)" 226 | _, err = dt.conn.Exec(sent) 227 | dt.checkErr(err) 228 | // Insert with parameters 229 | sent = "INSERT INTO test VALUES (?, ?, ?)" 230 | result, err := dt.conn.Exec(sent, id, name, age) 231 | dt.checkErr(err) 232 | nR, err := result.RowsAffected() 233 | dt.checkErr(err) 234 | if nR != 1 { 235 | dt.Errorf("Num rows inserted not equal to 1") 236 | } 237 | // Query 238 | sent = "SELECT * FROM test" 239 | rows, err := dt.conn.Query(sent) 240 | dt.checkErr(err) 241 | for rows.Next() { 242 | var ( 243 | id int 244 | name string 245 | age int 246 | ) 247 | err = rows.Scan(&id, &name, &age) 248 | dt.checkErr(err) 249 | if id != 1 { 250 | dt.Errorf("ID mismatch (not equal to 1)") 251 | } 252 | if name != "Paco" { 253 | dt.Errorf("Name mismatch (not equal to 'Paco')") 254 | } 255 | if age != 23 { 256 | dt.Errorf("Age mismatch (not equal to 23)") 257 | } 258 | } 259 | err = rows.Close() 260 | }) 261 | } 262 | 263 | func TestDateTimeTypes(t *testing.T) { 264 | runTests(t, func(dt *dbTest) { 265 | var err error 266 | var sent string 267 | var ( 268 | id int = 1 269 | dtFixed time.Time 270 | ) 271 | // Create table 272 | sent = "CREATE TABLE test (id INT, t TIME, ttz TIME WITH TIME ZONE, d DATE, ts TIMESTAMP, tstz TIMESTAMP WITH TIME ZONE)" 273 | _, err = dt.conn.Exec(sent) 274 | dt.checkErr(err) 275 | // Insert a row 276 | sent = "INSERT INTO test VALUES (?, ?, ?, ?, ?, ?)" 277 | loc, err := time.LoadLocation("Europe/Madrid") 278 | if err != nil { 279 | dt.Skipf("Can't get timezone for Europe/Madrid: %s", err) 280 | } 281 | dtFixed = time.Date(2020, 5, 25, 9, 1, 2, 123, loc) 282 | result, err := dt.conn.Exec(sent, id, dtFixed, dtFixed, dtFixed, dtFixed, dtFixed) 283 | dt.checkErr(err) 284 | nR, err := result.RowsAffected() 285 | dt.checkErr(err) 286 | if nR != 1 { 287 | dt.Errorf("Num rows inserted not equal to 1") 288 | } 289 | // Query 290 | sent = "SELECT * FROM test" 291 | rows, err := dt.conn.Query(sent) 292 | dt.checkErr(err) 293 | for rows.Next() { 294 | var ( 295 | vTime time.Time 296 | vTimeTZ time.Time 297 | vDate time.Time 298 | vTimestamp time.Time 299 | vTimestampTZ time.Time 300 | ) 301 | err = rows.Scan(&id, &vTime, &vTimeTZ, &vDate, &vTimestamp, &vTimestampTZ) 302 | dt.checkErr(err) 303 | if id != 1 { 304 | dt.Errorf("ID mismatch (not equal to 1)") 305 | } 306 | // TIME check 307 | if vTime.Hour() != 9 || vTime.Minute() != 1 || vTime.Second() != 2 { 308 | dt.Errorf("Time mismatch: %d %d %d", vTime.Hour(), vTime.Minute(), vTime.Second()) 309 | } 310 | // TIME WITH TIME ZONE check 311 | _, nSecOffset := vTimeTZ.Zone() 312 | if vTimeTZ.Hour() != 9 || vTimeTZ.Minute() != 1 || vTimeTZ.Second() != 2 || nSecOffset != 7200 { 313 | dt.Errorf("Time TZ mismatch: %d %d %d %d", vTimeTZ.Hour(), vTimeTZ.Minute(), vTimeTZ.Second(), nSecOffset) 314 | } 315 | // DATE check 316 | if vDate.Day() != 25 || vDate.Month() != 5 || vDate.Year() != 2020 { 317 | dt.Errorf("Date mismatch: %d %d %d", vDate.Day(), vDate.Month(), vDate.Year()) 318 | } 319 | // TIMESTAMP check 320 | if vTimestamp.Day() != 25 || vTimestamp.Month() != 5 || vTimestamp.Year() != 2020 || vTimestamp.Hour() != 9 || vTimestamp.Minute() != 1 || vTimestamp.Second() != 2 { 321 | dt.Errorf("Timestamp mismatch: %d %d %d %d %d %d", vTimestamp.Day(), vTimestamp.Month(), vTimestamp.Year(), vTimestamp.Hour(), vTimestamp.Minute(), vTimestamp.Second()) 322 | } 323 | // TIMESTAMP WITH TIME Zone check 324 | _, nSecOffset = vTimeTZ.Zone() 325 | if vTimestampTZ.Day() != 25 || vTimestampTZ.Month() != 5 || vTimestampTZ.Year() != 2020 || vTimestampTZ.Hour() != 9 || vTimestampTZ.Minute() != 1 || vTimestampTZ.Second() != 2 || nSecOffset != 7200 { 326 | dt.Errorf("Timestamp TZ mismatch: %d %d %d %d %d %d %d", vTimestampTZ.Day(), vTimestampTZ.Month(), vTimestampTZ.Year(), vTimestampTZ.Hour(), vTimestampTZ.Minute(), vTimestampTZ.Second(), nSecOffset) 327 | } 328 | } 329 | err = rows.Close() 330 | }) 331 | } 332 | 333 | func TestOtherTypes(t *testing.T) { 334 | runTests(t, func(dt *dbTest) { 335 | var err error 336 | var sent string 337 | // CREATE TABLE 338 | sent = "CREATE TABLE test (id INT, name VARCHAR(100), height FLOAT, isGood BOOLEAN, numAtoms DOUBLE, age SMALLINT)" 339 | _, err = dt.conn.Exec(sent) 340 | dt.checkErr(err) 341 | // INSERT 342 | var ( 343 | id int = 1 344 | name string = "Paco" 345 | height float32 = 1.88 346 | isGood bool = true 347 | numAtoms float64 = 13213123332132.5 348 | age int16 = 23 349 | ) 350 | sent = "INSERT INTO test VALUES (?, ?, ?, ?, ?, ?)" 351 | result, err := dt.conn.Exec(sent, id, name, height, isGood, numAtoms, age) 352 | dt.checkErr(err) 353 | nR, err := result.RowsAffected() 354 | dt.checkErr(err) 355 | if nR != 1 { 356 | dt.Errorf("Num rows inserted not equal to 1") 357 | } 358 | // Query 359 | sent = "SELECT * FROM test" 360 | rows, err := dt.conn.Query(sent) 361 | dt.checkErr(err) 362 | for rows.Next() { 363 | err = rows.Scan(&id, &name, &height, &isGood, &numAtoms, &age) 364 | dt.checkErr(err) 365 | if id != 1 { 366 | dt.Errorf("ID mismatch (not equal to 1)") 367 | } 368 | if name != "Paco" { 369 | dt.Errorf("Name mismatch (not equal to 'Paco')") 370 | } 371 | if height != 1.88 { 372 | dt.Errorf("Height mismatch (not equal to 1.88)") 373 | } 374 | if !isGood { 375 | dt.Errorf("isGood is false") 376 | } 377 | if numAtoms != 13213123332132.5 { 378 | dt.Errorf("Num atoms mismatch (not equal to 13213123332132.5)") 379 | } 380 | if age != 23 { 381 | dt.Errorf("Age mismatch (not equal to 23)") 382 | } 383 | } 384 | rows.Close() 385 | }) 386 | } 387 | 388 | func TestStmt(t *testing.T) { 389 | runTests(t, func(dt *dbTest) { 390 | var err error 391 | var sent string 392 | // CREATE TABLE 393 | sent = "CREATE TABLE test (id INT, name VARCHAR(100))" 394 | _, err = dt.conn.Exec(sent) 395 | dt.checkErr(err) 396 | // Get Stmt 397 | stmt, err := dt.conn.Prepare("INSERT INTO test VALUES (?,?)") 398 | dt.checkErr(err) 399 | result, err := stmt.Exec(1, "Paco") 400 | dt.checkErr(err) 401 | nR, err := result.RowsAffected() 402 | dt.checkErr(err) 403 | if nR != 1 { 404 | dt.Errorf("Num rows inserted not equal to 1") 405 | } 406 | }) 407 | } 408 | 409 | func TestTx(t *testing.T) { 410 | runTests(t, func(dt *dbTest) { 411 | var err error 412 | var sent string 413 | // CREATE TABLE 414 | sent = "CREATE TABLE test (id INT, name VARCHAR(100))" 415 | _, err = dt.conn.Exec(sent) 416 | dt.checkErr(err) 417 | // TX with commit 418 | tx, err := dt.conn.Begin() 419 | dt.checkErr(err) 420 | result, err := tx.Exec("INSERT INTO test VALUES (1, 'Paco')") 421 | dt.checkErr(err) 422 | nR, err := result.RowsAffected() 423 | dt.checkErr(err) 424 | if nR != 1 { 425 | dt.Errorf("Num rows inserted not equal to 1") 426 | } 427 | err = tx.Commit() 428 | dt.checkErr(err) 429 | // Query 430 | var ( 431 | id int 432 | name string 433 | ) 434 | sent = "SELECT * FROM test" 435 | rows, err := dt.conn.Query(sent) 436 | dt.checkErr(err) 437 | for rows.Next() { 438 | err = rows.Scan(&id, &name) 439 | dt.checkErr(err) 440 | if id != 1 { 441 | dt.Errorf("ID mismatch (not equal to 1)") 442 | } 443 | if name != "Paco" { 444 | dt.Errorf("Name mismatch (not equal to 'Paco')") 445 | } 446 | } 447 | rows.Close() 448 | // Tx with rollback 449 | tx, err = dt.conn.Begin() 450 | dt.checkErr(err) 451 | result, err = tx.Exec("INSERT INTO test VALUES (2, 'John')") 452 | dt.checkErr(err) 453 | nR, err = result.RowsAffected() 454 | dt.checkErr(err) 455 | if nR != 1 { 456 | dt.Errorf("Num rows inserted not equal to 1") 457 | } 458 | err = tx.Rollback() 459 | dt.checkErr(err) 460 | // Query 461 | sent = "SELECT * FROM test" 462 | rows, err = dt.conn.Query(sent) 463 | dt.checkErr(err) 464 | for rows.Next() { 465 | err = rows.Scan(&id, &name) 466 | dt.checkErr(err) 467 | if id != 1 { 468 | dt.Errorf("ID mismatch (not equal to 1)") 469 | } 470 | if name != "Paco" { 471 | dt.Errorf("Name mismatch (not equal to 'Paco')") 472 | } 473 | } 474 | rows.Close() 475 | }) 476 | } 477 | -------------------------------------------------------------------------------- /transfer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 JM Robles (@jmrobles) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package h2go 18 | 19 | import ( 20 | "bufio" 21 | "encoding/binary" 22 | "fmt" 23 | "net" 24 | "time" 25 | "unsafe" 26 | 27 | "github.com/pkg/errors" 28 | log "github.com/sirupsen/logrus" 29 | "golang.org/x/text/encoding/unicode" 30 | ) 31 | 32 | // Value types 33 | const ( 34 | Null int32 = 0 35 | Boolean int32 = 1 36 | Byte int32 = 2 37 | Short int32 = 3 38 | Int int32 = 4 39 | Long int32 = 5 40 | Decimal int32 = 6 41 | Double int32 = 7 42 | Float int32 = 8 43 | Time int32 = 9 44 | Date int32 = 10 45 | Timestamp int32 = 11 46 | Bytes int32 = 12 47 | String int32 = 13 48 | StringIgnoreCase int32 = 14 49 | Blob int32 = 15 50 | Clob int32 = 16 51 | Array int32 = 17 52 | ResultSet int32 = 18 53 | JavaObject int32 = 19 54 | UUID int32 = 20 55 | StringFixed int32 = 21 56 | Geometry int32 = 22 57 | TimestampTZ int32 = 24 58 | Enum int32 = 25 59 | Interval int32 = 26 60 | Row int32 = 27 61 | JSON int32 = 28 62 | TimeTZQuery int32 = 29 63 | TimeTZ int32 = 41 64 | ) 65 | 66 | type transfer struct { 67 | conn net.Conn 68 | buff *bufio.ReadWriter 69 | } 70 | 71 | func newTransfer(conn net.Conn) transfer { 72 | 73 | buffReader := bufio.NewReader(conn) 74 | buffWriter := bufio.NewWriter(conn) 75 | buff := bufio.NewReadWriter(buffReader, buffWriter) 76 | return transfer{conn: conn, buff: buff} 77 | } 78 | 79 | func (t *transfer) readInt32() (int32, error) { 80 | var ret int32 81 | err := binary.Read(t.buff, binary.BigEndian, &ret) 82 | if err != nil { 83 | return -1, errors.Wrapf(err, "can't read int32 value from socket") 84 | } 85 | return ret, nil 86 | } 87 | func (t *transfer) readInt16() (int16, error) { 88 | n, err := t.readInt32() 89 | if err != nil { 90 | return int16(-1), err 91 | } 92 | return int16(n), err 93 | } 94 | func (t *transfer) readInt64() (int64, error) { 95 | var ret int64 96 | err := binary.Read(t.buff, binary.BigEndian, &ret) 97 | if err != nil { 98 | return -1, errors.Wrapf(err, "can't read int64 value from socket") 99 | } 100 | return ret, nil 101 | } 102 | 103 | func (t *transfer) readFloat32() (float32, error) { 104 | var ret float32 105 | err := binary.Read(t.buff, binary.BigEndian, &ret) 106 | if err != nil { 107 | return -1, errors.Wrapf(err, "can't read float32 value from socket") 108 | } 109 | return ret, nil 110 | } 111 | 112 | func (t *transfer) readFloat64() (float64, error) { 113 | var ret float64 114 | err := binary.Read(t.buff, binary.BigEndian, &ret) 115 | if err != nil { 116 | return -1, errors.Wrapf(err, "can't read float64 value from socket") 117 | } 118 | return ret, nil 119 | } 120 | 121 | func (t *transfer) writeInt32(v int32) error { 122 | return binary.Write(t.buff, binary.BigEndian, v) 123 | } 124 | 125 | func (t *transfer) writeInt64(v int64) error { 126 | return binary.Write(t.buff, binary.BigEndian, v) 127 | } 128 | func (t *transfer) writeFloat64(v float64) error { 129 | return binary.Write(t.buff, binary.BigEndian, v) 130 | } 131 | 132 | func (t *transfer) readString() (string, error) { 133 | var err error 134 | n, err := t.readInt32() 135 | if err != nil { 136 | return "", errors.Wrapf(err, "can't read string length from socket") 137 | } 138 | if n == -1 || n == 0 { 139 | return "", nil 140 | } 141 | buf := make([]byte, n*2) 142 | /* 143 | var cur int32 144 | for { 145 | n2, err := t.buff.Read(buf[cur:n]) 146 | if err != nil { 147 | return "", err 148 | } 149 | cur += int32(n2) 150 | if cur == n { 151 | break 152 | } 153 | } 154 | */ 155 | n2, err := t.buff.Read(buf) 156 | if err != nil { 157 | return "", err 158 | } 159 | if n2 != len(buf) { 160 | return "", errors.Errorf("Can't read all data needed") 161 | } 162 | dec := unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder() 163 | buf, err = dec.Bytes(buf) 164 | if err != nil { 165 | return "", errors.Wrapf(err, "can't convert from UTF-16 a UTF-8 string") 166 | } 167 | return string(buf), nil 168 | 169 | } 170 | 171 | func (t *transfer) writeString(s string) error { 172 | var err error 173 | data := []byte(s) 174 | n := int32(len(data)) 175 | if n == 0 { 176 | n = -1 177 | } 178 | err = t.writeInt32(n) 179 | if err != nil { 180 | return errors.Wrapf(err, "can't write string length to socket") 181 | } 182 | if n == -1 { 183 | return nil 184 | } 185 | enc := unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewEncoder() 186 | data, err = enc.Bytes(data) 187 | if err != nil { 188 | return errors.Wrapf(err, "can't convert to UTF-16") 189 | } 190 | /* 191 | n = int32(len(data)) 192 | for { 193 | n2, err := t.buff.Write(data[pos:n]) 194 | if err != nil { 195 | return errors.Wrapf(err, "can't write string to socket") 196 | } 197 | pos += int32(n2) 198 | if pos == n { 199 | break 200 | } 201 | } 202 | */ 203 | n2, err := t.buff.Write(data) 204 | if err != nil { 205 | return errors.Wrapf(err, "can't write string to socket") 206 | } 207 | if n2 != len(data) { 208 | return errors.Errorf("Data send not equal to wished") 209 | } 210 | return nil 211 | } 212 | 213 | func (t *transfer) readBytes() ([]byte, error) { 214 | n, err := t.readInt32() 215 | if err != nil { 216 | return nil, errors.Wrapf(err, "can't read bytes length from socket") 217 | } 218 | if n == -1 { 219 | return nil, nil 220 | } 221 | return t.readBytesDef(int(n)) 222 | 223 | } 224 | func (t *transfer) writeBool(b bool) error { 225 | var v byte = 0 226 | if b { 227 | v = 1 228 | } 229 | return t.writeByte(v) 230 | } 231 | 232 | func (t *transfer) writeByte(b byte) error { 233 | return t.buff.WriteByte(b) 234 | } 235 | 236 | func (t *transfer) writeBytes(data []byte) error { 237 | var err error 238 | s := int32(len(data)) 239 | if data == nil || s == 0 { 240 | s = -1 241 | } 242 | err = t.writeInt32(s) 243 | if err != nil { 244 | return errors.Wrapf(err, "can't write bytes length to socket") 245 | } 246 | if s == -1 { 247 | return nil 248 | } 249 | n, err := t.buff.Write(data) 250 | if err != nil { 251 | return errors.Wrapf(err, "can't write bytes to socket") 252 | } 253 | if int32(n) != s { 254 | return errors.Wrapf(err, "can't write all bytes to socket => %d != %d", n, s) 255 | } 256 | return nil 257 | } 258 | 259 | func (t *transfer) readBool() (bool, error) { 260 | v, err := t.readByte() 261 | if err != nil { 262 | return false, err 263 | } 264 | return v == 1, nil 265 | } 266 | 267 | func (t *transfer) readByte() (byte, error) { 268 | v, err := t.buff.ReadByte() 269 | return v, err 270 | } 271 | 272 | func (t *transfer) readLong() (int64, error) { 273 | var ret int64 274 | err := binary.Read(t.buff, binary.BigEndian, &ret) 275 | if err != nil { 276 | return -1, errors.Wrapf(err, "can't read long value from socket") 277 | } 278 | return ret, nil 279 | } 280 | func (t *transfer) readDate() (time.Time, error) { 281 | n, err := t.readInt64() 282 | if err != nil { 283 | return time.Time{}, err 284 | } 285 | date := bin2date(n) 286 | return date, nil 287 | } 288 | 289 | func (t *transfer) readTimestamp() (time.Time, error) { 290 | nDate, err := t.readInt64() 291 | if err != nil { 292 | return time.Time{}, err 293 | } 294 | nNsecs, err := t.readInt64() 295 | if err != nil { 296 | return time.Time{}, err 297 | } 298 | date := bin2ts(nDate, nNsecs) 299 | return date, nil 300 | } 301 | 302 | func (t *transfer) readTimestampTZ() (time.Time, error) { 303 | nDate, err := t.readInt64() 304 | if err != nil { 305 | return time.Time{}, err 306 | } 307 | nNsecs, err := t.readInt64() 308 | if err != nil { 309 | return time.Time{}, err 310 | } 311 | nDiffTZ, err := t.readInt32() 312 | if err != nil { 313 | return time.Time{}, err 314 | } 315 | date := bin2tsz(nDate, nNsecs, nDiffTZ) 316 | return date, nil 317 | } 318 | 319 | func (t *transfer) flush() error { 320 | return t.buff.Flush() 321 | } 322 | 323 | func (t *transfer) readValue() (interface{}, error) { 324 | var err error 325 | kind, err := t.readInt32() 326 | if err != nil { 327 | return nil, errors.Wrapf(err, "can't read type of value") 328 | } 329 | L(log.DebugLevel, "Value type: %d", kind) 330 | switch kind { 331 | case Null: 332 | // TODO: review 333 | return nil, nil 334 | case Bytes: 335 | return t.readBytes() 336 | case UUID: 337 | return nil, errors.Errorf("UUID not implemented") 338 | case JavaObject: 339 | return nil, errors.Errorf("Java Object not implemented") 340 | case Boolean: 341 | return t.readBool() 342 | case Byte: 343 | return t.readByte() 344 | case Date: 345 | return t.readDate() 346 | case Time: 347 | return t.readTime() 348 | case TimeTZQuery, TimeTZ: 349 | return t.readTimeTZ() 350 | case Timestamp: 351 | return t.readTimestamp() 352 | case TimestampTZ: 353 | return t.readTimestampTZ() 354 | case Decimal: 355 | return nil, errors.Errorf("Decimal not implemented") 356 | case Double: 357 | return t.readFloat64() 358 | case Float: 359 | return t.readFloat32() 360 | case Enum: 361 | return nil, errors.Errorf("Enum not implemented") 362 | case Int: 363 | return t.readInt32() 364 | case Long: 365 | return t.readLong() 366 | case Short: 367 | return t.readInt16() 368 | case String: 369 | return t.readString() 370 | case StringIgnoreCase: 371 | return t.readString() 372 | case StringFixed: 373 | return t.readString() 374 | case Blob: 375 | return nil, errors.Errorf("Blob not implemented") 376 | case Clob: 377 | return nil, errors.Errorf("Clob not implemented") 378 | case Array: 379 | return nil, errors.Errorf("Array not implemented") 380 | case Row: 381 | return nil, errors.Errorf("Row not implemented") 382 | case ResultSet: 383 | return nil, errors.Errorf("Result Set not implemented") 384 | case Geometry: 385 | return nil, errors.Errorf("Geometry not implemented") 386 | case JSON: 387 | return nil, errors.Errorf("JSON not implemented") 388 | default: 389 | L(log.ErrorLevel, "Unknown type: %d", kind) 390 | return nil, errors.Errorf("Unknown type: %d", kind) 391 | } 392 | 393 | } 394 | 395 | func (t *transfer) writeValue(v interface{}) error { 396 | switch kind := v.(type) { 397 | case nil: 398 | t.writeInt32(Null) 399 | case bool: 400 | t.writeInt32(Boolean) 401 | t.writeBool(v.(bool)) 402 | case int: 403 | s := unsafe.Sizeof(v) 404 | if s == 4 { 405 | t.writeInt32(Int) 406 | t.writeInt32(int32(v.(int))) 407 | } else { 408 | // 8 bytes 409 | t.writeInt32(Long) 410 | t.writeInt64(int64(v.(int))) 411 | } 412 | case int16: 413 | t.writeInt32(Short) 414 | t.writeInt32(v.(int32)) 415 | case int32: 416 | t.writeInt32(Int) 417 | t.writeInt32(int32(v.(int32))) 418 | case int64: 419 | t.writeInt32(Long) 420 | t.writeInt64(int64(v.(int64))) 421 | case float64: 422 | t.writeInt32(Double) 423 | t.writeFloat64(v.(float64)) 424 | case string: 425 | t.writeInt32(String) 426 | t.writeString(v.(string)) 427 | case byte: 428 | t.writeInt32(Byte) 429 | t.writeByte(v.(byte)) 430 | case []byte: 431 | t.writeInt32(Bytes) 432 | t.writeBytes(v.([]byte)) 433 | // case time.Time: 434 | default: 435 | return fmt.Errorf("Can't convert type %T to H2 Type", kind) 436 | } 437 | return nil 438 | } 439 | func (t *transfer) writeDatetimeValue(dt time.Time, mdp h2parameter) error { 440 | L(log.DebugLevel, "Date/time type: %d", mdp.kind) 441 | var err error 442 | switch mdp.kind { 443 | case Date: 444 | t.writeInt32(Date) 445 | bin := date2bin(&dt) 446 | err = t.writeInt64(bin) 447 | if err != nil { 448 | return err 449 | } 450 | case Timestamp: 451 | t.writeInt32(Timestamp) 452 | dateBin, nsecBin := ts2bin(&dt) 453 | err = t.writeInt64(dateBin) 454 | if err != nil { 455 | return err 456 | } 457 | err = t.writeInt64(nsecBin) 458 | if err != nil { 459 | return err 460 | } 461 | case TimestampTZ: 462 | t.writeInt32(TimestampTZ) 463 | dateBin, nsecBin, offsetTZBin := tsz2bin(&dt) 464 | err = t.writeInt64(dateBin) 465 | if err != nil { 466 | return err 467 | } 468 | err = t.writeInt64(nsecBin) 469 | if err != nil { 470 | return err 471 | } 472 | err = t.writeInt32(offsetTZBin) 473 | if err != nil { 474 | return err 475 | } 476 | case Time: 477 | t.writeInt32(Time) 478 | nsecBin := time2bin(&dt) 479 | err = t.writeInt64(nsecBin) 480 | if err != nil { 481 | return err 482 | } 483 | case TimeTZ: 484 | t.writeInt32(TimeTZQuery) 485 | nsecBin, offsetTZBin := timetz2bin(&dt) 486 | err = t.writeInt64(nsecBin) 487 | if err != nil { 488 | return err 489 | } 490 | err = t.writeInt32(offsetTZBin) 491 | if err != nil { 492 | return err 493 | } 494 | default: 495 | return fmt.Errorf("Datatype unsupported: %d", mdp.kind) 496 | } 497 | return nil 498 | } 499 | func (t *transfer) readBytesDef(n int) ([]byte, error) { 500 | 501 | buf := make([]byte, n) 502 | n2, err := t.buff.Read(buf) 503 | if err != nil { 504 | return nil, err 505 | } 506 | if n != n2 { 507 | return nil, errors.Errorf("Read byte size differs: %d != %d", n, n2) 508 | } 509 | return buf, nil 510 | 511 | } 512 | func (t *transfer) close() error { 513 | // TODO: check close 514 | return nil 515 | } 516 | 517 | // Helpers 518 | 519 | func date2bin(dt *time.Time) int64 { 520 | return int64((dt.Year() << 9) + (int(dt.Month()) << 5) + dt.Day()) 521 | } 522 | 523 | func bin2date(n int64) time.Time { 524 | day := int(n & 0x1f) 525 | month := time.Month((n >> 5) & 0xf) 526 | year := int(n >> 9) 527 | return time.Date(year, month, day, 0, 0, 0, 0, time.UTC) 528 | } 529 | 530 | func ts2bin(dt *time.Time) (int64, int64) { 531 | var nsecBin int64 532 | dateBin := date2bin(dt) 533 | nsecBin = int64(dt.Hour()*3600 + dt.Minute()*60 + dt.Second()) 534 | nsecBin *= int64(1e9) 535 | nsecBin += int64(dt.Nanosecond()) 536 | return dateBin, nsecBin 537 | } 538 | 539 | func bin2ts(dateBin int64, nsecBin int64) time.Time { 540 | // TODO: optimization 541 | day := int(dateBin & 0x1f) 542 | month := time.Month((dateBin >> 5) & 0xf) 543 | year := int(dateBin >> 9) 544 | nsecs := int(nsecBin % int64(1e9)) 545 | nsecBin = nsecBin / int64(1e9) 546 | sec := int(nsecBin % 60) 547 | nsecBin = nsecBin / 60 548 | minute := int(nsecBin % 60) 549 | hour := int(nsecBin / 60) 550 | return time.Date(year, month, day, hour, minute, sec, nsecs, time.UTC) 551 | } 552 | 553 | func bin2tsz(dateBin int64, nsecBin int64, secsTZ int32) time.Time { 554 | // TODO: optimization 555 | day := int(dateBin & 0x1f) 556 | month := time.Month((dateBin >> 5) & 0xf) 557 | year := int(dateBin >> 9) 558 | nsecs := int(nsecBin % int64(1e9)) 559 | nsecBin = nsecBin / int64(1e9) 560 | sec := int(nsecBin % 60) 561 | nsecBin = nsecBin / 60 562 | minute := int(nsecBin % 60) 563 | hour := int(nsecBin / 60) 564 | tz := time.FixedZone(fmt.Sprintf("tz_%d", secsTZ), int(secsTZ)) 565 | return time.Date(year, month, day, hour, minute, sec, nsecs, tz) 566 | } 567 | 568 | func tsz2bin(dt *time.Time) (int64, int64, int32) { 569 | var nsecBin int64 570 | dateBin := date2bin(dt) 571 | nsecBin = int64(dt.Hour()*3600 + dt.Minute()*60 + dt.Second()) 572 | nsecBin *= int64(1e9) 573 | nsecBin += int64(dt.Nanosecond()) 574 | _, offsetTZ := dt.Zone() 575 | return dateBin, nsecBin, int32(offsetTZ) 576 | } 577 | 578 | func time2bin(dt *time.Time) int64 { 579 | var nsecBin int64 580 | nsecBin = int64(dt.Hour()*3600 + dt.Minute()*60 + dt.Second()) 581 | nsecBin *= int64(1e9) 582 | nsecBin += int64(dt.Nanosecond()) 583 | return nsecBin 584 | } 585 | 586 | func bin2time(nsecBin int64) time.Time { 587 | // TODO: optimization 588 | nsecs := int(nsecBin % int64(1e9)) 589 | nsecBin = nsecBin / int64(1e9) 590 | sec := int(nsecBin % 60) 591 | nsecBin = nsecBin / 60 592 | minute := int(nsecBin % 60) 593 | hour := int(nsecBin / 60) 594 | return time.Date(0, 1, 1, hour, minute, sec, nsecs, time.UTC) 595 | } 596 | 597 | func (t *transfer) readTime() (time.Time, error) { 598 | nNsecs, err := t.readInt64() 599 | if err != nil { 600 | return time.Time{}, err 601 | } 602 | date := bin2time(nNsecs) 603 | return date, nil 604 | } 605 | 606 | func (t *transfer) readTimeTZ() (time.Time, error) { 607 | nNsecs, err := t.readInt64() 608 | if err != nil { 609 | return time.Time{}, err 610 | } 611 | nDiffTZ, err := t.readInt32() 612 | if err != nil { 613 | return time.Time{}, err 614 | } 615 | date := bin2timetz(nNsecs, nDiffTZ) 616 | return date, nil 617 | } 618 | 619 | func bin2timetz(nsecBin int64, secsTZ int32) time.Time { 620 | // TODO: optimization 621 | nsecs := int(nsecBin % int64(1e9)) 622 | nsecBin = nsecBin / int64(1e9) 623 | sec := int(nsecBin % 60) 624 | nsecBin = nsecBin / 60 625 | minute := int(nsecBin % 60) 626 | hour := int(nsecBin / 60) 627 | tz := time.FixedZone(fmt.Sprintf("tz_%d", secsTZ), int(secsTZ)) 628 | return time.Date(0, 1, 1, hour, minute, sec, nsecs, tz) 629 | } 630 | 631 | func timetz2bin(dt *time.Time) (int64, int32) { 632 | var nsecBin int64 633 | nsecBin = int64(dt.Hour()*3600 + dt.Minute()*60 + dt.Second()) 634 | nsecBin *= int64(1e9) 635 | nsecBin += int64(dt.Nanosecond()) 636 | _, offsetTZ := dt.Zone() 637 | return nsecBin, int32(offsetTZ) 638 | } 639 | --------------------------------------------------------------------------------