├── integration_tests ├── etc │ ├── catalog │ │ └── tpch.properties │ ├── jvm.config │ ├── node.properties │ └── config.properties ├── Dockerfile └── run.sh ├── .travis.yml ├── CODE_OF_CONDUCT.md ├── go.mod ├── CONTRIBUTING.md ├── presto ├── transaction.go ├── converters.go ├── serial_test.go ├── integration_tls_test.go ├── serial.go ├── transaction_test.go ├── integration_test.go ├── presto_test.go └── presto.go ├── go.sum ├── README.md └── LICENSE /integration_tests/etc/catalog/tpch.properties: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | connector.name=tpch 3 | -------------------------------------------------------------------------------- /integration_tests/etc/jvm.config: -------------------------------------------------------------------------------- 1 | -Xmx512m 2 | -XX:+UseG1GC 3 | -XX:G1HeapRegionSize=32M 4 | -XX:+UseGCOverheadLimit 5 | -XX:+ExplicitGCInvokesConcurrent 6 | -------------------------------------------------------------------------------- /integration_tests/etc/node.properties: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | node.environment=test 3 | node.id=test 4 | node.data-dir=/var/lib/presto/data 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.18.x 4 | services: 5 | - docker 6 | install: 7 | - go get -v golang.org/x/tools/cmd/cover 8 | - go get -v gopkg.in/jcmturner/gokrb5.v6/... 9 | script: 10 | - ./integration_tests/run.sh 11 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /integration_tests/etc/config.properties: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | coordinator=true 3 | node-scheduler.include-coordinator=true 4 | http-server.http.port=8080 5 | discovery-server.enabled=true 6 | discovery.uri=http://localhost:8080 7 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/prestodb/presto-go-client 2 | 3 | go 1.18 4 | 5 | require gopkg.in/jcmturner/gokrb5.v6 v6.1.1 6 | 7 | require ( 8 | github.com/hashicorp/go-uuid v1.0.2 // indirect 9 | github.com/jcmturner/gofork v1.0.0 // indirect 10 | github.com/stretchr/testify v1.5.1 // indirect 11 | golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d // indirect 12 | gopkg.in/jcmturner/aescts.v1 v1.0.1 // indirect 13 | gopkg.in/jcmturner/dnsutils.v1 v1.0.1 // indirect 14 | gopkg.in/jcmturner/goidentity.v3 v3.0.0 // indirect 15 | gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Presto 2 | 3 | ## Contributor License Agreement ("CLA") 4 | 5 | In order to accept your pull request, we need you to submit a CLA. You only need to do this once, so if you've done this for one repository in the [prestodb](https://github.com/prestodb) organization, you're good to go. If you are submitting a pull request for the first time, the communitybridge-easycla bot will notify you if you haven't signed, and will provide you with a link. If you are contributing on behalf of a company, you might want to let the person who manages your corporate CLA whitelist know they will be receiving a request from you. 6 | 7 | ## License 8 | 9 | By contributing to Presto, you agree that your contributions will be licensed under the [Apache License Version 2.0 (APLv2)](LICENSE). 10 | -------------------------------------------------------------------------------- /presto/transaction.go: -------------------------------------------------------------------------------- 1 | package presto 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "fmt" 8 | ) 9 | 10 | type driverTx struct { 11 | conn *Conn 12 | } 13 | 14 | func (t *driverTx) Commit() error { 15 | if t.conn == nil { 16 | return driver.ErrBadConn 17 | } 18 | 19 | ctx := context.Background() 20 | stmt := &driverStmt{conn: t.conn, query: "COMMIT"} 21 | _, err := stmt.QueryContext(ctx, []driver.NamedValue{}) 22 | if err != nil { 23 | return err 24 | } 25 | 26 | t.conn = nil 27 | return nil 28 | } 29 | 30 | func (t *driverTx) Rollback() error { 31 | if t.conn == nil { 32 | return driver.ErrBadConn 33 | } 34 | 35 | ctx := context.Background() 36 | stmt := &driverStmt{conn: t.conn, query: "ROLLBACK"} 37 | _, err := stmt.QueryContext(ctx, []driver.NamedValue{}) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | t.conn = nil 43 | return nil 44 | } 45 | 46 | func verifyIsolationLevel(level sql.IsolationLevel) error { 47 | switch level { 48 | case sql.LevelRepeatableRead, sql.LevelReadCommitted, sql.LevelReadUncommitted, sql.LevelSerializable: 49 | return nil 50 | default: 51 | return fmt.Errorf("presto: unsupported isolation level: %v", level) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /integration_tests/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM openjdk:8-jre 2 | EXPOSE 8080 3 | 4 | MAINTAINER Greg Leclercq "ggreg@fb.com" 5 | ARG PRESTO_VERSION=0.167 6 | ENV PRESTO_PKG presto-server-$PRESTO_VERSION.tar.gz 7 | ENV PRESTO_PKG_URL https://repo1.maven.org/maven2/com/facebook/presto/presto-server/$PRESTO_VERSION/$PRESTO_PKG 8 | 9 | ENV PRESTO_CLI_JAR_URL https://repo1.maven.org/maven2/com/facebook/presto/presto-cli/$PRESTO_VERSION/presto-cli-$PRESTO_VERSION-executable.jar 10 | 11 | 12 | # Install python to run the launcher script 13 | RUN apt-get update 14 | RUN apt-get install -y python less 15 | 16 | # Download Presto package 17 | # Use curl rather ADD to leverage RUN caching 18 | # Let curl show progress bar to prevent Travis from thinking the job is stalled 19 | RUN curl -o /$PRESTO_PKG $PRESTO_PKG_URL 20 | RUN tar -zxf /$PRESTO_PKG 21 | 22 | # Create directory for Presto data 23 | RUN mkdir -p /var/lib/presto/data 24 | 25 | # Add Presto configuration 26 | WORKDIR /presto-server-$PRESTO_VERSION 27 | RUN mkdir etc 28 | ADD etc/jvm.config etc/ 29 | ADD etc/config.properties etc/ 30 | ADD etc/node.properties etc/ 31 | ADD etc/catalog etc/catalog 32 | 33 | # Download Presto cli 34 | RUN mkdir -p bin 35 | RUN curl -o bin/presto-cli $PRESTO_CLI_JAR_URL 36 | RUN chmod +x bin/presto-cli 37 | 38 | CMD bin/launcher.py run 39 | -------------------------------------------------------------------------------- /integration_tests/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | LOCAL_PORT=8080 5 | IMAGE_NAME=presto/test_server 6 | 7 | cd "$( dirname "${BASH_SOURCE[0]}" )" 8 | 9 | function test_container() { 10 | echo `docker ps | grep $IMAGE_NAME | cut -d\ -f1` 11 | } 12 | 13 | function test_cleanup() { 14 | local id=`test_container` 15 | [ -n "$id" ] && docker rm -f $id 16 | #docker rmi $IMAGE_NAME 17 | } 18 | 19 | trap test_cleanup EXIT 20 | 21 | function test_build() { 22 | local image=`docker images | grep $IMAGE_NAME` 23 | [ -z "$image" ] && docker build -t $IMAGE_NAME . 24 | } 25 | 26 | function test_query() { 27 | docker exec -t -i `test_container` bin/presto-cli --server localhost:${LOCAL_PORT} --execute "$*" 28 | } 29 | 30 | test_build 31 | docker run -p ${LOCAL_PORT}:${LOCAL_PORT} --rm -d $IMAGE_NAME 32 | 33 | attempts=10 34 | while [ $attempts -gt 0 ] 35 | do 36 | attempts=`expr $attempts - 1` 37 | ready=`test_query "SHOW SESSION" | grep task_writer_count` 38 | [ ! -z "$ready" ] && break 39 | echo "waiting for presto..." 40 | sleep 2 41 | done 42 | 43 | if [ $attempts -eq 0 ] 44 | then 45 | echo "timed out waiting for presto" 46 | exit 1 47 | fi 48 | 49 | PKG=../presto 50 | DSN=http://test@localhost:${LOCAL_PORT} 51 | go test -v -cover -coverprofile=coverage.out $PKG -presto_server_dsn=$DSN $* 52 | -------------------------------------------------------------------------------- /presto/converters.go: -------------------------------------------------------------------------------- 1 | package presto 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | "fmt" 7 | ) 8 | 9 | type rowConverter struct { 10 | fields []string 11 | converters []driver.ValueConverter 12 | } 13 | 14 | func (c *rowConverter) typeName() string { 15 | return "row" 16 | } 17 | 18 | // ConvertValue implements driver.ValueConverter interface to provide 19 | // conversion for row column types. The resulting value will be a 20 | // map[string]any. 21 | func (c *rowConverter) ConvertValue(v any) (driver.Value, error) { 22 | if v == nil { 23 | return nil, nil 24 | } 25 | vs, ok := v.([]any) 26 | if !ok { 27 | return nil, fmt.Errorf("presto: row converter needs []any and received %T", v) 28 | } 29 | if len(vs) != len(c.fields) { 30 | return nil, fmt.Errorf("presto: row converter has wrong number of elements: %d, expected: %d", len(vs), len(c.fields)) 31 | } 32 | res := make(map[string]any) 33 | for i, f := range c.fields { 34 | if vs[i] == nil { 35 | continue 36 | } 37 | 38 | sub, err := c.converters[i].ConvertValue(vs[i]) 39 | if err != nil { 40 | return nil, fmt.Errorf("presto: converting sub property of row: %w", err) 41 | } 42 | if sub != nil { 43 | res[f] = sub 44 | } 45 | } 46 | return res, nil 47 | } 48 | 49 | func newComplexConverter(ts typeSignature) (driver.ValueConverter, error) { 50 | if ts.RawType != "row" { 51 | return newTypeConverter(ts.RawType), nil 52 | } 53 | 54 | var c rowConverter 55 | // Field names. 56 | for _, fd := range ts.LiteralArguments { 57 | var fn string 58 | if err := json.Unmarshal(fd, &fn); err != nil { 59 | return nil, fmt.Errorf("presto: parsing field name for row converter: %w", err) 60 | } 61 | c.fields = append(c.fields, fn) 62 | } 63 | // Field converters. 64 | for _, tas := range ts.TypeArguments { 65 | var fts typeSignature 66 | if err := json.Unmarshal(tas, &fts); err != nil { 67 | return nil, fmt.Errorf("presto: parsing field type for row converter: %w", err) 68 | } 69 | conv, err := newComplexConverter(fts) 70 | if err != nil { 71 | return nil, fmt.Errorf("presto: creating nested converted for row converter: %w", err) 72 | } 73 | c.converters = append(c.converters, conv) 74 | } 75 | return &c, nil 76 | } 77 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= 4 | github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= 5 | github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= 6 | github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= 7 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 10 | github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= 11 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 12 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 13 | golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d h1:1ZiEyfaQIg3Qh0EoqpwAakHVhecoE5wlSg5GjnafJGw= 14 | golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 15 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 16 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 17 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 18 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 19 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 20 | gopkg.in/jcmturner/aescts.v1 v1.0.1 h1:cVVZBK2b1zY26haWB4vbBiZrfFQnfbTVrE3xZq6hrEw= 21 | gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= 22 | gopkg.in/jcmturner/dnsutils.v1 v1.0.1 h1:cIuC1OLRGZrld+16ZJvvZxVJeKPsvd5eUIvxfoN5hSM= 23 | gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= 24 | gopkg.in/jcmturner/goidentity.v3 v3.0.0 h1:1duIyWiTaYvVx3YX2CYtpJbUFd7/UuPYCfgXtQ3VTbI= 25 | gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4= 26 | gopkg.in/jcmturner/gokrb5.v6 v6.1.1 h1:n0KFjpbuM5pFMN38/Ay+Br3l91netGSVqHPHEXeWUqk= 27 | gopkg.in/jcmturner/gokrb5.v6 v6.1.1/go.mod h1:NFjHNLrHQiruory+EmqDXCGv6CrjkeYeA+bR9mIfNFk= 28 | gopkg.in/jcmturner/rpc.v1 v1.1.0 h1:QHIUxTX1ISuAv9dD2wJ9HWQVuWDX/Zc0PfeC2tjc4rU= 29 | gopkg.in/jcmturner/rpc.v1 v1.1.0/go.mod h1:YIdkC4XfD6GXbzje11McwsDuOlZQSb9W4vfLvuNnlv8= 30 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 31 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 32 | -------------------------------------------------------------------------------- /presto/serial_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package presto 16 | 17 | import "testing" 18 | 19 | func TestSerial(t *testing.T) { 20 | scenarios := []struct { 21 | name string 22 | value interface{} 23 | expectedError bool 24 | expectedSerial string 25 | }{ 26 | { 27 | name: "basic string", 28 | value: "hello world", 29 | expectedSerial: `'hello world'`, 30 | }, 31 | { 32 | name: "single quoted string", 33 | value: "hello world's", 34 | expectedSerial: `'hello world''s'`, 35 | }, 36 | { 37 | name: "double quoted string", 38 | value: `hello "world"`, 39 | expectedSerial: `'hello "world"'`, 40 | }, 41 | { 42 | name: "int8", 43 | value: int8(100), 44 | expectedSerial: "100", 45 | }, 46 | { 47 | name: "int16", 48 | value: int16(100), 49 | expectedSerial: "100", 50 | }, 51 | { 52 | name: "int32", 53 | value: int32(100), 54 | expectedSerial: "100", 55 | }, 56 | { 57 | name: "int", 58 | value: int(100), 59 | expectedSerial: "100", 60 | }, 61 | { 62 | name: "int64", 63 | value: int64(100), 64 | expectedSerial: "100", 65 | }, 66 | { 67 | name: "uint8", 68 | value: uint8(100), 69 | expectedError: true, 70 | }, 71 | { 72 | name: "uint16", 73 | value: uint16(100), 74 | expectedSerial: "100", 75 | }, 76 | { 77 | name: "uint32", 78 | value: uint32(100), 79 | expectedSerial: "100", 80 | }, 81 | { 82 | name: "uint", 83 | value: uint(100), 84 | expectedSerial: "100", 85 | }, 86 | { 87 | name: "uint64", 88 | value: uint64(100), 89 | expectedSerial: "100", 90 | }, 91 | { 92 | name: "byte", 93 | value: byte('a'), 94 | expectedError: true, 95 | }, 96 | { 97 | name: "valid Numeric", 98 | value: Numeric("10"), 99 | expectedSerial: "10", 100 | }, 101 | { 102 | name: "invalid Numeric", 103 | value: Numeric("not-a-number"), 104 | expectedError: true, 105 | }, 106 | { 107 | name: "bool true", 108 | value: true, 109 | expectedSerial: "true", 110 | }, 111 | { 112 | name: "bool false", 113 | value: false, 114 | expectedSerial: "false", 115 | }, 116 | { 117 | name: "nil", 118 | value: nil, 119 | expectedError: true, 120 | }, 121 | { 122 | name: "slice typed nil", 123 | value: []interface{}(nil), 124 | expectedError: true, 125 | }, 126 | { 127 | name: "valid slice", 128 | value: []interface{}{1, 2}, 129 | expectedSerial: "ARRAY[1, 2]", 130 | }, 131 | { 132 | name: "valid empty", 133 | value: []interface{}{}, 134 | expectedSerial: "ARRAY[]", 135 | }, 136 | { 137 | name: "invalid slice contents", 138 | value: []interface{}{1, byte('a')}, 139 | expectedError: true, 140 | }, 141 | } 142 | 143 | for i := range scenarios { 144 | scenario := scenarios[i] 145 | 146 | t.Run(scenario.name, func(t *testing.T) { 147 | s, err := Serial(scenario.value) 148 | if err != nil { 149 | if scenario.expectedError { 150 | return 151 | } 152 | t.Fatal(err) 153 | } 154 | 155 | if scenario.expectedError { 156 | t.Fatal("missing an expected error") 157 | } 158 | 159 | if scenario.expectedSerial != s { 160 | t.Fatalf("mismatched serial, got %q expected %q", s, scenario.expectedSerial) 161 | } 162 | }) 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /presto/integration_tls_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // +build go1.9 16 | 17 | package presto 18 | 19 | import ( 20 | "bytes" 21 | "crypto/tls" 22 | "database/sql" 23 | "io" 24 | "io/ioutil" 25 | "net/http" 26 | "net/http/httptest" 27 | "net/url" 28 | "strings" 29 | "sync" 30 | "testing" 31 | ) 32 | 33 | func TestIntegrationTLS(t *testing.T) { 34 | proxyServer := newTLSReverseProxy(t) 35 | defer proxyServer.Close() 36 | RegisterCustomClient("test_tls", proxyServer.Client()) 37 | defer DeregisterCustomClient("test_tls") 38 | dsn := proxyServer.URL + "?custom_client=test_tls" 39 | testSimpleQuery(t, dsn) 40 | } 41 | 42 | func TestIntegrationInsecureTLS(t *testing.T) { 43 | proxyServer := newTLSReverseProxy(t) 44 | defer proxyServer.Close() 45 | RegisterCustomClient("test_insecure_tls", &http.Client{ 46 | Transport: &http.Transport{ 47 | TLSClientConfig: &tls.Config{ 48 | InsecureSkipVerify: true, 49 | }, 50 | }, 51 | }) 52 | defer DeregisterCustomClient("test_insecure_tls") 53 | dsn := proxyServer.URL + "?custom_client=test_insecure_tls" 54 | testSimpleQuery(t, dsn) 55 | } 56 | 57 | func testSimpleQuery(t *testing.T, dsn string) { 58 | db, err := sql.Open("presto", dsn) 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | defer db.Close() 63 | row := db.QueryRow("SELECT 1") 64 | var count int 65 | if err = row.Scan(&count); err != nil { 66 | t.Fatal(err) 67 | } 68 | if count != 1 { 69 | t.Fatal("unexpected count=", count) 70 | } 71 | } 72 | 73 | // hax0r reverse tls proxy for integration tests 74 | 75 | // newTLSReverseProxy creates a TLS integration test server. 76 | func newTLSReverseProxy(t *testing.T) *httptest.Server { 77 | dsn := integrationServerDSN(t) 78 | prestoURL, _ := url.Parse(dsn) 79 | cproxyURL := make(chan string, 1) 80 | handler := newReverseProxyHandler(prestoURL, cproxyURL) 81 | srv := httptest.NewTLSServer(http.HandlerFunc(handler)) 82 | cproxyURL <- srv.URL 83 | close(cproxyURL) 84 | proxyURL, _ := url.Parse(srv.URL) 85 | proxyURL.User = prestoURL.User 86 | proxyURL.Path = prestoURL.Path 87 | proxyURL.RawPath = prestoURL.RawPath 88 | proxyURL.RawQuery = prestoURL.RawQuery 89 | srv.URL = proxyURL.String() 90 | return srv 91 | } 92 | 93 | // newReverseProxyHandler creates an http handler that proxies requests to the given prestoURL, and replaces URLs in responses with the first value sent to the cproxyURL channel. 94 | func newReverseProxyHandler(prestoURL *url.URL, cproxyURL chan string) http.HandlerFunc { 95 | baseURL := []byte(prestoURL.Scheme + "://" + prestoURL.Host) 96 | var proxyURL []byte 97 | var onceProxyURL sync.Once 98 | return func(w http.ResponseWriter, r *http.Request) { 99 | onceProxyURL.Do(func() { 100 | proxyURL = []byte(<-cproxyURL) 101 | }) 102 | target := *prestoURL 103 | target.User = nil 104 | target.Path = r.URL.Path 105 | target.RawPath = r.URL.RawPath 106 | target.RawQuery = r.URL.RawQuery 107 | req, err := http.NewRequest(r.Method, target.String(), r.Body) 108 | if err != nil { 109 | http.Error(w, err.Error(), http.StatusServiceUnavailable) 110 | return 111 | } 112 | for k, v := range r.Header { 113 | if strings.HasPrefix(k, "X-") { 114 | req.Header[k] = v 115 | } 116 | } 117 | client := *http.DefaultClient 118 | client.Timeout = *integrationServerQueryTimeout 119 | resp, err := client.Do(req) 120 | if err != nil { 121 | http.Error(w, err.Error(), http.StatusServiceUnavailable) 122 | return 123 | } 124 | defer resp.Body.Close() 125 | w.WriteHeader(resp.StatusCode) 126 | pr, pw := io.Pipe() 127 | go func() { 128 | b, err := ioutil.ReadAll(resp.Body) 129 | if err != nil { 130 | pw.CloseWithError(err) 131 | return 132 | } 133 | b = bytes.Replace(b, baseURL, proxyURL, -1) 134 | pw.Write(b) 135 | pw.Close() 136 | }() 137 | io.Copy(w, pr) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /presto/serial.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package presto 16 | 17 | import ( 18 | "encoding/json" 19 | "fmt" 20 | "reflect" 21 | "strconv" 22 | "strings" 23 | "time" 24 | ) 25 | 26 | type UnsupportedArgError struct { 27 | t string 28 | } 29 | 30 | func (e UnsupportedArgError) Error() string { 31 | return fmt.Sprintf("presto: unsupported arg type: %s", e.t) 32 | } 33 | 34 | // Numeric is a string representation of a number, such as "10", "5.5" or in scientific form 35 | // If another string format is used it will error to serialise 36 | type Numeric string 37 | 38 | // Serial converts any supported value to its equivalent string for as a presto parameter 39 | // See https://prestodb.io/docs/current/language/types.html 40 | func Serial(v interface{}) (string, error) { 41 | switch x := v.(type) { 42 | case nil: 43 | return "", UnsupportedArgError{""} 44 | 45 | // numbers convertible to int 46 | case int8: 47 | return strconv.Itoa(int(x)), nil 48 | case int16: 49 | return strconv.Itoa(int(x)), nil 50 | case int32: 51 | return strconv.Itoa(int(x)), nil 52 | case int: 53 | return strconv.Itoa(x), nil 54 | case uint16: 55 | return strconv.Itoa(int(x)), nil 56 | 57 | case int64: 58 | return strconv.FormatInt(x, 10), nil 59 | 60 | case uint32: 61 | return strconv.FormatUint(uint64(x), 10), nil 62 | case uint: 63 | return strconv.FormatUint(uint64(x), 10), nil 64 | case uint64: 65 | return strconv.FormatUint(x, 10), nil 66 | 67 | // float32, float64 not supported because digit precision will easily cause large problems 68 | case float32: 69 | return "", UnsupportedArgError{"float32"} 70 | case float64: 71 | return "", UnsupportedArgError{"float64"} 72 | 73 | case Numeric: 74 | if _, err := strconv.ParseFloat(string(x), 64); err != nil { 75 | return "", err 76 | } 77 | return string(x), nil 78 | 79 | // note byte and uint are not supported, this is because byte is an alias for uint8 80 | // if you were to use uint8 (as a number) it could be interpreted as a byte, so it is unsupported 81 | // use string instead of byte and any other uint/int type for uint8 82 | case byte: 83 | return "", UnsupportedArgError{"byte/uint8"} 84 | 85 | case bool: 86 | return strconv.FormatBool(x), nil 87 | 88 | case string: 89 | return "'" + strings.Replace(x, "'", "''", -1) + "'", nil 90 | 91 | // TODO - []byte should probably be matched to 'VARBINARY' in presto 92 | case []byte: 93 | return "", UnsupportedArgError{"[]byte"} 94 | 95 | // time.Time and time.Duration not supported as time and date take several different formats in presto 96 | case time.Time: 97 | return "", UnsupportedArgError{"time.Time"} 98 | case time.Duration: 99 | return "", UnsupportedArgError{"time.Duration"} 100 | 101 | // TODO - json.RawMesssage should probably be matched to 'JSON' in presto 102 | case json.RawMessage: 103 | return "", UnsupportedArgError{"json.RawMessage"} 104 | } 105 | 106 | if reflect.TypeOf(v).Kind() == reflect.Slice { 107 | x := reflect.ValueOf(v) 108 | if x.IsNil() { 109 | return "", UnsupportedArgError{"[]"} 110 | } 111 | 112 | slice := make([]interface{}, x.Len()) 113 | 114 | for i := 0; i < x.Len(); i++ { 115 | slice[i] = x.Index(i).Interface() 116 | } 117 | 118 | return serialSlice(slice) 119 | } 120 | 121 | if reflect.TypeOf(v).Kind() == reflect.Map { 122 | // are presto MAPs indifferent to order? Golang maps are, if presto aren't then the two types can't be compatible 123 | return "", UnsupportedArgError{"map"} 124 | } 125 | 126 | // TODO - consider the remaining types in https://prestodb.io/docs/current/language/types.html (Row, IP, ...) 127 | 128 | return "", UnsupportedArgError{fmt.Sprintf("%T", v)} 129 | } 130 | 131 | func serialSlice(v []interface{}) (string, error) { 132 | ss := make([]string, len(v)) 133 | 134 | for i, x := range v { 135 | s, err := Serial(x) 136 | if err != nil { 137 | return "", err 138 | } 139 | ss[i] = s 140 | } 141 | 142 | return "ARRAY[" + strings.Join(ss, ", ") + "]", nil 143 | } 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Presto client 2 | 3 | A [Presto](https://prestodb.io) client for the [Go](https://golang.org) programming language. 4 | 5 | [![Build Status](https://secure.travis-ci.org/prestodb/presto-go-client.png)](http://travis-ci.org/prestodb/presto-go-client) 6 | [![GoDoc](https://godoc.org/github.com/prestodb/presto-go-client?status.svg)](https://godoc.org/github.com/prestodb/presto-go-client) 7 | 8 | ## Features 9 | 10 | * Native Go implementation 11 | * Connections over HTTP or HTTPS 12 | * HTTP Basic and Kerberos authentication 13 | * Per-query user information for access control 14 | * Support custom HTTP client (tunable conn pools, timeouts, TLS) 15 | * Supports conversion from Presto to native Go data types 16 | * `string`, `sql.NullString` 17 | * `int64`, `presto.NullInt64` 18 | * `float64`, `presto.NullFloat64` 19 | * `map`, `presto.NullMap` 20 | * `time.Time`, `presto.NullTime` 21 | * Up to 3-dimensional arrays to Go slices, of any supported type 22 | 23 | ## Requirements 24 | 25 | * Go 1.18 or newer 26 | * Presto 0.16x or newer 27 | 28 | ## Installation 29 | 30 | You need a working environment with Go installed and $GOPATH set. 31 | 32 | Download and install presto database/sql driver: 33 | 34 | ```bash 35 | go get github.com/prestodb/presto-go-client/presto 36 | ``` 37 | 38 | Make sure you have Git installed and in your $PATH. 39 | 40 | ## Usage 41 | 42 | This Presto client is an implementation of Go's `database/sql/driver` interface. In order to use it, you need to import the package and use the [`database/sql`](https://golang.org/pkg/database/sql/) API then. 43 | 44 | Only read operations are supported, such as SHOW and SELECT. 45 | 46 | Use `presto` as `driverName` and a valid [DSN](#dsn-data-source-name) as the `dataSourceName`. 47 | 48 | Example: 49 | 50 | ```go 51 | import "database/sql" 52 | import _ "github.com/prestodb/presto-go-client/presto" 53 | 54 | dsn := "http://user@localhost:8080?catalog=default&schema=test" 55 | db, err := sql.Open("presto", dsn) 56 | ``` 57 | 58 | ### Authentication 59 | 60 | HTTP Basic, Kerberos, and JWT authentication are supported. 61 | 62 | #### HTTP Basic authentication 63 | 64 | If the DSN contains a password, the client enables HTTP Basic authentication by setting the `Authorization` header in every request to presto. 65 | 66 | HTTP Basic authentication **is only supported on encrypted connections over HTTPS**. 67 | 68 | #### Kerberos authentication 69 | 70 | This driver supports Kerberos authentication by setting up the Kerberos fields in the [Config](https://godoc.org/github.com/prestodb/presto-go-client/presto#Config) struct. 71 | 72 | Please refer to the [Coordinator Kerberos Authentication](https://prestodb.io/docs/current/security/server.html) for server-side configuration. 73 | 74 | #### JWT authentication 75 | 76 | This driver supports JWT authentication by setting the `AccessToken` field in the configuration. Add the query parameter with the JWT bearer token to be used for authentication. This token will then be sent as a bearer token for all HTTP requests. 77 | 78 | This authentication method has lower precedence than HTTP basic authentication. 79 | 80 | #### System access control and per-query user information 81 | 82 | It's possible to pass user information to presto, different from the principal used to authenticate to the coordinator. See the [System Access Control](https://prestodb.io/docs/current/develop/system-access-control.html) documentation for details. 83 | 84 | In order to pass user information in queries to presto, you have to add a [NamedArg](https://godoc.org/database/sql#NamedArg) to the query parameters where the key is X-Presto-User. This parameter is used by the driver to inform presto about the user executing the query regardless of the authentication method for the actual connection, and its value is NOT passed to the query. 85 | 86 | Example: 87 | 88 | ```go 89 | db.Query("SELECT * FROM foobar WHERE id=?", 1, sql.Named("X-Presto-User", string("Alice"))) 90 | ``` 91 | 92 | The position of the X-Presto-User NamedArg is irrelevant and does not affect the query in any way. 93 | 94 | ### DSN (Data Source Name) 95 | 96 | The Data Source Name is a URL with a mandatory username, and optional query string parameters that are supported by this driver, in the following format: 97 | 98 | ``` 99 | http[s]://user[:pass]@host[:port][?parameters] 100 | ``` 101 | 102 | The easiest way to build your DSN is by using the [Config.FormatDSN](https://godoc.org/github.com/prestodb/presto-go-client/presto#Config.FormatDSN) helper function. 103 | 104 | The driver supports both HTTP and HTTPS. If you use HTTPS it's recommended that you also provide a custom `http.Client` that can validate (or skip) the security checks of the server certificate, and/or to configure TLS client authentication. 105 | 106 | #### Parameters 107 | 108 | *Parameters are case-sensitive* 109 | 110 | Refer to the [Presto Concepts](https://prestodb.io/docs/current/overview/concepts.html) documentation for more information. 111 | 112 | ##### `source` 113 | 114 | ``` 115 | Type: string 116 | Valid values: string describing the source of the connection to presto 117 | Default: empty 118 | ``` 119 | 120 | The `source` parameter is optional, but if used, can help presto admins troubleshoot queries and trace them back to the original client. 121 | 122 | ##### `catalog` 123 | 124 | ``` 125 | Type: string 126 | Valid values: the name of a catalog configured in the presto server 127 | Default: empty 128 | ``` 129 | 130 | The `catalog` parameter defines the presto catalog where schemas exist to organize tables. 131 | 132 | ##### `schema` 133 | 134 | ``` 135 | Type: string 136 | Valid values: the name of an existing schema in the catalog 137 | Default: empty 138 | ``` 139 | 140 | The `schema` parameter defines the presto schema where tables exist. This is also known as namespace in some environments. 141 | 142 | ##### `session_properties` 143 | 144 | ``` 145 | Type: string 146 | Valid values: comma-separated list of key=value session properties 147 | Default: empty 148 | ``` 149 | 150 | The `session_properties` parameter must contain valid parameters accepted by the presto server. Run `SHOW SESSION` in presto to get the current list. 151 | 152 | ##### `custom_client` 153 | 154 | ``` 155 | Type: string 156 | Valid values: the name of a client previously registered to the driver 157 | Default: empty (defaults to http.DefaultClient) 158 | ``` 159 | 160 | The `custom_client` parameter allows the use of custom `http.Client` for the communication with presto. 161 | 162 | Register your custom client in the driver, then refer to it by name in the DSN, on the call to `sql.Open`: 163 | 164 | ```go 165 | foobarClient := &http.Client{ 166 | Transport: &http.Transport{ 167 | Proxy: http.ProxyFromEnvironment, 168 | DialContext: (&net.Dialer{ 169 | Timeout: 30 * time.Second, 170 | KeepAlive: 30 * time.Second, 171 | DualStack: true, 172 | }).DialContext, 173 | MaxIdleConns: 100, 174 | IdleConnTimeout: 90 * time.Second, 175 | TLSHandshakeTimeout: 10 * time.Second, 176 | ExpectContinueTimeout: 1 * time.Second, 177 | TLSClientConfig: &tls.Config{ 178 | // your config here... 179 | }, 180 | }, 181 | } 182 | presto.RegisterCustomClient("foobar", foobarClient) 183 | db, err := sql.Open("presto", "https://user@localhost:8080?custom_client=foobar") 184 | ``` 185 | 186 | #### Examples 187 | 188 | ``` 189 | http://user@localhost:8080?source=hello&catalog=default&schema=foobar 190 | ``` 191 | 192 | ``` 193 | https://user@localhost:8443?session_properties=query_max_run_time=10m,query_priority=2 194 | ``` 195 | 196 | ## License 197 | 198 | As described in the [LICENSE](./LICENSE) file. 199 | -------------------------------------------------------------------------------- /presto/transaction_test.go: -------------------------------------------------------------------------------- 1 | package presto 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "encoding/json" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | ) 13 | 14 | type queryHandler struct { 15 | url string 16 | body string 17 | handler func(w http.ResponseWriter, r *http.Request) (string, error) 18 | matched bool 19 | } 20 | 21 | type testServer struct { 22 | expectedQueries []*queryHandler 23 | } 24 | 25 | func (srv *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 26 | bodyBytes, err := ioutil.ReadAll(r.Body) 27 | if err != nil { 28 | w.WriteHeader(http.StatusBadRequest) 29 | json.NewEncoder(w).Encode(&stmtResponse{ 30 | Error: stmtError{ 31 | ErrorName: "BAD QUERY", 32 | }, 33 | }) 34 | return 35 | } 36 | 37 | var nextURI string 38 | body := string(bodyBytes) 39 | err = fmt.Errorf("unexpected query %s", body) 40 | for _, query := range srv.expectedQueries { 41 | if query.url == r.RequestURI && query.body == body { 42 | query.matched = true 43 | nextURI, err = query.handler(w, r) 44 | break 45 | } 46 | } 47 | 48 | if err != nil { 49 | w.WriteHeader(http.StatusBadRequest) 50 | json.NewEncoder(w).Encode(&stmtResponse{ 51 | Error: stmtError{ 52 | ErrorName: err.Error(), 53 | }, 54 | }) 55 | return 56 | } 57 | 58 | w.WriteHeader(http.StatusOK) 59 | json.NewEncoder(w).Encode(&stmtResponse{ 60 | ID: "id", 61 | NextURI: nextURI, 62 | }) 63 | } 64 | 65 | func (srv *testServer) verifyExpectedQueries() error { 66 | for _, query := range srv.expectedQueries { 67 | if !query.matched { 68 | return fmt.Errorf("expected query not matched. url: %s, body: %s", query.body, query.url) 69 | } 70 | } 71 | 72 | return nil 73 | } 74 | 75 | func checkRequestTransactionHeader(r *http.Request, id string) error { 76 | headerValue := r.Header.Get(prestoTransactionHeader) 77 | if headerValue == id { 78 | return nil 79 | } 80 | 81 | return fmt.Errorf("unexpected transaction id in header. got: %s, expected: %s", headerValue, id) 82 | } 83 | 84 | func TestTransactionCommit(t *testing.T) { 85 | server := &testServer{} 86 | ts := httptest.NewServer(server) 87 | defer ts.Close() 88 | 89 | transactionID := "123" 90 | server.expectedQueries = []*queryHandler{ 91 | { 92 | url: "/v1/statement", 93 | body: "START TRANSACTION READ ONLY, ISOLATION LEVEL Read Uncommitted", 94 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 95 | if err := checkRequestTransactionHeader(r, "NONE"); err != nil { 96 | return "", err 97 | } 98 | 99 | return fmt.Sprintf("%s/%s", ts.URL, "start"), nil 100 | }, 101 | }, 102 | { 103 | url: "/start", 104 | body: "", 105 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 106 | if err := checkRequestTransactionHeader(r, "NONE"); err != nil { 107 | return "", err 108 | } 109 | 110 | w.Header().Set(prestoStartedTransactionHeader, transactionID) 111 | return "", nil 112 | }, 113 | }, 114 | { 115 | url: "/v1/statement", 116 | body: "SELECT * FROM TransactionTable", 117 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 118 | if err := checkRequestTransactionHeader(r, transactionID); err != nil { 119 | return "", err 120 | } 121 | 122 | return fmt.Sprintf("%s/%s", ts.URL, "select_transaction"), nil 123 | }, 124 | }, 125 | { 126 | url: "/select_transaction", 127 | body: "", 128 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 129 | if err := checkRequestTransactionHeader(r, transactionID); err != nil { 130 | return "", err 131 | } 132 | 133 | return "", nil 134 | }, 135 | }, 136 | { 137 | url: "/v1/statement", 138 | body: "COMMIT", 139 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 140 | if err := checkRequestTransactionHeader(r, transactionID); err != nil { 141 | return "", err 142 | } 143 | 144 | return fmt.Sprintf("%s/%s", ts.URL, "commit"), nil 145 | }, 146 | }, 147 | { 148 | url: "/commit", 149 | body: "", 150 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 151 | if err := checkRequestTransactionHeader(r, transactionID); err != nil { 152 | return "", err 153 | } 154 | 155 | w.Header().Set(prestoClearTransactionHeader, "true") 156 | return "", nil 157 | }, 158 | }, 159 | { 160 | url: "/v1/statement", 161 | body: "SELECT * FROM NoTransactionTable", 162 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 163 | if err := checkRequestTransactionHeader(r, ""); err != nil { 164 | return "", err 165 | } 166 | 167 | return fmt.Sprintf("%s/%s", ts.URL, "select_no_transaction"), nil 168 | }, 169 | }, 170 | { 171 | url: "/select_no_transaction", 172 | body: "", 173 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 174 | if err := checkRequestTransactionHeader(r, ""); err != nil { 175 | return "", err 176 | } 177 | 178 | return "", nil 179 | }, 180 | }, 181 | } 182 | 183 | db, err := sql.Open("presto", ts.URL) 184 | if err != nil { 185 | t.Fatal(err) 186 | } 187 | defer db.Close() 188 | 189 | tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true, Isolation: sql.LevelReadUncommitted}) 190 | if err != nil { 191 | t.Fatal(err.Error()) 192 | } 193 | 194 | _, err = tx.Query("SELECT * FROM TransactionTable") 195 | if err != nil { 196 | t.Fatal(err.Error()) 197 | } 198 | 199 | err = tx.Commit() 200 | if err != nil { 201 | t.Fatal(err.Error()) 202 | } 203 | 204 | _, err = db.Query("SELECT * FROM NoTransactionTable") 205 | if err != nil { 206 | t.Fatal(err.Error()) 207 | } 208 | 209 | err = server.verifyExpectedQueries() 210 | if err != nil { 211 | t.Fatal(err.Error()) 212 | } 213 | } 214 | 215 | func TestTransactionRollback(t *testing.T) { 216 | server := &testServer{} 217 | ts := httptest.NewServer(server) 218 | defer ts.Close() 219 | 220 | transactionID := "123" 221 | server.expectedQueries = []*queryHandler{ 222 | { 223 | url: "/v1/statement", 224 | body: "START TRANSACTION READ ONLY, ISOLATION LEVEL Read Uncommitted", 225 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 226 | if err := checkRequestTransactionHeader(r, "NONE"); err != nil { 227 | return "", err 228 | } 229 | 230 | return fmt.Sprintf("%s/%s", ts.URL, "start"), nil 231 | }, 232 | }, 233 | { 234 | url: "/start", 235 | body: "", 236 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 237 | if err := checkRequestTransactionHeader(r, "NONE"); err != nil { 238 | return "", err 239 | } 240 | 241 | w.Header().Set(prestoStartedTransactionHeader, transactionID) 242 | return "", nil 243 | }, 244 | }, 245 | { 246 | url: "/v1/statement", 247 | body: "SELECT * FROM TransactionTable", 248 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 249 | if err := checkRequestTransactionHeader(r, transactionID); err != nil { 250 | return "", err 251 | } 252 | 253 | return fmt.Sprintf("%s/%s", ts.URL, "select_transaction"), nil 254 | }, 255 | }, 256 | { 257 | url: "/select_transaction", 258 | body: "", 259 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 260 | if err := checkRequestTransactionHeader(r, transactionID); err != nil { 261 | return "", err 262 | } 263 | 264 | return "", nil 265 | }, 266 | }, 267 | { 268 | url: "/v1/statement", 269 | body: "ROLLBACK", 270 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 271 | if err := checkRequestTransactionHeader(r, transactionID); err != nil { 272 | return "", err 273 | } 274 | 275 | return fmt.Sprintf("%s/%s", ts.URL, "rollback"), nil 276 | }, 277 | }, 278 | { 279 | url: "/rollback", 280 | body: "", 281 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 282 | if err := checkRequestTransactionHeader(r, transactionID); err != nil { 283 | return "", err 284 | } 285 | 286 | w.Header().Set(prestoClearTransactionHeader, "true") 287 | return "", nil 288 | }, 289 | }, 290 | { 291 | url: "/v1/statement", 292 | body: "SELECT * FROM NoTransactionTable", 293 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 294 | if err := checkRequestTransactionHeader(r, ""); err != nil { 295 | return "", err 296 | } 297 | 298 | return fmt.Sprintf("%s/%s", ts.URL, "select_no_transaction"), nil 299 | }, 300 | }, 301 | { 302 | url: "/select_no_transaction", 303 | body: "", 304 | handler: func(w http.ResponseWriter, r *http.Request) (string, error) { 305 | if err := checkRequestTransactionHeader(r, ""); err != nil { 306 | return "", err 307 | } 308 | 309 | return "", nil 310 | }, 311 | }, 312 | } 313 | 314 | db, err := sql.Open("presto", ts.URL) 315 | if err != nil { 316 | t.Fatal(err) 317 | } 318 | defer db.Close() 319 | 320 | tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true, Isolation: sql.LevelReadUncommitted}) 321 | if err != nil { 322 | t.Fatal(err.Error()) 323 | } 324 | 325 | _, err = tx.Query("SELECT * FROM TransactionTable") 326 | if err != nil { 327 | t.Fatal(err.Error()) 328 | } 329 | 330 | err = tx.Rollback() 331 | if err != nil { 332 | t.Fatal(err.Error()) 333 | } 334 | 335 | _, err = db.Query("SELECT * FROM NoTransactionTable") 336 | if err != nil { 337 | t.Fatal(err.Error()) 338 | } 339 | 340 | err = server.verifyExpectedQueries() 341 | if err != nil { 342 | t.Fatal(err.Error()) 343 | } 344 | } 345 | -------------------------------------------------------------------------------- /presto/integration_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package presto 16 | 17 | import ( 18 | "context" 19 | "database/sql" 20 | "errors" 21 | "flag" 22 | "os" 23 | "strings" 24 | "testing" 25 | "time" 26 | ) 27 | 28 | var ( 29 | integrationServerFlag = flag.String( 30 | "presto_server_dsn", 31 | os.Getenv("PRESTO_SERVER_DSN"), 32 | "dsn of the presto server used for integration tests; default disabled", 33 | ) 34 | integrationServerQueryTimeout = flag.Duration( 35 | "presto_query_timeout", 36 | 5*time.Second, 37 | "max duration for presto queries to run before giving up", 38 | ) 39 | ) 40 | 41 | func init() { 42 | flag.Parse() 43 | DefaultQueryTimeout = *integrationServerQueryTimeout 44 | DefaultCancelQueryTimeout = *integrationServerQueryTimeout 45 | } 46 | 47 | // integrationServerDSN returns the URL of the integration test server. 48 | func integrationServerDSN(t *testing.T) string { 49 | if dsn := *integrationServerFlag; dsn != "" { 50 | return dsn 51 | } 52 | t.Skip() 53 | return "" 54 | } 55 | 56 | // integrationOpen opens a connection to the integration test server. 57 | func integrationOpen(t *testing.T, dsn ...string) *sql.DB { 58 | target := integrationServerDSN(t) 59 | if len(dsn) > 0 { 60 | target = dsn[0] 61 | } 62 | db, err := sql.Open("presto", target) 63 | if err != nil { 64 | t.Fatal(err) 65 | } 66 | return db 67 | } 68 | 69 | // integration tests based on python tests: 70 | // https://github.com/prestodb/presto-python-client/tree/master/integration_tests 71 | 72 | func TestIntegrationEnabled(t *testing.T) { 73 | dsn := *integrationServerFlag 74 | if dsn == "" { 75 | example := "http://test@localhost:8080" 76 | t.Skip("integration tests not enabled; use e.g. -presto_server_dsn=" + example) 77 | } 78 | } 79 | 80 | type nodesRow struct { 81 | NodeID string 82 | HTTPURI string 83 | NodeVersion string 84 | Coordinator bool 85 | State string 86 | } 87 | 88 | func TestIntegrationSelectQueryIterator(t *testing.T) { 89 | db := integrationOpen(t) 90 | defer db.Close() 91 | rows, err := db.Query("SELECT * FROM system.runtime.nodes") 92 | if err != nil { 93 | t.Fatal(err) 94 | } 95 | defer rows.Close() 96 | count := 0 97 | for rows.Next() { 98 | count++ 99 | var col nodesRow 100 | err = rows.Scan( 101 | &col.NodeID, 102 | &col.HTTPURI, 103 | &col.NodeVersion, 104 | &col.Coordinator, 105 | &col.State, 106 | ) 107 | if err != nil { 108 | t.Fatal(err) 109 | } 110 | if col.NodeID != "test" { 111 | t.Fatal("node_id != test") 112 | } 113 | } 114 | if err = rows.Err(); err != nil { 115 | t.Fatal(err) 116 | } 117 | if count < 1 { 118 | t.Fatal("no rows returned") 119 | } 120 | } 121 | 122 | func TestIntegrationSelectQueryNoResult(t *testing.T) { 123 | db := integrationOpen(t) 124 | defer db.Close() 125 | row := db.QueryRow("SELECT * FROM system.runtime.nodes where false") 126 | var col nodesRow 127 | err := row.Scan( 128 | &col.NodeID, 129 | &col.HTTPURI, 130 | &col.NodeVersion, 131 | &col.Coordinator, 132 | &col.State, 133 | ) 134 | if err == nil { 135 | t.Fatalf("unexpected query returning data: %+v", col) 136 | } 137 | } 138 | 139 | func TestIntegrationSelectFailedQuery(t *testing.T) { 140 | db := integrationOpen(t) 141 | defer db.Close() 142 | rows, err := db.Query("SELECT * FROM catalog.schema.do_not_exist") 143 | if err == nil { 144 | rows.Close() 145 | t.Fatal("query to invalid catalog succeeded") 146 | } 147 | _, ok := err.(*ErrQueryFailed) 148 | if !ok { 149 | t.Fatal("unexpected error:", err) 150 | } 151 | } 152 | 153 | type tpchRow struct { 154 | CustKey int 155 | Name string 156 | Address string 157 | NationKey int 158 | Phone string 159 | AcctBal float64 160 | MktSegment string 161 | Comment string 162 | } 163 | 164 | func TestIntegrationSelectTpch1000(t *testing.T) { 165 | db := integrationOpen(t) 166 | defer db.Close() 167 | rows, err := db.Query("SELECT * FROM tpch.sf1.customer LIMIT 1000") 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | defer rows.Close() 172 | count := 0 173 | for rows.Next() { 174 | count++ 175 | var col tpchRow 176 | err = rows.Scan( 177 | &col.CustKey, 178 | &col.Name, 179 | &col.Address, 180 | &col.NationKey, 181 | &col.Phone, 182 | &col.AcctBal, 183 | &col.MktSegment, 184 | &col.Comment, 185 | ) 186 | if err != nil { 187 | t.Fatal(err) 188 | } 189 | /* 190 | if col.CustKey == 1 && col.AcctBal != 711.56 { 191 | t.Fatal("unexpected acctbal for custkey=1:", col.AcctBal) 192 | } 193 | */ 194 | } 195 | if rows.Err() != nil { 196 | t.Fatal(err) 197 | } 198 | if count != 1000 { 199 | t.Fatal("not enough rows returned:", count) 200 | } 201 | } 202 | 203 | func TestIntegrationSelectCancelQuery(t *testing.T) { 204 | db := integrationOpen(t) 205 | defer db.Close() 206 | deadline := time.Now().Add(200 * time.Millisecond) 207 | ctx, cancel := context.WithDeadline(context.Background(), deadline) 208 | defer cancel() 209 | rows, err := db.QueryContext(ctx, "SELECT * FROM tpch.sf1.customer") 210 | if err != nil { 211 | goto handleErr 212 | } 213 | defer rows.Close() 214 | for rows.Next() { 215 | var col tpchRow 216 | err = rows.Scan( 217 | &col.CustKey, 218 | &col.Name, 219 | &col.Address, 220 | &col.NationKey, 221 | &col.Phone, 222 | &col.AcctBal, 223 | &col.MktSegment, 224 | &col.Comment, 225 | ) 226 | if err != nil { 227 | break 228 | } 229 | } 230 | if err = rows.Err(); err == nil { 231 | t.Fatal("unexpected query with deadline succeeded") 232 | } 233 | handleErr: 234 | errmsg := err.Error() 235 | for _, msg := range []string{"cancel", "deadline"} { 236 | if strings.Contains(errmsg, msg) { 237 | return 238 | } 239 | } 240 | t.Fatal("unexpected error:", err) 241 | } 242 | 243 | func TestIntegrationSessionProperties(t *testing.T) { 244 | dsn := integrationServerDSN(t) 245 | dsn += "?session_properties=query_max_run_time=10m,query_priority=2" 246 | db := integrationOpen(t, dsn) 247 | defer db.Close() 248 | rows, err := db.Query("SHOW SESSION") 249 | if err != nil { 250 | t.Fatal(err) 251 | } 252 | for rows.Next() { 253 | col := struct { 254 | Name string 255 | Value string 256 | Default string 257 | Type string 258 | Description string 259 | }{} 260 | err = rows.Scan( 261 | &col.Name, 262 | &col.Value, 263 | &col.Default, 264 | &col.Type, 265 | &col.Description, 266 | ) 267 | if err != nil { 268 | t.Fatal(err) 269 | } 270 | switch { 271 | case col.Name == "query_max_run_time" && col.Value != "10m": 272 | t.Fatal("unexpected value for query_max_run_time:", col.Value) 273 | case col.Name == "query_priority" && col.Value != "2": 274 | t.Fatal("unexpected value for query_priority:", col.Value) 275 | } 276 | } 277 | if err = rows.Err(); err != nil { 278 | t.Fatal(err) 279 | } 280 | } 281 | 282 | func TestIntegrationTypeConversion(t *testing.T) { 283 | db := integrationOpen(t) 284 | var ( 285 | goTime time.Time 286 | nullTime NullTime 287 | goString string 288 | nullString sql.NullString 289 | nullStringSlice NullSliceString 290 | nullStringSlice2 NullSlice2String 291 | nullStringSlice3 NullSlice3String 292 | nullInt64Slice NullSliceInt64 293 | nullInt64Slice2 NullSlice2Int64 294 | nullInt64Slice3 NullSlice3Int64 295 | nullFloat64Slice NullSliceFloat64 296 | nullFloat64Slice2 NullSlice2Float64 297 | nullFloat64Slice3 NullSlice3Float64 298 | goMap map[string]interface{} 299 | nullMap NullMap 300 | ) 301 | err := db.QueryRow(` 302 | SELECT 303 | TIMESTAMP '2017-07-10 01:02:03.004 UTC', 304 | CAST(NULL AS TIMESTAMP), 305 | CAST('string' AS VARCHAR), 306 | CAST(NULL AS VARCHAR), 307 | ARRAY['A', 'B', NULL], 308 | ARRAY[ARRAY['A'], NULL], 309 | ARRAY[ARRAY[ARRAY['A'], NULL], NULL], 310 | ARRAY[1, 2, NULL], 311 | ARRAY[ARRAY[1, 1, 1], NULL], 312 | ARRAY[ARRAY[ARRAY[1, 1, 1], NULL], NULL], 313 | ARRAY[1.0, 2.0, NULL], 314 | ARRAY[ARRAY[1.1, 1.1, 1.1], NULL], 315 | ARRAY[ARRAY[ARRAY[1.1, 1.1, 1.1], NULL], NULL], 316 | MAP(ARRAY['a', 'b'], ARRAY['c', 'd']), 317 | CAST(NULL AS MAP(ARRAY(INTEGER), ARRAY(INTEGER))) 318 | `).Scan( 319 | &goTime, 320 | &nullTime, 321 | &goString, 322 | &nullString, 323 | &nullStringSlice, 324 | &nullStringSlice2, 325 | &nullStringSlice3, 326 | &nullInt64Slice, 327 | &nullInt64Slice2, 328 | &nullInt64Slice3, 329 | &nullFloat64Slice, 330 | &nullFloat64Slice2, 331 | &nullFloat64Slice3, 332 | &goMap, 333 | &nullMap, 334 | ) 335 | if err != nil { 336 | t.Fatal(err) 337 | } 338 | } 339 | 340 | func TestIntegrationNoResults(t *testing.T) { 341 | db := integrationOpen(t) 342 | rows, err := db.Query("SELECT 1 LIMIT 0") 343 | if err != nil { 344 | t.Fatal(err) 345 | } 346 | for rows.Next() { 347 | t.Fatal(errors.New("Rows returned")) 348 | } 349 | if err = rows.Err(); err != nil { 350 | t.Fatal(err) 351 | } 352 | } 353 | 354 | func TestIntegrationQueryParametersSelect(t *testing.T) { 355 | scenarios := []struct { 356 | name string 357 | query string 358 | args []interface{} 359 | expectedError bool 360 | expectedRows int 361 | }{ 362 | { 363 | name: "valid string as varchar", 364 | query: "SELECT * FROM system.runtime.nodes WHERE system.runtime.nodes.node_id=?", 365 | args: []interface{}{"test"}, 366 | expectedRows: 1, 367 | }, 368 | { 369 | name: "valid int as bigint", 370 | query: "SELECT * FROM tpch.sf1.customer WHERE custkey=? LIMIT 2", 371 | args: []interface{}{int(1)}, 372 | expectedRows: 1, 373 | }, 374 | { 375 | name: "invalid string as bigint", 376 | query: "SELECT * FROM tpch.sf1.customer WHERE custkey=? LIMIT 2", 377 | args: []interface{}{"1"}, 378 | expectedError: true, 379 | }, 380 | { 381 | name: "valid string as date", 382 | query: "SELECT * FROM tpch.sf1.lineitem WHERE shipdate=? LIMIT 2", 383 | args: []interface{}{"1995-01-27"}, 384 | expectedError: true, 385 | }, 386 | } 387 | 388 | for i := range scenarios { 389 | scenario := scenarios[i] 390 | 391 | t.Run(scenario.name, func(t *testing.T) { 392 | db := integrationOpen(t) 393 | defer db.Close() 394 | 395 | rows, err := db.Query(scenario.query, scenario.args...) 396 | if err != nil { 397 | if scenario.expectedError { 398 | return 399 | } 400 | t.Fatal(err) 401 | } 402 | defer rows.Close() 403 | 404 | if scenario.expectedError { 405 | t.Fatal("missing expected error") 406 | } 407 | 408 | var count int 409 | for rows.Next() { 410 | count++ 411 | } 412 | if err = rows.Err(); err != nil { 413 | t.Fatal(err) 414 | } 415 | if count != scenario.expectedRows { 416 | t.Fatalf("expecting %d rows, got %d", scenario.expectedRows, count) 417 | } 418 | }) 419 | } 420 | } 421 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /presto/presto_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package presto 16 | 17 | import ( 18 | "context" 19 | "database/sql" 20 | "encoding/json" 21 | "errors" 22 | "fmt" 23 | "net/http" 24 | "net/http/httptest" 25 | "reflect" 26 | "strings" 27 | "testing" 28 | "time" 29 | ) 30 | 31 | func TestConfig(t *testing.T) { 32 | c := &Config{ 33 | PrestoURI: "http://foobar@localhost:8080", 34 | SessionProperties: map[string]string{"query_priority": "1"}, 35 | } 36 | dsn, err := c.FormatDSN() 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | want := "http://foobar@localhost:8080?session_properties=query_priority%3D1&source=presto-go-client" 41 | if dsn != want { 42 | t.Fatal("unexpected dsn:", dsn) 43 | } 44 | } 45 | 46 | func TestConfigSSLCertPath(t *testing.T) { 47 | c := &Config{ 48 | PrestoURI: "https://foobar@localhost:8080", 49 | SessionProperties: map[string]string{"query_priority": "1"}, 50 | SSLCertPath: "cert.pem", 51 | } 52 | dsn, err := c.FormatDSN() 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | want := "https://foobar@localhost:8080?SSLCertPath=cert.pem&session_properties=query_priority%3D1&source=presto-go-client" 57 | if dsn != want { 58 | t.Fatal("unexpected dsn:", dsn) 59 | } 60 | } 61 | 62 | func TestConfigWithoutSSLCertPath(t *testing.T) { 63 | c := &Config{ 64 | PrestoURI: "https://foobar@localhost:8080", 65 | SessionProperties: map[string]string{"query_priority": "1"}, 66 | } 67 | dsn, err := c.FormatDSN() 68 | if err != nil { 69 | t.Fatal(err) 70 | } 71 | want := "https://foobar@localhost:8080?session_properties=query_priority%3D1&source=presto-go-client" 72 | if dsn != want { 73 | t.Fatal("unexpected dsn:", dsn) 74 | } 75 | } 76 | 77 | func TestKerberosConfig(t *testing.T) { 78 | c := &Config{ 79 | PrestoURI: "https://foobar@localhost:8090", 80 | SessionProperties: map[string]string{"query_priority": "1"}, 81 | KerberosEnabled: "true", 82 | KerberosKeytabPath: "/opt/test.keytab", 83 | KerberosPrincipal: "presto/testhost", 84 | KerberosRealm: "example.com", 85 | KerberosConfigPath: "/etc/krb5.conf", 86 | SSLCertPath: "/tmp/test.cert", 87 | } 88 | dsn, err := c.FormatDSN() 89 | if err != nil { 90 | t.Fatal(err) 91 | } 92 | 93 | want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=presto%2Ftesthost&KerberosRealm=example.com&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3D1&source=presto-go-client" 94 | if dsn != want { 95 | t.Fatal("unexpected dsn:", dsn) 96 | } 97 | } 98 | 99 | func TestInvalidKerberosConfig(t *testing.T) { 100 | c := &Config{ 101 | PrestoURI: "http://foobar@localhost:8090", 102 | KerberosEnabled: "true", 103 | } 104 | _, err := c.FormatDSN() 105 | if err == nil { 106 | t.Fatal("dsn generated from invalid secure url, since kerberos enabled must has SSL enabled") 107 | } 108 | } 109 | 110 | func TestJWTConfig(t *testing.T) { 111 | c := &Config{ 112 | PrestoURI: "https://foobar@localhost:8090", 113 | SessionProperties: map[string]string{"query_priority": "1"}, 114 | AccessToken: "test_token", 115 | } 116 | dsn, err := c.FormatDSN() 117 | if err != nil { 118 | t.Fatal(err) 119 | } 120 | 121 | want := "https://foobar@localhost:8090?AccessToken=test_token&session_properties=query_priority%3D1&source=presto-go-client" 122 | if dsn != want { 123 | t.Fatal("unexpected dsn:", dsn) 124 | } 125 | } 126 | 127 | func TestConfigWithMalformedURL(t *testing.T) { 128 | _, err := (&Config{PrestoURI: ":("}).FormatDSN() 129 | if err == nil { 130 | t.Fatal("dsn generated from malformed url") 131 | } 132 | } 133 | 134 | func TestConnErrorDSN(t *testing.T) { 135 | testcases := []struct { 136 | Name string 137 | DSN string 138 | }{ 139 | {Name: "malformed", DSN: "://"}, 140 | {Name: "unknown_client", DSN: "http://localhost?custom_client=unknown"}, 141 | } 142 | for _, tc := range testcases { 143 | t.Run(tc.Name, func(t *testing.T) { 144 | db, err := sql.Open("presto", tc.DSN) 145 | if err != nil { 146 | t.Fatal(err) 147 | } 148 | if _, err = db.Query("SELECT 1"); err == nil { 149 | db.Close() 150 | t.Fatal("test dsn is supposed to fail:", tc.DSN) 151 | } 152 | }) 153 | } 154 | } 155 | 156 | func TestRegisterCustomClientReserved(t *testing.T) { 157 | for _, tc := range []string{"true", "false"} { 158 | t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { 159 | err := RegisterCustomClient(tc, &http.Client{}) 160 | if err == nil { 161 | t.Fatal("client key name supposed to fail:", tc) 162 | } 163 | }) 164 | } 165 | } 166 | 167 | func TestRoundTripRetryQueryError(t *testing.T) { 168 | count := 0 169 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 170 | if count == 0 { 171 | count++ 172 | w.WriteHeader(http.StatusServiceUnavailable) 173 | return 174 | } 175 | w.WriteHeader(http.StatusOK) 176 | json.NewEncoder(w).Encode(&stmtResponse{ 177 | Error: stmtError{ 178 | ErrorName: "TEST", 179 | }, 180 | }) 181 | })) 182 | defer ts.Close() 183 | db, err := sql.Open("presto", ts.URL) 184 | if err != nil { 185 | t.Fatal(err) 186 | } 187 | defer db.Close() 188 | _, err = db.Query("SELECT 1") 189 | if _, ok := err.(*ErrQueryFailed); !ok { 190 | t.Fatal("unexpected error:", err) 191 | } 192 | } 193 | 194 | func TestRoundTripCancellation(t *testing.T) { 195 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 196 | w.WriteHeader(http.StatusServiceUnavailable) 197 | })) 198 | defer ts.Close() 199 | db, err := sql.Open("presto", ts.URL) 200 | if err != nil { 201 | t.Fatal(err) 202 | } 203 | defer db.Close() 204 | ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 205 | defer cancel() 206 | _, err = db.QueryContext(ctx, "SELECT 1") 207 | if err == nil { 208 | t.Fatal("unexpected query with cancelled context succeeded") 209 | } 210 | } 211 | 212 | func TestQueryContextCancellation(t *testing.T) { 213 | var qr = queryResponse{ 214 | NextURI: "", 215 | Columns: []queryColumn{}, 216 | Data: []queryData{}, 217 | Stats: stmtStats{ 218 | State: "RUNNING", 219 | }, 220 | } 221 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 222 | w.WriteHeader(http.StatusOK) 223 | json.NewEncoder(w).Encode(&qr) 224 | })) 225 | qr.NextURI = ts.URL 226 | defer ts.Close() 227 | db, err := sql.Open("presto", ts.URL) 228 | if err != nil { 229 | t.Fatal(err) 230 | } 231 | defer db.Close() 232 | ctx, cancel := context.WithCancel(context.Background()) 233 | errChannel := make(chan error) 234 | done := make(chan bool) 235 | go func() { 236 | _, err := db.QueryContext(ctx, "SELECT 1") 237 | if err != nil { 238 | errChannel <- err 239 | } else { 240 | close(done) 241 | } 242 | }() 243 | cancel() 244 | var err1 error 245 | select { 246 | case <-done: 247 | t.Fatal("unexpected query with cancelled context succeeded") 248 | break 249 | case err1 = <-errChannel: 250 | close(errChannel) 251 | if err1.Error() != "context canceled" { 252 | t.Fatal("query should have been cancelled with error message context cancelled") 253 | } 254 | } 255 | } 256 | 257 | func TestAuthFailure(t *testing.T) { 258 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 259 | w.WriteHeader(http.StatusUnauthorized) 260 | })) 261 | defer ts.Close() 262 | db, err := sql.Open("presto", ts.URL) 263 | if err != nil { 264 | t.Fatal(err) 265 | } 266 | defer db.Close() 267 | } 268 | 269 | func TestQueryCancellation(t *testing.T) { 270 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 271 | w.WriteHeader(http.StatusOK) 272 | json.NewEncoder(w).Encode(&stmtResponse{ 273 | Error: stmtError{ 274 | ErrorName: "USER_CANCELLED", 275 | }, 276 | }) 277 | })) 278 | defer ts.Close() 279 | db, err := sql.Open("presto", ts.URL) 280 | if err != nil { 281 | t.Fatal(err) 282 | } 283 | defer db.Close() 284 | _, err = db.Query("SELECT 1") 285 | if err != ErrQueryCancelled { 286 | t.Fatal("unexpected error:", err) 287 | } 288 | } 289 | 290 | func TestQueryFailure(t *testing.T) { 291 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 292 | w.WriteHeader(http.StatusInternalServerError) 293 | })) 294 | defer ts.Close() 295 | db, err := sql.Open("presto", ts.URL) 296 | if err != nil { 297 | t.Fatal(err) 298 | } 299 | defer db.Close() 300 | _, err = db.Query("SELECT 1") 301 | if _, ok := err.(*ErrQueryFailed); !ok { 302 | t.Fatal("unexpected error:", err) 303 | } 304 | } 305 | 306 | func TestSSLCertPath(t *testing.T) { 307 | db, err := sql.Open("presto", "https://localhost:9?SSLCertPath=/tmp/invalid_test.cert") 308 | if err != nil { 309 | t.Fatal(err) 310 | } 311 | defer db.Close() 312 | 313 | want := "Error loading SSL Cert File" 314 | if err := db.Ping(); err == nil { 315 | t.Fatal(err) 316 | } else if !strings.Contains(err.Error(), want) { 317 | t.Fatalf("want: %q, got: %v", want, err) 318 | } 319 | } 320 | 321 | func TestWithoutSSLCertPath(t *testing.T) { 322 | db, err := sql.Open("presto", "https://localhost:9") 323 | if err != nil { 324 | t.Fatal(err) 325 | } 326 | defer db.Close() 327 | 328 | if err := db.Ping(); err != nil { 329 | t.Fatal(err) 330 | } 331 | } 332 | 333 | func TestUnsupportedExec(t *testing.T) { 334 | db, err := sql.Open("presto", "http://localhost:9") 335 | if err != nil { 336 | t.Fatal(err) 337 | } 338 | defer db.Close() 339 | if _, err := db.Exec("CREATE TABLE foobar (V VARCHAR)"); err == nil { 340 | t.Fatal("unsupported exec succeeded with no error") 341 | } 342 | } 343 | 344 | func TestJWTAuthHeader(t *testing.T) { 345 | // this test ensures that the JWT token is passed as a Bearer token within the Authorization header 346 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 347 | // validate the Authorization header is JWT token 348 | if r.Header.Get("Authorization") != "Bearer test_token" { 349 | w.WriteHeader(http.StatusUnauthorized) 350 | } else { 351 | w.WriteHeader(http.StatusOK) 352 | } 353 | })) 354 | defer ts.Close() 355 | 356 | db, err := sql.Open("presto", ts.URL+"?AccessToken=test_token") 357 | if err != nil { 358 | t.Fatal(err) 359 | } 360 | defer db.Close() 361 | _, err = db.Query("SELECT 1") 362 | if err.Error() != "presto: EOF" { 363 | t.Fatal("expected query to return EOF", err) 364 | } 365 | } 366 | 367 | func TestTypeConversion(t *testing.T) { 368 | utc, err := time.LoadLocation("UTC") 369 | if err != nil { 370 | t.Fatal(err) 371 | } 372 | testcases := []struct { 373 | PrestoType string 374 | PrestoResponseUnmarshalledSample interface{} 375 | ExpectedGoValue interface{} 376 | }{ 377 | { 378 | PrestoType: "boolean", 379 | PrestoResponseUnmarshalledSample: true, 380 | ExpectedGoValue: true, 381 | }, 382 | { 383 | PrestoType: "varchar(1)", 384 | PrestoResponseUnmarshalledSample: "hello", 385 | ExpectedGoValue: "hello", 386 | }, 387 | { 388 | PrestoType: "bigint", 389 | PrestoResponseUnmarshalledSample: json.Number("1234516165077230279"), 390 | ExpectedGoValue: int64(1234516165077230279), 391 | }, 392 | { 393 | PrestoType: "double", 394 | PrestoResponseUnmarshalledSample: json.Number("1.0"), 395 | ExpectedGoValue: float64(1), 396 | }, 397 | { 398 | PrestoType: "date", 399 | PrestoResponseUnmarshalledSample: "2017-07-10", 400 | ExpectedGoValue: time.Date(2017, 7, 10, 0, 0, 0, 0, time.Local), 401 | }, 402 | { 403 | PrestoType: "time", 404 | PrestoResponseUnmarshalledSample: "01:02:03.000", 405 | ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, time.Local), 406 | }, 407 | { 408 | PrestoType: "time with time zone", 409 | PrestoResponseUnmarshalledSample: "01:02:03.000 UTC", 410 | ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, utc), 411 | }, 412 | { 413 | PrestoType: "timestamp", 414 | PrestoResponseUnmarshalledSample: "2017-07-10 01:02:03.000", 415 | ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, time.Local), 416 | }, 417 | { 418 | PrestoType: "timestamp with time zone", 419 | PrestoResponseUnmarshalledSample: "2017-07-10 01:02:03.000 UTC", 420 | ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, utc), 421 | }, 422 | { 423 | PrestoType: "map", 424 | PrestoResponseUnmarshalledSample: nil, 425 | ExpectedGoValue: nil, 426 | }, 427 | { 428 | // arrays return data as-is for slice scanners 429 | PrestoType: "array", 430 | PrestoResponseUnmarshalledSample: nil, 431 | ExpectedGoValue: nil, 432 | }, 433 | } 434 | for _, tc := range testcases { 435 | converter := newTypeConverter(tc.PrestoType) 436 | 437 | t.Run(tc.PrestoType+":nil", func(t *testing.T) { 438 | if _, err := converter.ConvertValue(nil); err != nil { 439 | t.Fatal(err) 440 | } 441 | }) 442 | 443 | t.Run(tc.PrestoType+":bogus", func(t *testing.T) { 444 | if _, err := converter.ConvertValue(struct{}{}); err == nil { 445 | t.Fatal("bogus data scanned with no error") 446 | } 447 | }) 448 | t.Run(tc.PrestoType+":sample", func(t *testing.T) { 449 | v, err := converter.ConvertValue(tc.PrestoResponseUnmarshalledSample) 450 | if err != nil { 451 | t.Fatal(err) 452 | } 453 | if !reflect.DeepEqual(v, tc.ExpectedGoValue) { 454 | t.Fatalf("unexpected data from sample:\nhave %+v\nwant %+v", v, tc.ExpectedGoValue) 455 | } 456 | }) 457 | } 458 | } 459 | 460 | func TestSliceTypeConversion(t *testing.T) { 461 | testcases := []struct { 462 | GoType string 463 | Scanner sql.Scanner 464 | PrestoResponseUnmarshalledSample interface{} 465 | TestScanner func(t *testing.T, s sql.Scanner) 466 | }{ 467 | { 468 | GoType: "[]bool", 469 | Scanner: &NullSliceBool{}, 470 | PrestoResponseUnmarshalledSample: []interface{}{true}, 471 | TestScanner: func(t *testing.T, s sql.Scanner) { 472 | v, _ := s.(*NullSliceBool) 473 | if !v.Valid { 474 | t.Fatal("scanner failed") 475 | } 476 | }, 477 | }, 478 | { 479 | GoType: "[]string", 480 | Scanner: &NullSliceString{}, 481 | PrestoResponseUnmarshalledSample: []interface{}{"hello"}, 482 | TestScanner: func(t *testing.T, s sql.Scanner) { 483 | v, _ := s.(*NullSliceString) 484 | if !v.Valid { 485 | t.Fatal("scanner failed") 486 | } 487 | }, 488 | }, 489 | { 490 | GoType: "[]int64", 491 | Scanner: &NullSliceInt64{}, 492 | PrestoResponseUnmarshalledSample: []interface{}{json.Number("1")}, 493 | TestScanner: func(t *testing.T, s sql.Scanner) { 494 | v, _ := s.(*NullSliceInt64) 495 | if !v.Valid { 496 | t.Fatal("scanner failed") 497 | } 498 | }, 499 | }, 500 | 501 | { 502 | GoType: "[]float64", 503 | Scanner: &NullSliceFloat64{}, 504 | PrestoResponseUnmarshalledSample: []interface{}{json.Number("1.0")}, 505 | TestScanner: func(t *testing.T, s sql.Scanner) { 506 | v, _ := s.(*NullSliceFloat64) 507 | if !v.Valid { 508 | t.Fatal("scanner failed") 509 | } 510 | }, 511 | }, 512 | { 513 | GoType: "[]time.Time", 514 | Scanner: &NullSliceTime{}, 515 | PrestoResponseUnmarshalledSample: []interface{}{"2017-07-01"}, 516 | TestScanner: func(t *testing.T, s sql.Scanner) { 517 | v, _ := s.(*NullSliceTime) 518 | if !v.Valid { 519 | t.Fatal("scanner failed") 520 | } 521 | }, 522 | }, 523 | { 524 | GoType: "[]map[string]interface{}", 525 | Scanner: &NullSliceMap{}, 526 | PrestoResponseUnmarshalledSample: []interface{}{map[string]interface{}{"hello": "world"}}, 527 | TestScanner: func(t *testing.T, s sql.Scanner) { 528 | v, _ := s.(*NullSliceMap) 529 | if !v.Valid { 530 | t.Fatal("scanner failed") 531 | } 532 | }, 533 | }, 534 | } 535 | for _, tc := range testcases { 536 | t.Run(tc.GoType+":nil", func(t *testing.T) { 537 | if err := tc.Scanner.Scan(nil); err != nil { 538 | t.Error(err) 539 | } 540 | }) 541 | 542 | t.Run(tc.GoType+":bogus", func(t *testing.T) { 543 | if err := tc.Scanner.Scan(struct{}{}); err == nil { 544 | t.Error("bogus data scanned with no error") 545 | } 546 | if err := tc.Scanner.Scan([]interface{}{struct{}{}}); err == nil { 547 | t.Error("bogus data scanned with no error") 548 | } 549 | }) 550 | 551 | t.Run(tc.GoType+":sample", func(t *testing.T) { 552 | if err := tc.Scanner.Scan(tc.PrestoResponseUnmarshalledSample); err != nil { 553 | t.Error(err) 554 | } 555 | tc.TestScanner(t, tc.Scanner) 556 | }) 557 | } 558 | } 559 | 560 | func TestSlice2TypeConversion(t *testing.T) { 561 | testcases := []struct { 562 | GoType string 563 | Scanner sql.Scanner 564 | PrestoResponseUnmarshalledSample interface{} 565 | TestScanner func(t *testing.T, s sql.Scanner) 566 | }{ 567 | { 568 | GoType: "[][]bool", 569 | Scanner: &NullSlice2Bool{}, 570 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{true}}, 571 | TestScanner: func(t *testing.T, s sql.Scanner) { 572 | v, _ := s.(*NullSlice2Bool) 573 | if !v.Valid { 574 | t.Fatal("scanner failed") 575 | } 576 | }, 577 | }, 578 | { 579 | GoType: "[][]string", 580 | Scanner: &NullSlice2String{}, 581 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{"hello"}}, 582 | TestScanner: func(t *testing.T, s sql.Scanner) { 583 | v, _ := s.(*NullSlice2String) 584 | if !v.Valid { 585 | t.Fatal("scanner failed") 586 | } 587 | }, 588 | }, 589 | { 590 | GoType: "[][]int64", 591 | Scanner: &NullSlice2Int64{}, 592 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{json.Number("1")}}, 593 | TestScanner: func(t *testing.T, s sql.Scanner) { 594 | v, _ := s.(*NullSlice2Int64) 595 | if !v.Valid { 596 | t.Fatal("scanner failed") 597 | } 598 | }, 599 | }, 600 | { 601 | GoType: "[][]float64", 602 | Scanner: &NullSlice2Float64{}, 603 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{json.Number("1.0")}}, 604 | TestScanner: func(t *testing.T, s sql.Scanner) { 605 | v, _ := s.(*NullSlice2Float64) 606 | if !v.Valid { 607 | t.Fatal("scanner failed") 608 | } 609 | }, 610 | }, 611 | { 612 | GoType: "[][]time.Time", 613 | Scanner: &NullSlice2Time{}, 614 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{"2017-07-01"}}, 615 | TestScanner: func(t *testing.T, s sql.Scanner) { 616 | v, _ := s.(*NullSlice2Time) 617 | if !v.Valid { 618 | t.Fatal("scanner failed") 619 | } 620 | }, 621 | }, 622 | { 623 | GoType: "[][]map[string]interface{}", 624 | Scanner: &NullSlice2Map{}, 625 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{map[string]interface{}{"hello": "world"}}}, 626 | TestScanner: func(t *testing.T, s sql.Scanner) { 627 | v, _ := s.(*NullSlice2Map) 628 | if !v.Valid { 629 | t.Fatal("scanner failed") 630 | } 631 | }, 632 | }, 633 | } 634 | for _, tc := range testcases { 635 | t.Run(tc.GoType+":nil", func(t *testing.T) { 636 | if err := tc.Scanner.Scan(nil); err != nil { 637 | t.Error(err) 638 | } 639 | if err := tc.Scanner.Scan([]interface{}{nil}); err != nil { 640 | t.Error(err) 641 | } 642 | }) 643 | 644 | t.Run(tc.GoType+":bogus", func(t *testing.T) { 645 | if err := tc.Scanner.Scan(struct{}{}); err == nil { 646 | t.Error("bogus data scanned with no error") 647 | } 648 | if err := tc.Scanner.Scan([]interface{}{struct{}{}}); err == nil { 649 | t.Error("bogus data scanned with no error") 650 | } 651 | if err := tc.Scanner.Scan([]interface{}{[]interface{}{struct{}{}}}); err == nil { 652 | t.Error("bogus data scanned with no error") 653 | } 654 | }) 655 | 656 | t.Run(tc.GoType+":sample", func(t *testing.T) { 657 | if err := tc.Scanner.Scan(tc.PrestoResponseUnmarshalledSample); err != nil { 658 | t.Error(err) 659 | } 660 | tc.TestScanner(t, tc.Scanner) 661 | }) 662 | } 663 | } 664 | 665 | func TestSlice3TypeConversion(t *testing.T) { 666 | testcases := []struct { 667 | GoType string 668 | Scanner sql.Scanner 669 | PrestoResponseUnmarshalledSample interface{} 670 | TestScanner func(t *testing.T, s sql.Scanner) 671 | }{ 672 | { 673 | GoType: "[][][]bool", 674 | Scanner: &NullSlice3Bool{}, 675 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{true}}}, 676 | TestScanner: func(t *testing.T, s sql.Scanner) { 677 | v, _ := s.(*NullSlice3Bool) 678 | if !v.Valid { 679 | t.Fatal("scanner failed") 680 | } 681 | }, 682 | }, 683 | { 684 | GoType: "[][][]string", 685 | Scanner: &NullSlice3String{}, 686 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{"hello"}}}, 687 | TestScanner: func(t *testing.T, s sql.Scanner) { 688 | v, _ := s.(*NullSlice3String) 689 | if !v.Valid { 690 | t.Fatal("scanner failed") 691 | } 692 | }, 693 | }, 694 | { 695 | GoType: "[][][]int64", 696 | Scanner: &NullSlice3Int64{}, 697 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{json.Number("1")}}}, 698 | TestScanner: func(t *testing.T, s sql.Scanner) { 699 | v, _ := s.(*NullSlice3Int64) 700 | if !v.Valid { 701 | t.Fatal("scanner failed") 702 | } 703 | }, 704 | }, 705 | { 706 | GoType: "[][][]float64", 707 | Scanner: &NullSlice3Float64{}, 708 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{json.Number("1.0")}}}, 709 | TestScanner: func(t *testing.T, s sql.Scanner) { 710 | v, _ := s.(*NullSlice3Float64) 711 | if !v.Valid { 712 | t.Fatal("scanner failed") 713 | } 714 | }, 715 | }, 716 | { 717 | GoType: "[][][]time.Time", 718 | Scanner: &NullSlice3Time{}, 719 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{"2017-07-01"}}}, 720 | TestScanner: func(t *testing.T, s sql.Scanner) { 721 | v, _ := s.(*NullSlice3Time) 722 | if !v.Valid { 723 | t.Fatal("scanner failed") 724 | } 725 | }, 726 | }, 727 | { 728 | GoType: "[][][]map[string]interface{}", 729 | Scanner: &NullSlice3Map{}, 730 | PrestoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{map[string]interface{}{"hello": "world"}}}}, 731 | TestScanner: func(t *testing.T, s sql.Scanner) { 732 | v, _ := s.(*NullSlice3Map) 733 | if !v.Valid { 734 | t.Fatal("scanner failed") 735 | } 736 | }, 737 | }, 738 | } 739 | for _, tc := range testcases { 740 | t.Run(tc.GoType+":nil", func(t *testing.T) { 741 | if err := tc.Scanner.Scan(nil); err != nil { 742 | t.Fatal(err) 743 | } 744 | if err := tc.Scanner.Scan([]interface{}{[]interface{}{nil}}); err != nil { 745 | t.Fatal(err) 746 | } 747 | }) 748 | 749 | t.Run(tc.GoType+":bogus", func(t *testing.T) { 750 | if err := tc.Scanner.Scan(struct{}{}); err == nil { 751 | t.Error("bogus data scanned with no error") 752 | } 753 | if err := tc.Scanner.Scan([]interface{}{[]interface{}{struct{}{}}}); err == nil { 754 | t.Error("bogus data scanned with no error") 755 | } 756 | if err := tc.Scanner.Scan([]interface{}{[]interface{}{[]interface{}{struct{}{}}}}); err == nil { 757 | t.Error("bogus data scanned with no error") 758 | } 759 | }) 760 | 761 | t.Run(tc.GoType+":sample", func(t *testing.T) { 762 | if err := tc.Scanner.Scan(tc.PrestoResponseUnmarshalledSample); err != nil { 763 | t.Error(err) 764 | } 765 | tc.TestScanner(t, tc.Scanner) 766 | }) 767 | } 768 | } 769 | func TestNamedArgAndQueryId(t *testing.T) { 770 | db, err := sql.Open("presto", "http://localhost:9") 771 | if err != nil { 772 | t.Fatal(err) 773 | } 774 | defer db.Close() 775 | 776 | rows, err := db.Query("select 1 ", sql.Named("X-Presto-Client-Tags", "userName=root"), sql.Named("X-Presto-Client-Info", "{\"submitTime\":\"2022-05-223 10:22:03\",\"userName\":\"root\"}")) 777 | if err != nil { 778 | t.Fatal(err) 779 | } 780 | 781 | var testId string 782 | for rows.Next() { 783 | err := rows.Scan(&testId) 784 | if err != nil { 785 | t.Fatal(err) 786 | } 787 | } 788 | 789 | var e *EOF 790 | if errors.As(rows.Err(), &e) { 791 | t.Logf("sucess to get query ID: %s", e.QueryID) 792 | } 793 | } 794 | -------------------------------------------------------------------------------- /presto/presto.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // This file contains code that was borrowed from prestgo, mainly some 16 | // data type definitions. 17 | // 18 | // See https://github.com/avct/prestgo for copyright information. 19 | // 20 | // The MIT License (MIT) 21 | // 22 | // Copyright (c) 2015 Avocet Systems Ltd. 23 | // 24 | // Permission is hereby granted, free of charge, to any person obtaining a copy 25 | // of this software and associated documentation files (the "Software"), to deal 26 | // in the Software without restriction, including without limitation the rights 27 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 28 | // copies of the Software, and to permit persons to whom the Software is 29 | // furnished to do so, subject to the following conditions: 30 | // 31 | // The above copyright notice and this permission notice shall be included in all 32 | // copies or substantial portions of the Software. 33 | // 34 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 35 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 36 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 37 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 38 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 39 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 40 | // SOFTWARE. 41 | 42 | // Package presto provides a database/sql driver for Facebook's Presto. 43 | // 44 | // The driver should be used via the database/sql package: 45 | // 46 | // import "database/sql" 47 | // import _ "github.com/prestodb/presto-go-client/presto" 48 | // 49 | // dsn := "http://user@localhost:8080?catalog=default&schema=test" 50 | // db, err := sql.Open("presto", dsn) 51 | package presto 52 | 53 | import ( 54 | "context" 55 | "crypto/tls" 56 | "crypto/x509" 57 | "database/sql" 58 | "database/sql/driver" 59 | "encoding/json" 60 | "errors" 61 | "fmt" 62 | "io" 63 | "io/ioutil" 64 | "math" 65 | "net/http" 66 | "net/url" 67 | "os" 68 | "regexp" 69 | "strconv" 70 | "strings" 71 | "sync" 72 | "time" 73 | "unicode" 74 | 75 | "gopkg.in/jcmturner/gokrb5.v6/client" 76 | "gopkg.in/jcmturner/gokrb5.v6/config" 77 | "gopkg.in/jcmturner/gokrb5.v6/keytab" 78 | ) 79 | 80 | func init() { 81 | sql.Register("presto", &sqldriver{}) 82 | } 83 | 84 | var ( 85 | // DefaultQueryTimeout is the default timeout for queries executed without a context. 86 | DefaultQueryTimeout = 60 * time.Second 87 | 88 | // DefaultCancelQueryTimeout is the timeout for the request to cancel queries in presto. 89 | DefaultCancelQueryTimeout = 30 * time.Second 90 | 91 | // ErrOperationNotSupported indicates that a database operation is not supported. 92 | ErrOperationNotSupported = errors.New("presto: operation not supported") 93 | 94 | // ErrQueryCancelled indicates that a query has been cancelled. 95 | ErrQueryCancelled = errors.New("presto: query cancelled") 96 | ) 97 | 98 | const ( 99 | preparedStatementHeader = "X-Presto-Prepared-Statement" 100 | preparedStatementName = "_presto_go" 101 | prestoUserHeader = "X-Presto-User" 102 | prestoSourceHeader = "X-Presto-Source" 103 | prestoCatalogHeader = "X-Presto-Catalog" 104 | prestoSchemaHeader = "X-Presto-Schema" 105 | prestoSessionHeader = "X-Presto-Session" 106 | prestoTransactionHeader = "X-Presto-Transaction-Id" 107 | prestoStartedTransactionHeader = "X-Presto-Started-Transaction-Id" 108 | prestoClearTransactionHeader = "X-Presto-Clear-Transaction-Id" 109 | prestoClientTagsHeader = "X-Presto-Client-Tags" 110 | prestoClientInfoHeader = "X-Presto-Client-Info" 111 | 112 | kerberosEnabledConfig = "KerberosEnabled" 113 | kerberosKeytabPathConfig = "KerberosKeytabPath" 114 | kerberosPrincipalConfig = "KerberosPrincipal" 115 | kerberosRealmConfig = "KerberosRealm" 116 | kerberosConfigPathConfig = "KerberosConfigPath" 117 | sSLCertPathConfig = "SSLCertPath" 118 | 119 | accessTokenConfig = "AccessToken" 120 | ) 121 | 122 | type sqldriver struct{} 123 | 124 | func (d *sqldriver) Open(name string) (driver.Conn, error) { 125 | return newConn(name) 126 | } 127 | 128 | var _ driver.Driver = &sqldriver{} 129 | 130 | // Config is a configuration that can be encoded to a DSN string. 131 | type Config struct { 132 | PrestoURI string // URI of the Presto server, e.g. http://user@localhost:8080 133 | Source string // Source of the connection (optional) 134 | Catalog string // Catalog (optional) 135 | Schema string // Schema (optional) 136 | SessionProperties map[string]string // Session properties (optional) 137 | CustomClientName string // Custom client name (optional) 138 | KerberosEnabled string // KerberosEnabled (optional, default is false) 139 | KerberosKeytabPath string // Kerberos Keytab Path (optional) 140 | KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional) 141 | KerberosRealm string // The Kerberos Realm (optional) 142 | KerberosConfigPath string // The krb5 config path (optional) 143 | SSLCertPath string // The SSL cert path for TLS verification (optional) 144 | AccessToken string // The JWT access token for authentication (optional) 145 | } 146 | 147 | // FormatDSN returns a DSN string from the configuration. 148 | func (c *Config) FormatDSN() (string, error) { 149 | prestoURL, err := url.Parse(c.PrestoURI) 150 | if err != nil { 151 | return "", err 152 | } 153 | var sessionkv []string 154 | if c.SessionProperties != nil { 155 | for k, v := range c.SessionProperties { 156 | sessionkv = append(sessionkv, k+"="+v) 157 | } 158 | } 159 | source := c.Source 160 | if source == "" { 161 | source = "presto-go-client" 162 | } 163 | query := make(url.Values) 164 | query.Add("source", source) 165 | 166 | KerberosEnabled, _ := strconv.ParseBool(c.KerberosEnabled) 167 | isSSL := prestoURL.Scheme == "https" 168 | 169 | if isSSL && c.SSLCertPath != "" { 170 | query.Add(sSLCertPathConfig, c.SSLCertPath) 171 | } 172 | 173 | if KerberosEnabled { 174 | query.Add(kerberosEnabledConfig, "true") 175 | query.Add(kerberosKeytabPathConfig, c.KerberosKeytabPath) 176 | query.Add(kerberosPrincipalConfig, c.KerberosPrincipal) 177 | query.Add(kerberosRealmConfig, c.KerberosRealm) 178 | query.Add(kerberosConfigPathConfig, c.KerberosConfigPath) 179 | if !isSSL { 180 | return "", fmt.Errorf("presto: client configuration error, SSL must be enabled for secure env") 181 | } 182 | } 183 | 184 | if c.AccessToken != "" { 185 | query.Add(accessTokenConfig, c.AccessToken) 186 | } 187 | 188 | for k, v := range map[string]string{ 189 | "catalog": c.Catalog, 190 | "schema": c.Schema, 191 | "session_properties": strings.Join(sessionkv, ","), 192 | "custom_client": c.CustomClientName, 193 | } { 194 | if v != "" { 195 | query[k] = []string{v} 196 | } 197 | } 198 | prestoURL.RawQuery = query.Encode() 199 | return prestoURL.String(), nil 200 | } 201 | 202 | // Conn is a presto connection. 203 | type Conn struct { 204 | baseURL string 205 | auth *url.Userinfo 206 | httpClient http.Client 207 | httpHeaders http.Header 208 | kerberosClient client.Client 209 | kerberosEnabled bool 210 | } 211 | 212 | var ( 213 | _ driver.Conn = &Conn{} 214 | _ driver.ConnPrepareContext = &Conn{} 215 | _ driver.ConnBeginTx = &Conn{} 216 | ) 217 | 218 | func newConn(dsn string) (*Conn, error) { 219 | prestoURL, err := url.Parse(dsn) 220 | if err != nil { 221 | return nil, fmt.Errorf("presto: malformed dsn: %v", err) 222 | } 223 | 224 | prestoQuery := prestoURL.Query() 225 | 226 | kerberosEnabled, _ := strconv.ParseBool(prestoQuery.Get(kerberosEnabledConfig)) 227 | 228 | var kerberosClient client.Client 229 | 230 | if kerberosEnabled { 231 | kt, err := keytab.Load(prestoQuery.Get(kerberosKeytabPathConfig)) 232 | if err != nil { 233 | return nil, fmt.Errorf("presto: Error loading Keytab: %v", err) 234 | } 235 | 236 | kerberosClient = client.NewClientWithKeytab(prestoQuery.Get(kerberosPrincipalConfig), prestoQuery.Get(kerberosRealmConfig), kt) 237 | conf, err := config.Load(prestoQuery.Get(kerberosConfigPathConfig)) 238 | if err != nil { 239 | return nil, fmt.Errorf("presto: Error loading krb config: %v", err) 240 | } 241 | 242 | kerberosClient.WithConfig(conf) 243 | 244 | loginErr := kerberosClient.Login() 245 | if loginErr != nil { 246 | return nil, fmt.Errorf("presto: Error login to KDC: %v", loginErr) 247 | } 248 | } 249 | 250 | var httpClient = http.DefaultClient 251 | if clientKey := prestoQuery.Get("custom_client"); clientKey != "" { 252 | httpClient = getCustomClient(clientKey) 253 | if httpClient == nil { 254 | return nil, fmt.Errorf("presto: custom client not registered: %q", clientKey) 255 | } 256 | } else if certPath := prestoQuery.Get(sSLCertPathConfig); certPath != "" && prestoURL.Scheme == "https" { 257 | cert, err := os.ReadFile(certPath) 258 | if err != nil { 259 | return nil, fmt.Errorf("presto: Error loading SSL Cert File: %v", err) 260 | } 261 | certPool := x509.NewCertPool() 262 | certPool.AppendCertsFromPEM(cert) 263 | 264 | httpClient = &http.Client{ 265 | Transport: &http.Transport{ 266 | TLSClientConfig: &tls.Config{ 267 | RootCAs: certPool, 268 | }, 269 | }, 270 | } 271 | } 272 | 273 | c := &Conn{ 274 | baseURL: prestoURL.Scheme + "://" + prestoURL.Host, 275 | httpClient: *httpClient, 276 | httpHeaders: make(http.Header), 277 | kerberosClient: kerberosClient, 278 | kerberosEnabled: kerberosEnabled, 279 | } 280 | 281 | var user string 282 | if prestoURL.User != nil { 283 | user = prestoURL.User.Username() 284 | pass, _ := prestoURL.User.Password() 285 | if pass != "" && prestoURL.Scheme == "https" { 286 | c.auth = prestoURL.User 287 | } 288 | } 289 | 290 | for k, v := range map[string]string{ 291 | prestoUserHeader: user, 292 | prestoSourceHeader: prestoQuery.Get("source"), 293 | prestoCatalogHeader: prestoQuery.Get("catalog"), 294 | prestoSchemaHeader: prestoQuery.Get("schema"), 295 | prestoSessionHeader: prestoQuery.Get("session_properties"), 296 | } { 297 | if v != "" { 298 | c.httpHeaders.Add(k, v) 299 | } 300 | } 301 | 302 | // if a JWT access token is provided, add an Authorization header with Bearer token 303 | if token := prestoQuery.Get(accessTokenConfig); token != "" { 304 | c.httpHeaders.Set("Authorization", "Bearer "+token) 305 | } 306 | 307 | return c, nil 308 | } 309 | 310 | // registry for custom http clients 311 | var customClientRegistry = struct { 312 | sync.RWMutex 313 | Index map[string]http.Client 314 | }{ 315 | Index: make(map[string]http.Client), 316 | } 317 | 318 | // RegisterCustomClient associates a client to a key in the driver's registry. 319 | // 320 | // Register your custom client in the driver, then refer to it by name in the DSN, on the call to sql.Open: 321 | // 322 | // foobarClient := &http.Client{ 323 | // Transport: &http.Transport{ 324 | // Proxy: http.ProxyFromEnvironment, 325 | // DialContext: (&net.Dialer{ 326 | // Timeout: 30 * time.Second, 327 | // KeepAlive: 30 * time.Second, 328 | // DualStack: true, 329 | // }).DialContext, 330 | // MaxIdleConns: 100, 331 | // IdleConnTimeout: 90 * time.Second, 332 | // TLSHandshakeTimeout: 10 * time.Second, 333 | // ExpectContinueTimeout: 1 * time.Second, 334 | // TLSClientConfig: &tls.Config{ 335 | // // your config here... 336 | // }, 337 | // }, 338 | // } 339 | // presto.RegisterCustomClient("foobar", foobarClient) 340 | // db, err := sql.Open("presto", "https://user@localhost:8080?custom_client=foobar") 341 | func RegisterCustomClient(key string, client *http.Client) error { 342 | if _, err := strconv.ParseBool(key); err == nil { 343 | return fmt.Errorf("presto: custom client key %q is reserved", key) 344 | } 345 | customClientRegistry.Lock() 346 | customClientRegistry.Index[key] = *client 347 | customClientRegistry.Unlock() 348 | return nil 349 | } 350 | 351 | // DeregisterCustomClient removes the client associated to the key. 352 | func DeregisterCustomClient(key string) { 353 | customClientRegistry.Lock() 354 | delete(customClientRegistry.Index, key) 355 | customClientRegistry.Unlock() 356 | } 357 | 358 | func getCustomClient(key string) *http.Client { 359 | customClientRegistry.RLock() 360 | defer customClientRegistry.RUnlock() 361 | if client, ok := customClientRegistry.Index[key]; ok { 362 | return &client 363 | } 364 | return nil 365 | } 366 | 367 | // Begin implements the driver.Conn interface. 368 | func (c *Conn) Begin() (driver.Tx, error) { 369 | return nil, ErrOperationNotSupported 370 | } 371 | 372 | func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 373 | args := []string{} 374 | if opts.ReadOnly { 375 | args = append(args, "READ ONLY") 376 | } 377 | 378 | level := sql.IsolationLevel(opts.Isolation) 379 | if level != sql.LevelDefault { 380 | err := verifyIsolationLevel(level) 381 | if err != nil { 382 | return nil, err 383 | } 384 | args = append(args, fmt.Sprintf("ISOLATION LEVEL %s", level.String())) 385 | } 386 | 387 | query := fmt.Sprintf("START TRANSACTION %s", strings.Join(args, ", ")) 388 | c.httpHeaders.Set(prestoTransactionHeader, "NONE") 389 | stmt := &driverStmt{conn: c, query: query} 390 | _, err := stmt.QueryContext(ctx, []driver.NamedValue{}) 391 | if err != nil { 392 | c.httpHeaders.Del(prestoTransactionHeader) 393 | return nil, err 394 | } 395 | 396 | return &driverTx{conn: c}, nil 397 | } 398 | 399 | // Prepare implements the driver.Conn interface. 400 | func (c *Conn) Prepare(query string) (driver.Stmt, error) { 401 | return nil, driver.ErrSkip 402 | } 403 | 404 | // PrepareContext implements the driver.ConnPrepareContext interface. 405 | func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 406 | return &driverStmt{conn: c, query: query}, nil 407 | } 408 | 409 | // Close implements the driver.Conn interface. 410 | func (c *Conn) Close() error { 411 | return nil 412 | } 413 | 414 | func (c *Conn) newRequest(method, url string, body io.Reader, hs http.Header) (*http.Request, error) { 415 | req, err := http.NewRequest(method, url, body) 416 | if err != nil { 417 | return nil, fmt.Errorf("presto: %v", err) 418 | } 419 | 420 | if c.kerberosEnabled { 421 | err = c.kerberosClient.SetSPNEGOHeader(req, "presto/"+req.URL.Hostname()) 422 | if err != nil { 423 | return nil, fmt.Errorf("error setting client SPNEGO header: %v", err) 424 | } 425 | } 426 | 427 | for k, v := range c.httpHeaders { 428 | req.Header[k] = v 429 | } 430 | for k, v := range hs { 431 | req.Header[k] = v 432 | } 433 | 434 | if c.auth != nil { 435 | pass, _ := c.auth.Password() 436 | req.SetBasicAuth(c.auth.Username(), pass) 437 | } 438 | return req, nil 439 | } 440 | 441 | func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response, error) { 442 | delay := 100 * time.Millisecond 443 | const maxDelayBetweenRequests = float64(15 * time.Second) 444 | timer := time.NewTimer(0) 445 | defer timer.Stop() 446 | for { 447 | select { 448 | case <-ctx.Done(): 449 | return nil, ctx.Err() 450 | case <-timer.C: 451 | timeout := DefaultQueryTimeout 452 | if deadline, ok := ctx.Deadline(); ok { 453 | timeout = deadline.Sub(time.Now()) 454 | } 455 | client := c.httpClient 456 | client.Timeout = timeout 457 | resp, err := client.Do(req) 458 | if err != nil { 459 | return nil, &ErrQueryFailed{Reason: err} 460 | } 461 | switch resp.StatusCode { 462 | case http.StatusOK: 463 | if id := resp.Header.Get(prestoStartedTransactionHeader); id != "" { 464 | c.httpHeaders.Set(prestoTransactionHeader, id) 465 | } else if resp.Header.Get(prestoClearTransactionHeader) == "true" { 466 | c.httpHeaders.Del(prestoTransactionHeader) 467 | } 468 | 469 | return resp, nil 470 | case http.StatusServiceUnavailable: 471 | resp.Body.Close() 472 | timer.Reset(delay) 473 | delay = time.Duration(math.Min( 474 | float64(delay)*math.Phi, 475 | maxDelayBetweenRequests, 476 | )) 477 | continue 478 | default: 479 | return nil, newErrQueryFailedFromResponse(resp) 480 | } 481 | } 482 | } 483 | } 484 | 485 | // ErrQueryFailed indicates that a query to presto failed. 486 | type ErrQueryFailed struct { 487 | StatusCode int 488 | Reason error 489 | } 490 | 491 | // Error implements the error interface. 492 | func (e *ErrQueryFailed) Error() string { 493 | return fmt.Sprintf("presto: query failed (%d %s): %q", 494 | e.StatusCode, http.StatusText(e.StatusCode), e.Reason) 495 | } 496 | 497 | func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed { 498 | const maxBytes = 8 * 1024 499 | defer resp.Body.Close() 500 | qf := &ErrQueryFailed{StatusCode: resp.StatusCode} 501 | b, err := ioutil.ReadAll(io.LimitReader(resp.Body, maxBytes)) 502 | if err != nil { 503 | qf.Reason = err 504 | return qf 505 | } 506 | reason := string(b) 507 | if resp.ContentLength > maxBytes { 508 | reason += "..." 509 | } 510 | qf.Reason = errors.New(reason) 511 | return qf 512 | } 513 | 514 | type driverStmt struct { 515 | conn *Conn 516 | query string 517 | user string 518 | } 519 | 520 | var ( 521 | _ driver.Stmt = &driverStmt{} 522 | _ driver.StmtQueryContext = &driverStmt{} 523 | ) 524 | 525 | func (st *driverStmt) Close() error { 526 | return nil 527 | } 528 | 529 | func (st *driverStmt) NumInput() int { 530 | return -1 531 | } 532 | 533 | func (st *driverStmt) Exec(args []driver.Value) (driver.Result, error) { 534 | return nil, ErrOperationNotSupported 535 | } 536 | 537 | type stmtResponse struct { 538 | ID string `json:"id"` 539 | InfoURI string `json:"infoUri"` 540 | NextURI string `json:"nextUri"` 541 | Stats stmtStats `json:"stats"` 542 | Error stmtError `json:"error"` 543 | } 544 | 545 | type stmtStats struct { 546 | State string `json:"state"` 547 | Scheduled bool `json:"scheduled"` 548 | Nodes int `json:"nodes"` 549 | TotalSplits int `json:"totalSplits"` 550 | QueuesSplits int `json:"queuedSplits"` 551 | RunningSplits int `json:"runningSplits"` 552 | CompletedSplits int `json:"completedSplits"` 553 | UserTimeMillis int `json:"userTimeMillis"` 554 | CPUTimeMillis int `json:"cpuTimeMillis"` 555 | WallTimeMillis int `json:"wallTimeMillis"` 556 | ProcessedRows int `json:"processedRows"` 557 | ProcessedBytes int `json:"processedBytes"` 558 | RootStage stmtStage `json:"rootStage"` 559 | } 560 | 561 | type stmtError struct { 562 | Message string `json:"message"` 563 | ErrorName string `json:"errorName"` 564 | ErrorCode int `json:"errorCode"` 565 | ErrorLocation stmtErrorLocation `json:"errorLocation"` 566 | FailureInfo stmtErrorFailureInfo `json:"failureInfo"` 567 | // Other fields omitted 568 | } 569 | 570 | type stmtErrorLocation struct { 571 | LineNumber int `json:"lineNumber"` 572 | ColumnNumber int `json:"columnNumber"` 573 | } 574 | 575 | type stmtErrorFailureInfo struct { 576 | Type string `json:"type"` 577 | // Other fields omitted 578 | } 579 | 580 | func (e stmtError) Error() string { 581 | return e.FailureInfo.Type + ": " + e.Message 582 | } 583 | 584 | type stmtStage struct { 585 | StageID string `json:"stageId"` 586 | State string `json:"state"` 587 | Done bool `json:"done"` 588 | Nodes int `json:"nodes"` 589 | TotalSplits int `json:"totalSplits"` 590 | QueuedSplits int `json:"queuedSplits"` 591 | RunningSplits int `json:"runningSplits"` 592 | CompletedSplits int `json:"completedSplits"` 593 | UserTimeMillis int `json:"userTimeMillis"` 594 | CPUTimeMillis int `json:"cpuTimeMillis"` 595 | WallTimeMillis int `json:"wallTimeMillis"` 596 | ProcessedRows int `json:"processedRows"` 597 | ProcessedBytes int `json:"processedBytes"` 598 | SubStages []stmtStage `json:"subStages"` 599 | } 600 | 601 | // EOF indicates the server has returned io.EOF for the given QueryID. 602 | type EOF struct { 603 | QueryID string 604 | } 605 | 606 | // Error implements the error interface. 607 | func (e *EOF) Error() string { 608 | return e.QueryID 609 | } 610 | 611 | func (st *driverStmt) Query(args []driver.Value) (driver.Rows, error) { 612 | return nil, driver.ErrSkip 613 | } 614 | 615 | func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 616 | query := st.query 617 | var hs http.Header 618 | 619 | if len(args) > 0 { 620 | hs = make(http.Header) 621 | var ss []string 622 | for _, arg := range args { 623 | s, err := Serial(arg.Value) 624 | if err != nil { 625 | return nil, err 626 | } 627 | if arg.Name == prestoUserHeader { 628 | st.user = s 629 | hs.Add(prestoUserHeader, st.user) 630 | } else if arg.Name == prestoClientTagsHeader { 631 | hs.Add(prestoClientTagsHeader, s) 632 | } else if arg.Name == prestoClientInfoHeader { 633 | hs.Add(prestoClientInfoHeader, s) 634 | } else { 635 | ss = append(ss, s) 636 | } 637 | } 638 | 639 | if len(ss) > 0 { 640 | if hs.Get(preparedStatementHeader) == "" { 641 | hs.Add(preparedStatementHeader, preparedStatementName+"="+url.QueryEscape(st.query)) 642 | } 643 | query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ") 644 | } 645 | } 646 | 647 | req, err := st.conn.newRequest("POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs) 648 | if err != nil { 649 | return nil, err 650 | } 651 | 652 | resp, err := st.conn.roundTrip(ctx, req) 653 | if err != nil { 654 | return nil, err 655 | } 656 | defer resp.Body.Close() 657 | var sr stmtResponse 658 | d := json.NewDecoder(resp.Body) 659 | d.UseNumber() 660 | err = d.Decode(&sr) 661 | if err != nil { 662 | return nil, fmt.Errorf("presto: %v", err) 663 | } 664 | err = handleResponseError(resp.StatusCode, sr.Error) 665 | if err != nil { 666 | return nil, err 667 | } 668 | rows := &driverRows{ 669 | ctx: ctx, 670 | stmt: st, 671 | nextURI: sr.NextURI, 672 | id: sr.ID, 673 | } 674 | completedChannel := make(chan struct{}) 675 | defer close(completedChannel) 676 | go func() { 677 | select { 678 | case <-ctx.Done(): 679 | err := rows.Close() 680 | if err != nil { 681 | return 682 | } 683 | case <-completedChannel: 684 | return 685 | } 686 | }() 687 | if err = rows.fetch(false); err != nil { 688 | return nil, err 689 | } 690 | return rows, nil 691 | } 692 | 693 | type rowsColumn struct { 694 | name string 695 | dbType string 696 | vc driver.ValueConverter 697 | } 698 | 699 | type driverRows struct { 700 | ctx context.Context 701 | stmt *driverStmt 702 | nextURI string 703 | id string 704 | 705 | err error 706 | rowindex int 707 | columns []rowsColumn 708 | data []queryData 709 | } 710 | 711 | var _ driver.Rows = &driverRows{} 712 | 713 | func (qr *driverRows) Close() error { 714 | if qr.nextURI != "" { 715 | hs := make(http.Header) 716 | hs.Add(prestoUserHeader, qr.stmt.user) 717 | req, err := qr.stmt.conn.newRequest("DELETE", qr.nextURI, nil, hs) 718 | if err != nil { 719 | return err 720 | } 721 | ctx, cancel := context.WithDeadline( 722 | context.Background(), 723 | time.Now().Add(DefaultCancelQueryTimeout), 724 | ) 725 | defer cancel() 726 | resp, err := qr.stmt.conn.roundTrip(ctx, req) 727 | if err != nil { 728 | qferr, ok := err.(*ErrQueryFailed) 729 | if ok && qferr.StatusCode == http.StatusNoContent { 730 | qr.nextURI = "" 731 | return nil 732 | } 733 | return err 734 | } 735 | resp.Body.Close() 736 | } 737 | return qr.err 738 | } 739 | 740 | func (qr *driverRows) Columns() []string { 741 | if qr.err != nil { 742 | return []string{} 743 | } 744 | if qr.columns == nil { 745 | if err := qr.fetch(false); err != nil { 746 | qr.err = err 747 | return []string{} 748 | } 749 | } 750 | res := make([]string, len(qr.columns)) 751 | for i, c := range qr.columns { 752 | res[i] = c.name 753 | } 754 | return res 755 | } 756 | 757 | var coltypeLengthSuffix = regexp.MustCompile(`\(\d+\)$`) 758 | 759 | func (qr *driverRows) ColumnTypeDatabaseTypeName(index int) string { 760 | name := qr.columns[index].dbType 761 | if m := coltypeLengthSuffix.FindStringSubmatch(name); m != nil { 762 | name = name[0 : len(name)-len(m[0])] 763 | } 764 | return name 765 | } 766 | 767 | func (qr *driverRows) Next(dest []driver.Value) error { 768 | if qr.err != nil { 769 | return qr.err 770 | } 771 | if qr.columns == nil || qr.rowindex >= len(qr.data) { 772 | if qr.nextURI == "" { 773 | qr.err = io.EOF 774 | return &EOF{QueryID: qr.id} 775 | } 776 | if err := qr.fetch(true); err != nil { 777 | qr.err = err 778 | if qr.err == io.EOF { 779 | return &EOF{QueryID: qr.id} 780 | } 781 | return qr.err 782 | } 783 | } 784 | if len(qr.columns) == 0 { 785 | qr.err = sql.ErrNoRows 786 | return qr.err 787 | } 788 | for i, v := range qr.columns { 789 | vv, err := v.vc.ConvertValue(qr.data[qr.rowindex][i]) 790 | if err != nil { 791 | qr.err = err 792 | return err 793 | } 794 | dest[i] = vv 795 | } 796 | qr.rowindex++ 797 | return nil 798 | } 799 | 800 | type queryResponse struct { 801 | ID string `json:"id"` 802 | InfoURI string `json:"infoUri"` 803 | PartialCancelURI string `json:"partialCancelUri"` 804 | NextURI string `json:"nextUri"` 805 | Columns []queryColumn `json:"columns"` 806 | Data []queryData `json:"data"` 807 | Stats stmtStats `json:"stats"` 808 | Error stmtError `json:"error"` 809 | } 810 | 811 | type queryColumn struct { 812 | Name string `json:"name"` 813 | Type string `json:"type"` 814 | TypeSignature typeSignature `json:"typeSignature"` 815 | } 816 | 817 | type queryData []interface{} 818 | 819 | type typeSignature struct { 820 | RawType string `json:"rawType"` 821 | TypeArguments []json.RawMessage `json:"typeArguments"` 822 | LiteralArguments []json.RawMessage `json:"literalArguments"` 823 | } 824 | 825 | type infoResponse struct { 826 | QueryID string `json:"queryId"` 827 | State string `json:"state"` 828 | } 829 | 830 | func handleResponseError(status int, respErr stmtError) error { 831 | switch respErr.ErrorName { 832 | case "": 833 | return nil 834 | case "USER_CANCELLED": 835 | return ErrQueryCancelled 836 | default: 837 | return &ErrQueryFailed{ 838 | StatusCode: status, 839 | Reason: &respErr, 840 | } 841 | } 842 | } 843 | 844 | func (qr *driverRows) fetch(allowEOF bool) error { 845 | hs := make(http.Header) 846 | hs.Add(prestoUserHeader, qr.stmt.user) 847 | req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs) 848 | if err != nil { 849 | return err 850 | } 851 | resp, err := qr.stmt.conn.roundTrip(qr.ctx, req) 852 | if err != nil { 853 | return err 854 | } 855 | defer resp.Body.Close() 856 | var qresp queryResponse 857 | d := json.NewDecoder(resp.Body) 858 | d.UseNumber() 859 | err = d.Decode(&qresp) 860 | if err != nil { 861 | return fmt.Errorf("presto: %v", err) 862 | } 863 | err = handleResponseError(resp.StatusCode, qresp.Error) 864 | if err != nil { 865 | return err 866 | } 867 | qr.rowindex = 0 868 | qr.data = qresp.Data 869 | qr.nextURI = qresp.NextURI 870 | if len(qr.data) == 0 { 871 | if qr.nextURI != "" { 872 | return qr.fetch(allowEOF) 873 | } 874 | if allowEOF { 875 | return io.EOF 876 | } 877 | } 878 | if qr.columns == nil && len(qresp.Columns) > 0 { 879 | return qr.initColumns(&qresp) 880 | } 881 | return nil 882 | } 883 | 884 | func (qr *driverRows) initColumns(resp *queryResponse) error { 885 | qr.columns = make([]rowsColumn, len(resp.Columns)) 886 | for i, col := range resp.Columns { 887 | vc, err := newComplexConverter(col.TypeSignature) 888 | if err != nil { 889 | return fmt.Errorf("presto: creating complex converter for %s: %w", col.Name, err) 890 | } 891 | qr.columns[i] = rowsColumn{ 892 | name: col.Name, 893 | dbType: col.Type, 894 | vc: vc, 895 | } 896 | } 897 | return nil 898 | } 899 | 900 | type typeConverter struct { 901 | typeName string 902 | parsedType []string // e.g. array, array, varchar, for [][]string 903 | } 904 | 905 | func newTypeConverter(typeName string) driver.ValueConverter { 906 | return &typeConverter{ 907 | typeName: typeName, 908 | parsedType: parseType(typeName), 909 | } 910 | } 911 | 912 | // parses presto types, e.g. array(varchar(10)) to "array", "varchar" 913 | // TODO: Use queryColumn.TypeSignature instead. 914 | func parseType(name string) []string { 915 | parts := strings.Split(name, "(") 916 | if len(parts) == 1 { 917 | return parts 918 | } 919 | last := len(parts) - 1 920 | parts[last] = strings.TrimRight(parts[last], ")") 921 | if len(parts[last]) > 0 { 922 | if _, err := strconv.Atoi(parts[last]); err == nil { 923 | parts = parts[:last] 924 | } 925 | } 926 | return parts 927 | } 928 | 929 | // ConvertValue implements the driver.ValueConverter interface. 930 | func (c *typeConverter) ConvertValue(v interface{}) (driver.Value, error) { 931 | switch strings.ToLower(c.parsedType[0]) { 932 | case "boolean": 933 | vv, err := scanNullBool(v) 934 | if !vv.Valid { 935 | return nil, err 936 | } 937 | return vv.Bool, err 938 | case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "unknown": 939 | vv, err := scanNullString(v) 940 | if !vv.Valid { 941 | return nil, err 942 | } 943 | return vv.String, err 944 | case "tinyint", "smallint", "integer", "bigint": 945 | vv, err := scanNullInt64(v) 946 | if !vv.Valid { 947 | return nil, err 948 | } 949 | return vv.Int64, err 950 | case "real", "double": 951 | vv, err := scanNullFloat64(v) 952 | if !vv.Valid { 953 | return nil, err 954 | } 955 | return vv.Float64, err 956 | case "date", "time", "time with time zone", "timestamp", "timestamp with time zone": 957 | vv, err := scanNullTime(v) 958 | if !vv.Valid { 959 | return nil, err 960 | } 961 | return vv.Time, err 962 | case "map": 963 | if err := validateMap(v); err != nil { 964 | return nil, err 965 | } 966 | return v, nil 967 | case "array": 968 | if err := validateSlice(v); err != nil { 969 | return nil, err 970 | } 971 | return v, nil 972 | default: 973 | return nil, fmt.Errorf("type not supported: %q", c.typeName) 974 | } 975 | } 976 | 977 | func validateMap(v interface{}) error { 978 | if v == nil { 979 | return nil 980 | } 981 | if _, ok := v.(map[string]interface{}); !ok { 982 | return fmt.Errorf("cannot convert %v (%T) to map", v, v) 983 | } 984 | return nil 985 | } 986 | 987 | func validateSlice(v interface{}) error { 988 | if v == nil { 989 | return nil 990 | } 991 | if _, ok := v.([]interface{}); !ok { 992 | return fmt.Errorf("cannot convert %v (%T) to slice", v, v) 993 | } 994 | return nil 995 | } 996 | 997 | func scanNullBool(v interface{}) (sql.NullBool, error) { 998 | if v == nil { 999 | return sql.NullBool{}, nil 1000 | } 1001 | vv, ok := v.(bool) 1002 | if !ok { 1003 | return sql.NullBool{}, 1004 | fmt.Errorf("cannot convert %v (%T) to bool", v, v) 1005 | } 1006 | return sql.NullBool{Valid: true, Bool: vv}, nil 1007 | } 1008 | 1009 | // NullSliceBool represents a slice of bool that may be null. 1010 | type NullSliceBool struct { 1011 | SliceBool []sql.NullBool 1012 | Valid bool 1013 | } 1014 | 1015 | // Scan implements the sql.Scanner interface. 1016 | func (s *NullSliceBool) Scan(value interface{}) error { 1017 | if value == nil { 1018 | return nil 1019 | } 1020 | vs, ok := value.([]interface{}) 1021 | if !ok { 1022 | return fmt.Errorf("presto: cannot convert %v (%T) to []bool", value, value) 1023 | } 1024 | slice := make([]sql.NullBool, len(vs)) 1025 | for i := range vs { 1026 | v, err := scanNullBool(vs[i]) 1027 | if err != nil { 1028 | return err 1029 | } 1030 | slice[i] = v 1031 | } 1032 | s.SliceBool = slice 1033 | s.Valid = true 1034 | return nil 1035 | } 1036 | 1037 | // NullSlice2Bool represents a two-dimensional slice of bool that may be null. 1038 | type NullSlice2Bool struct { 1039 | Slice2Bool [][]sql.NullBool 1040 | Valid bool 1041 | } 1042 | 1043 | // Scan implements the sql.Scanner interface. 1044 | func (s *NullSlice2Bool) Scan(value interface{}) error { 1045 | if value == nil { 1046 | return nil 1047 | } 1048 | vs, ok := value.([]interface{}) 1049 | if !ok { 1050 | return fmt.Errorf("presto: cannot convert %v (%T) to [][]bool", value, value) 1051 | } 1052 | slice := make([][]sql.NullBool, len(vs)) 1053 | for i := range vs { 1054 | var ss NullSliceBool 1055 | if err := ss.Scan(vs[i]); err != nil { 1056 | return err 1057 | } 1058 | slice[i] = ss.SliceBool 1059 | } 1060 | s.Slice2Bool = slice 1061 | s.Valid = true 1062 | return nil 1063 | } 1064 | 1065 | // NullSlice3Bool implements a three-dimensional slice of bool that may be null. 1066 | type NullSlice3Bool struct { 1067 | Slice3Bool [][][]sql.NullBool 1068 | Valid bool 1069 | } 1070 | 1071 | // Scan implements the sql.Scanner interface. 1072 | func (s *NullSlice3Bool) Scan(value interface{}) error { 1073 | if value == nil { 1074 | return nil 1075 | } 1076 | vs, ok := value.([]interface{}) 1077 | if !ok { 1078 | return fmt.Errorf("presto: cannot convert %v (%T) to [][][]bool", value, value) 1079 | } 1080 | slice := make([][][]sql.NullBool, len(vs)) 1081 | for i := range vs { 1082 | var ss NullSlice2Bool 1083 | if err := ss.Scan(vs[i]); err != nil { 1084 | return err 1085 | } 1086 | slice[i] = ss.Slice2Bool 1087 | } 1088 | s.Slice3Bool = slice 1089 | s.Valid = true 1090 | return nil 1091 | } 1092 | 1093 | func scanNullString(v interface{}) (sql.NullString, error) { 1094 | if v == nil { 1095 | return sql.NullString{}, nil 1096 | } 1097 | vv, ok := v.(string) 1098 | if !ok { 1099 | return sql.NullString{}, 1100 | fmt.Errorf("cannot convert %v (%T) to string", v, v) 1101 | } 1102 | return sql.NullString{Valid: true, String: vv}, nil 1103 | } 1104 | 1105 | // NullSliceString represents a slice of string that may be null. 1106 | type NullSliceString struct { 1107 | SliceString []sql.NullString 1108 | Valid bool 1109 | } 1110 | 1111 | // Scan implements the sql.Scanner interface. 1112 | func (s *NullSliceString) Scan(value interface{}) error { 1113 | if value == nil { 1114 | return nil 1115 | } 1116 | vs, ok := value.([]interface{}) 1117 | if !ok { 1118 | return fmt.Errorf("presto: cannot convert %v (%T) to []string", value, value) 1119 | } 1120 | slice := make([]sql.NullString, len(vs)) 1121 | for i := range vs { 1122 | v, err := scanNullString(vs[i]) 1123 | if err != nil { 1124 | return err 1125 | } 1126 | slice[i] = v 1127 | } 1128 | s.SliceString = slice 1129 | s.Valid = true 1130 | return nil 1131 | } 1132 | 1133 | // NullSlice2String represents a two-dimensional slice of string that may be null. 1134 | type NullSlice2String struct { 1135 | Slice2String [][]sql.NullString 1136 | Valid bool 1137 | } 1138 | 1139 | // Scan implements the sql.Scanner interface. 1140 | func (s *NullSlice2String) Scan(value interface{}) error { 1141 | if value == nil { 1142 | return nil 1143 | } 1144 | vs, ok := value.([]interface{}) 1145 | if !ok { 1146 | return fmt.Errorf("presto: cannot convert %v (%T) to [][]string", value, value) 1147 | } 1148 | slice := make([][]sql.NullString, len(vs)) 1149 | for i := range vs { 1150 | var ss NullSliceString 1151 | if err := ss.Scan(vs[i]); err != nil { 1152 | return err 1153 | } 1154 | slice[i] = ss.SliceString 1155 | } 1156 | s.Slice2String = slice 1157 | s.Valid = true 1158 | return nil 1159 | } 1160 | 1161 | // NullSlice3String implements a three-dimensional slice of string that may be null. 1162 | type NullSlice3String struct { 1163 | Slice3String [][][]sql.NullString 1164 | Valid bool 1165 | } 1166 | 1167 | // Scan implements the sql.Scanner interface. 1168 | func (s *NullSlice3String) Scan(value interface{}) error { 1169 | if value == nil { 1170 | return nil 1171 | } 1172 | vs, ok := value.([]interface{}) 1173 | if !ok { 1174 | return fmt.Errorf("presto: cannot convert %v (%T) to [][][]string", value, value) 1175 | } 1176 | slice := make([][][]sql.NullString, len(vs)) 1177 | for i := range vs { 1178 | var ss NullSlice2String 1179 | if err := ss.Scan(vs[i]); err != nil { 1180 | return err 1181 | } 1182 | slice[i] = ss.Slice2String 1183 | } 1184 | s.Slice3String = slice 1185 | s.Valid = true 1186 | return nil 1187 | } 1188 | 1189 | func scanNullInt64(v interface{}) (sql.NullInt64, error) { 1190 | if v == nil { 1191 | return sql.NullInt64{}, nil 1192 | } 1193 | vNumber, ok := v.(json.Number) 1194 | if !ok { 1195 | return sql.NullInt64{}, 1196 | fmt.Errorf("cannot convert %v (%T) to int64", v, v) 1197 | } 1198 | vv, err := vNumber.Int64() 1199 | if err != nil { 1200 | return sql.NullInt64{}, 1201 | fmt.Errorf("cannot convert %v (%T) to int64", v, v) 1202 | } 1203 | return sql.NullInt64{Valid: true, Int64: vv}, nil 1204 | } 1205 | 1206 | // NullSliceInt64 represents a slice of int64 that may be null. 1207 | type NullSliceInt64 struct { 1208 | SliceInt64 []sql.NullInt64 1209 | Valid bool 1210 | } 1211 | 1212 | // Scan implements the sql.Scanner interface. 1213 | func (s *NullSliceInt64) Scan(value interface{}) error { 1214 | if value == nil { 1215 | return nil 1216 | } 1217 | vs, ok := value.([]interface{}) 1218 | if !ok { 1219 | return fmt.Errorf("presto: cannot convert %v (%T) to []int64", value, value) 1220 | } 1221 | slice := make([]sql.NullInt64, len(vs)) 1222 | for i := range vs { 1223 | v, err := scanNullInt64(vs[i]) 1224 | if err != nil { 1225 | return err 1226 | } 1227 | slice[i] = v 1228 | } 1229 | s.SliceInt64 = slice 1230 | s.Valid = true 1231 | return nil 1232 | } 1233 | 1234 | // NullSlice2Int64 represents a two-dimensional slice of int64 that may be null. 1235 | type NullSlice2Int64 struct { 1236 | Slice2Int64 [][]sql.NullInt64 1237 | Valid bool 1238 | } 1239 | 1240 | // Scan implements the sql.Scanner interface. 1241 | func (s *NullSlice2Int64) Scan(value interface{}) error { 1242 | if value == nil { 1243 | return nil 1244 | } 1245 | vs, ok := value.([]interface{}) 1246 | if !ok { 1247 | return fmt.Errorf("presto: cannot convert %v (%T) to [][]int64", value, value) 1248 | } 1249 | slice := make([][]sql.NullInt64, len(vs)) 1250 | for i := range vs { 1251 | var ss NullSliceInt64 1252 | if err := ss.Scan(vs[i]); err != nil { 1253 | return err 1254 | } 1255 | slice[i] = ss.SliceInt64 1256 | } 1257 | s.Slice2Int64 = slice 1258 | s.Valid = true 1259 | return nil 1260 | } 1261 | 1262 | // NullSlice3Int64 implements a three-dimensional slice of int64 that may be null. 1263 | type NullSlice3Int64 struct { 1264 | Slice3Int64 [][][]sql.NullInt64 1265 | Valid bool 1266 | } 1267 | 1268 | // Scan implements the sql.Scanner interface. 1269 | func (s *NullSlice3Int64) Scan(value interface{}) error { 1270 | if value == nil { 1271 | return nil 1272 | } 1273 | vs, ok := value.([]interface{}) 1274 | if !ok { 1275 | return fmt.Errorf("presto: cannot convert %v (%T) to [][][]int64", value, value) 1276 | } 1277 | slice := make([][][]sql.NullInt64, len(vs)) 1278 | for i := range vs { 1279 | var ss NullSlice2Int64 1280 | if err := ss.Scan(vs[i]); err != nil { 1281 | return err 1282 | } 1283 | slice[i] = ss.Slice2Int64 1284 | } 1285 | s.Slice3Int64 = slice 1286 | s.Valid = true 1287 | return nil 1288 | } 1289 | 1290 | func scanNullFloat64(v interface{}) (sql.NullFloat64, error) { 1291 | if v == nil { 1292 | return sql.NullFloat64{}, nil 1293 | } 1294 | vNumber, ok := v.(json.Number) 1295 | if ok { 1296 | vFloat, err := vNumber.Float64() 1297 | if err != nil { 1298 | return sql.NullFloat64{}, fmt.Errorf("cannot convert %v (%T) to float64", vNumber, vNumber) 1299 | } 1300 | return sql.NullFloat64{Valid: true, Float64: vFloat}, nil 1301 | } 1302 | switch v { 1303 | case "NaN": 1304 | return sql.NullFloat64{Valid: true, Float64: math.NaN()}, nil 1305 | case "Infinity": 1306 | return sql.NullFloat64{Valid: true, Float64: math.Inf(+1)}, nil 1307 | case "-Infinity": 1308 | return sql.NullFloat64{Valid: true, Float64: math.Inf(-1)}, nil 1309 | default: 1310 | return sql.NullFloat64{}, fmt.Errorf("cannot convert %v (%T) to float64", v, v) 1311 | } 1312 | } 1313 | 1314 | // NullSliceFloat64 represents a slice of float64 that may be null. 1315 | type NullSliceFloat64 struct { 1316 | SliceFloat64 []sql.NullFloat64 1317 | Valid bool 1318 | } 1319 | 1320 | // Scan implements the sql.Scanner interface. 1321 | func (s *NullSliceFloat64) Scan(value interface{}) error { 1322 | if value == nil { 1323 | return nil 1324 | } 1325 | vs, ok := value.([]interface{}) 1326 | if !ok { 1327 | return fmt.Errorf("presto: cannot convert %v (%T) to []float64", value, value) 1328 | } 1329 | slice := make([]sql.NullFloat64, len(vs)) 1330 | for i := range vs { 1331 | v, err := scanNullFloat64(vs[i]) 1332 | if err != nil { 1333 | return err 1334 | } 1335 | slice[i] = v 1336 | } 1337 | s.SliceFloat64 = slice 1338 | s.Valid = true 1339 | return nil 1340 | } 1341 | 1342 | // NullSlice2Float64 represents a two-dimensional slice of float64 that may be null. 1343 | type NullSlice2Float64 struct { 1344 | Slice2Float64 [][]sql.NullFloat64 1345 | Valid bool 1346 | } 1347 | 1348 | // Scan implements the sql.Scanner interface. 1349 | func (s *NullSlice2Float64) Scan(value interface{}) error { 1350 | if value == nil { 1351 | return nil 1352 | } 1353 | vs, ok := value.([]interface{}) 1354 | if !ok { 1355 | return fmt.Errorf("presto: cannot convert %v (%T) to [][]float64", value, value) 1356 | } 1357 | slice := make([][]sql.NullFloat64, len(vs)) 1358 | for i := range vs { 1359 | var ss NullSliceFloat64 1360 | if err := ss.Scan(vs[i]); err != nil { 1361 | return err 1362 | } 1363 | slice[i] = ss.SliceFloat64 1364 | } 1365 | s.Slice2Float64 = slice 1366 | s.Valid = true 1367 | return nil 1368 | } 1369 | 1370 | // NullSlice3Float64 represents a three-dimensional slice of float64 that may be null. 1371 | type NullSlice3Float64 struct { 1372 | Slice3Float64 [][][]sql.NullFloat64 1373 | Valid bool 1374 | } 1375 | 1376 | // Scan implements the sql.Scanner interface. 1377 | func (s *NullSlice3Float64) Scan(value interface{}) error { 1378 | if value == nil { 1379 | return nil 1380 | } 1381 | vs, ok := value.([]interface{}) 1382 | if !ok { 1383 | return fmt.Errorf("presto: cannot convert %v (%T) to [][][]float64", value, value) 1384 | } 1385 | slice := make([][][]sql.NullFloat64, len(vs)) 1386 | for i := range vs { 1387 | var ss NullSlice2Float64 1388 | if err := ss.Scan(vs[i]); err != nil { 1389 | return err 1390 | } 1391 | slice[i] = ss.Slice2Float64 1392 | } 1393 | s.Slice3Float64 = slice 1394 | s.Valid = true 1395 | return nil 1396 | } 1397 | 1398 | var timeLayouts = []string{ 1399 | "2006-01-02", 1400 | "15:04:05.000", 1401 | "2006-01-02 15:04:05.000", 1402 | } 1403 | 1404 | func scanNullTime(v interface{}) (NullTime, error) { 1405 | if v == nil { 1406 | return NullTime{}, nil 1407 | } 1408 | vv, ok := v.(string) 1409 | if !ok { 1410 | return NullTime{}, fmt.Errorf("cannot convert %v (%T) to time string", v, v) 1411 | } 1412 | vparts := strings.Split(vv, " ") 1413 | if len(vparts) > 1 && !unicode.IsDigit(rune(vparts[len(vparts)-1][0])) { 1414 | return parseNullTimeWithLocation(vv) 1415 | } 1416 | return parseNullTime(vv) 1417 | } 1418 | 1419 | func parseNullTime(v string) (NullTime, error) { 1420 | var t time.Time 1421 | var err error 1422 | for _, layout := range timeLayouts { 1423 | t, err = time.ParseInLocation(layout, v, time.Local) 1424 | if err == nil { 1425 | return NullTime{Valid: true, Time: t}, nil 1426 | } 1427 | } 1428 | return NullTime{}, err 1429 | } 1430 | 1431 | func parseNullTimeWithLocation(v string) (NullTime, error) { 1432 | idx := strings.LastIndex(v, " ") 1433 | if idx == -1 { 1434 | return NullTime{}, fmt.Errorf("cannot convert %v (%T) to time+zone", v, v) 1435 | } 1436 | stamp, location := v[:idx], v[idx+1:] 1437 | loc, err := time.LoadLocation(location) 1438 | if err != nil { 1439 | return NullTime{}, fmt.Errorf("cannot load timezone %q: %v", location, err) 1440 | } 1441 | var t time.Time 1442 | for _, layout := range timeLayouts { 1443 | t, err = time.ParseInLocation(layout, stamp, loc) 1444 | if err == nil { 1445 | return NullTime{Valid: true, Time: t}, nil 1446 | } 1447 | } 1448 | return NullTime{}, err 1449 | } 1450 | 1451 | // NullTime represents a time.Time value that can be null. 1452 | // The NullTime supports presto's Date, Time and Timestamp data types, 1453 | // with or without time zone. 1454 | type NullTime struct { 1455 | Time time.Time 1456 | Valid bool 1457 | } 1458 | 1459 | // Scan implements the sql.Scanner interface. 1460 | func (s *NullTime) Scan(value interface{}) error { 1461 | switch value.(type) { 1462 | case time.Time: 1463 | s.Time, s.Valid = value.(time.Time) 1464 | case NullTime: 1465 | *s = value.(NullTime) 1466 | } 1467 | return nil 1468 | } 1469 | 1470 | // NullSliceTime represents a slice of time.Time that may be null. 1471 | type NullSliceTime struct { 1472 | SliceTime []NullTime 1473 | Valid bool 1474 | } 1475 | 1476 | // Scan implements the sql.Scanner interface. 1477 | func (s *NullSliceTime) Scan(value interface{}) error { 1478 | if value == nil { 1479 | return nil 1480 | } 1481 | vs, ok := value.([]interface{}) 1482 | if !ok { 1483 | return fmt.Errorf("presto: cannot convert %v (%T) to []time.Time", value, value) 1484 | } 1485 | slice := make([]NullTime, len(vs)) 1486 | for i := range vs { 1487 | v, err := scanNullTime(vs[i]) 1488 | if err != nil { 1489 | return err 1490 | } 1491 | slice[i] = v 1492 | } 1493 | s.SliceTime = slice 1494 | s.Valid = true 1495 | return nil 1496 | } 1497 | 1498 | // NullSlice2Time represents a two-dimensional slice of time.Time that may be null. 1499 | type NullSlice2Time struct { 1500 | Slice2Time [][]NullTime 1501 | Valid bool 1502 | } 1503 | 1504 | // Scan implements the sql.Scanner interface. 1505 | func (s *NullSlice2Time) Scan(value interface{}) error { 1506 | if value == nil { 1507 | return nil 1508 | } 1509 | vs, ok := value.([]interface{}) 1510 | if !ok { 1511 | return fmt.Errorf("presto: cannot convert %v (%T) to [][]time.Time", value, value) 1512 | } 1513 | slice := make([][]NullTime, len(vs)) 1514 | for i := range vs { 1515 | var ss NullSliceTime 1516 | if err := ss.Scan(vs[i]); err != nil { 1517 | return err 1518 | } 1519 | slice[i] = ss.SliceTime 1520 | } 1521 | s.Slice2Time = slice 1522 | s.Valid = true 1523 | return nil 1524 | } 1525 | 1526 | // NullSlice3Time represents a three-dimensional slice of time.Time that may be null. 1527 | type NullSlice3Time struct { 1528 | Slice3Time [][][]NullTime 1529 | Valid bool 1530 | } 1531 | 1532 | // Scan implements the sql.Scanner interface. 1533 | func (s *NullSlice3Time) Scan(value interface{}) error { 1534 | if value == nil { 1535 | return nil 1536 | } 1537 | vs, ok := value.([]interface{}) 1538 | if !ok { 1539 | return fmt.Errorf("presto: cannot convert %v (%T) to [][][]time.Time", value, value) 1540 | } 1541 | slice := make([][][]NullTime, len(vs)) 1542 | for i := range vs { 1543 | var ss NullSlice2Time 1544 | if err := ss.Scan(vs[i]); err != nil { 1545 | return err 1546 | } 1547 | slice[i] = ss.Slice2Time 1548 | } 1549 | s.Slice3Time = slice 1550 | s.Valid = true 1551 | return nil 1552 | } 1553 | 1554 | // NullMap represents a map type that may be null. 1555 | type NullMap struct { 1556 | Map map[string]interface{} 1557 | Valid bool 1558 | } 1559 | 1560 | // Scan implements the sql.Scanner interface. 1561 | func (m *NullMap) Scan(v interface{}) error { 1562 | if v == nil { 1563 | return nil 1564 | } 1565 | m.Map, m.Valid = v.(map[string]interface{}) 1566 | return nil 1567 | } 1568 | 1569 | // NullSliceMap represents a slice of NullMap that may be null. 1570 | type NullSliceMap struct { 1571 | SliceMap []NullMap 1572 | Valid bool 1573 | } 1574 | 1575 | // Scan implements the sql.Scanner interface. 1576 | func (s *NullSliceMap) Scan(value interface{}) error { 1577 | if value == nil { 1578 | return nil 1579 | } 1580 | vs, ok := value.([]interface{}) 1581 | if !ok { 1582 | return fmt.Errorf("presto: cannot convert %v (%T) to []NullMap", value, value) 1583 | } 1584 | slice := make([]NullMap, len(vs)) 1585 | for i := range vs { 1586 | if err := validateMap(vs[i]); err != nil { 1587 | return fmt.Errorf("cannot convert %v (%T) to []NullMap", value, value) 1588 | } 1589 | m := NullMap{} 1590 | m.Scan(vs[i]) 1591 | slice[i] = m 1592 | } 1593 | s.SliceMap = slice 1594 | s.Valid = true 1595 | return nil 1596 | } 1597 | 1598 | // NullSlice2Map represents a two-dimensional slice of NullMap that may be null. 1599 | type NullSlice2Map struct { 1600 | Slice2Map [][]NullMap 1601 | Valid bool 1602 | } 1603 | 1604 | // Scan implements the sql.Scanner interface. 1605 | func (s *NullSlice2Map) Scan(value interface{}) error { 1606 | if value == nil { 1607 | return nil 1608 | } 1609 | vs, ok := value.([]interface{}) 1610 | if !ok { 1611 | return fmt.Errorf("presto: cannot convert %v (%T) to [][]NullMap", value, value) 1612 | } 1613 | slice := make([][]NullMap, len(vs)) 1614 | for i := range vs { 1615 | var ss NullSliceMap 1616 | if err := ss.Scan(vs[i]); err != nil { 1617 | return err 1618 | } 1619 | slice[i] = ss.SliceMap 1620 | } 1621 | s.Slice2Map = slice 1622 | s.Valid = true 1623 | return nil 1624 | } 1625 | 1626 | // NullSlice3Map represents a three-dimensional slice of NullMap that may be null. 1627 | type NullSlice3Map struct { 1628 | Slice3Map [][][]NullMap 1629 | Valid bool 1630 | } 1631 | 1632 | // Scan implements the sql.Scanner interface. 1633 | func (s *NullSlice3Map) Scan(value interface{}) error { 1634 | if value == nil { 1635 | return nil 1636 | } 1637 | vs, ok := value.([]interface{}) 1638 | if !ok { 1639 | return fmt.Errorf("presto: cannot convert %v (%T) to [][][]NullMap", value, value) 1640 | } 1641 | slice := make([][][]NullMap, len(vs)) 1642 | for i := range vs { 1643 | var ss NullSlice2Map 1644 | if err := ss.Scan(vs[i]); err != nil { 1645 | return err 1646 | } 1647 | slice[i] = ss.Slice2Map 1648 | } 1649 | s.Slice3Map = slice 1650 | s.Valid = true 1651 | return nil 1652 | } 1653 | --------------------------------------------------------------------------------