├── changelog.tmpl ├── smoketest ├── drill-override.conf ├── setup-and-run.sh ├── Dockerfile ├── drill-override-kerberos.conf ├── docker-compose.yml └── storage-plugins-override.conf ├── .release.yml ├── .gitignore ├── internal ├── data │ ├── vector.go │ ├── type_traits_numeric.gen_test.go.tmpl │ ├── numeric_vec_typemap.gen.go.tmpl │ ├── numeric_vec_typemap.gen_test.go.tmpl │ ├── type_traits_numeric.gen.go.tmpl │ ├── numeric.tmpldata │ ├── arrow_numeric.gen_test.go.tmpl │ ├── vector_numeric.gen.go.tmpl │ ├── vector_numeric.gen_test.go.tmpl │ ├── numeric_vec_typemap.gen.go │ ├── decimal_utils.go │ ├── arrow.go │ ├── type_traits_numeric.gen_test.go │ ├── date_time_vectors.go │ ├── data_vector.go │ └── numeric_vec_typemap.gen_test.go ├── log │ └── log.go ├── cmd │ ├── tmpl │ │ └── main.go │ └── drillProto │ │ └── main.go └── rpc │ └── proto │ └── exec │ ├── SchemaDef.pb.go │ └── bit │ └── ExecutionProtos.pb.go ├── go.mod ├── LICENSE ├── .github └── workflows │ ├── go.yml │ ├── release.yml │ └── smoketest.yml ├── doc.go ├── sasl ├── krb_client.go ├── gssapi_test.go ├── sasl.go └── gssapi.go ├── driver ├── rows.go ├── prepared.go ├── connector.go ├── conn.go ├── connector_test.go ├── prepared_test.go ├── conn_test.go └── rows_test.go ├── options.go ├── smoke_test.go ├── kerberos_smoke_test.go ├── zk_handler.go ├── zk_handler_test.go ├── README.md ├── meta_requests.go ├── auth.go ├── utils.go ├── utils_test.go └── auth_test.go /changelog.tmpl: -------------------------------------------------------------------------------- 1 | {{ .Commits -}} 2 | -------------------------------------------------------------------------------- /smoketest/drill-override.conf: -------------------------------------------------------------------------------- 1 | drill.exec: { 2 | cluster-id: "drillbits1", 3 | zk.connect: "localhost:2181,localhost:2182,localhost:2183" 4 | } 5 | -------------------------------------------------------------------------------- /.release.yml: -------------------------------------------------------------------------------- 1 | branch: 2 | master: release 3 | release: 'github' 4 | github: 5 | repo: "go-drill" 6 | user: "factset" 7 | commitFormat: angular 8 | changelog: 9 | printAll: true 10 | templatePath: "./changelog.tmpl" 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | protobuf/*.proto 18 | internal/cmd/tester 19 | .vscode 20 | .version 21 | -------------------------------------------------------------------------------- /internal/data/vector.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | type vector struct { 4 | rawData []byte 5 | } 6 | 7 | func (v *vector) GetRawBytes() []byte { 8 | return v.rawData 9 | } 10 | 11 | type nullByteMap struct { 12 | byteMap []byte 13 | } 14 | 15 | func (n *nullByteMap) IsNull(index uint) bool { 16 | return n.byteMap[index] == 0 17 | } 18 | 19 | func (n *nullByteMap) GetNullBytemap() []byte { 20 | return n.byteMap 21 | } 22 | -------------------------------------------------------------------------------- /smoketest/setup-and-run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cat < /opt/drill/conf/drill-override.conf 4 | drill.exec: { 5 | cluster-id: "drillbits1", 6 | zk.connect: "$DRILL_ZK_CLUSTER" 7 | } 8 | EOF 9 | 10 | cleanup() { 11 | /opt/drill/bin/drillbit.sh graceful_stop 12 | } 13 | 14 | trap 'cleanup' SIGTERM 15 | 16 | /opt/drill/bin/drillbit.sh --config /opt/drill/conf start 17 | 18 | tail -f /opt/drill/log/drillbit.log 19 | -------------------------------------------------------------------------------- /smoketest/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM openjdk:8u232-jdk 2 | 3 | RUN mkdir /opt/drill 4 | 5 | RUN wget -O - http://apache.mirrors.hoobly.com/drill/drill-1.17.0/apache-drill-1.17.0.tar.gz | tar -xzC /opt 6 | 7 | RUN mv /opt/apache-drill-1.17.0/* /opt/drill/ 8 | 9 | WORKDIR /opt/drill 10 | 11 | COPY ./setup-and-run.sh /opt/drill/setup-and-run.sh 12 | COPY ./storage-plugins-override.conf /opt/drill/conf/storage-plugins-override.conf 13 | RUN chmod +x /opt/drill/setup-and-run.sh 14 | 15 | ENTRYPOINT /opt/drill/setup-and-run.sh 16 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/factset/go-drill 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/apache/arrow/go/v7 v7.0.0 7 | github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 8 | github.com/go-zookeeper/zk v1.0.2 9 | github.com/golang/protobuf v1.5.2 10 | github.com/google/go-github/v32 v32.1.0 11 | github.com/jcmturner/gofork v1.0.0 12 | github.com/jcmturner/gokrb5/v8 v8.4.1 13 | github.com/rs/zerolog v1.21.0 14 | github.com/stretchr/testify v1.7.0 15 | golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a // indirect 16 | google.golang.org/protobuf v1.27.1 17 | ) 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 FactSet Research Systems, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /smoketest/drill-override-kerberos.conf: -------------------------------------------------------------------------------- 1 | drill.exec: { 2 | cluster-id: "drillbits1", 3 | zk.connect: "localhost:2181,localhost:2182,localhost:2183", 4 | impersonation: { 5 | enabled: true, 6 | max_chained_user_hops: 3 7 | } 8 | security.auth: { 9 | mechanisms: ["KERBEROS", "PLAIN"], 10 | principal: "drill/_host@EXAMPLE.COM", 11 | keytab: "/tmp/drill.keytab" 12 | } 13 | security.user: { 14 | auth.enabled: true, 15 | auth.packages += "org.apache.drill.exec.rpc.user.security", 16 | auth.impl: "pam4j", 17 | auth.pam_profiles: ["sudo", "login"], 18 | encryption.sasl.enabled: true, 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /internal/log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "os" 5 | "time" 6 | 7 | "github.com/rs/zerolog" 8 | "github.com/rs/zerolog/log" 9 | ) 10 | 11 | var Logger zerolog.Logger 12 | 13 | func init() { 14 | loglevel := os.Getenv("GO_DRILL_LOG_LEVEL") 15 | lvl, err := zerolog.ParseLevel(loglevel) 16 | if err != nil { 17 | log.Printf("invalid value '%s' given for GO_DRILL_LOG_LEVEL. ignoring", loglevel) 18 | } 19 | 20 | Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC822}).Level(lvl).With().Timestamp().Logger() 21 | } 22 | 23 | func Printf(format string, v ...interface{}) { 24 | Logger.Printf(format, v...) 25 | } 26 | 27 | func Print(v ...interface{}) { 28 | Logger.Print(v...) 29 | } 30 | -------------------------------------------------------------------------------- /internal/data/type_traits_numeric.gen_test.go.tmpl: -------------------------------------------------------------------------------- 1 | package data_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/factset/go-drill/internal/data" 8 | ) 9 | 10 | {{- range .In}} 11 | 12 | func Test{{.Name}}Traits(t *testing.T) { 13 | const N = 10 14 | b1 := data.{{.Name}}Traits.CastToBytes([]{{.Type}}{ 15 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16 | }) 17 | 18 | v1 := data.{{.Name}}Traits.CastFromBytes(b1) 19 | for i, v := range v1 { 20 | if got, want := v, {{.Type}}(i); got != want { 21 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 22 | } 23 | } 24 | 25 | v2 := make([]{{.Type}}, N) 26 | data.{{.Name}}Traits.Copy(v2, v1) 27 | 28 | if !reflect.DeepEqual(v1, v2) { 29 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 30 | } 31 | } 32 | {{end}} 33 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | name: Build 12 | runs-on: ubuntu-latest 13 | steps: 14 | 15 | - name: Set up Go 1.x 16 | uses: actions/setup-go@v2 17 | with: 18 | go-version: ^1.16 19 | id: go 20 | 21 | - name: Check out code into the Go module directory 22 | uses: actions/checkout@v2 23 | 24 | - name: Get dependencies 25 | run: go mod download 26 | 27 | - name: Build 28 | run: go build -v . 29 | 30 | - name: Test 31 | run: | 32 | go install github.com/ory/go-acc@latest 33 | go-acc -o coverage.out ./... -- -race -v 34 | 35 | - name: Codecov 36 | uses: codecov/codecov-action@v1.0.12 37 | with: 38 | file: coverage.out 39 | flags: unittests 40 | -------------------------------------------------------------------------------- /internal/data/numeric_vec_typemap.gen.go.tmpl: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "github.com/factset/go-drill/internal/rpc/proto/common" 5 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 6 | ) 7 | 8 | func NewNumericValueVec(rawData []byte, meta *shared.SerializedField) DataVector { 9 | if meta.GetMajorType().GetMode() == common.DataMode_OPTIONAL { 10 | switch meta.GetMajorType().GetMinorType() { 11 | {{range .In}} 12 | {{ if .ProtoType -}} 13 | case common.MinorType_{{.ProtoType}}: 14 | return NewNullable{{.Name}}Vector(rawData, meta) 15 | {{end}} 16 | {{end}} 17 | default: 18 | return nil 19 | } 20 | } 21 | 22 | switch meta.GetMajorType().GetMinorType() { 23 | {{range .In}} 24 | {{ if .ProtoType -}} 25 | case common.MinorType_{{.ProtoType}}: 26 | return New{{.Name}}Vector(rawData, meta) 27 | {{end}} 28 | {{end}} 29 | default: 30 | return nil 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package drill is a highly efficient Pure Go client and driver for Apache Drill 2 | // and Dremio. 3 | // 4 | // A driver for the database/sql package is also provided via the driver subpackage. 5 | // This can be used like so: 6 | // 7 | // import ( 8 | // "database/sql" 9 | // 10 | // _ "github.com/factset/go-drill/driver" 11 | // ) 12 | // 13 | // func main() { 14 | // props := []string{ 15 | // "zk=zookeeper1,zookeeper2,zookeeper3", 16 | // "auth=kerberos", 17 | // "service=", 18 | // "cluster=", 19 | // } 20 | // 21 | // db, err := sql.Open("drill", strings.Join(props, ";")) 22 | // } 23 | // 24 | // Also, currently logging of the internals can be turned on via the environment 25 | // variable GO_DRILL_LOG_LEVEL. This uses github.com/rs/zerolog to do the logging 26 | // so anything that is valid to pass to the zerolog.ParseLevel function is valid 27 | // as a value for the environment variable. 28 | package drill 29 | -------------------------------------------------------------------------------- /smoketest/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | zoo1: 4 | image: zookeeper 5 | restart: always 6 | hostname: zoo1 7 | ports: 8 | - 2181:2181 9 | environment: 10 | ZOO_MY_ID: 1 11 | ZOO_SERVERS: server.1=0.0.0.0:2888:3888;2181 server.2=zoo2:2888:3888;2181 server.3=zoo3:2888:3888;2181 12 | 13 | zoo2: 14 | image: zookeeper 15 | restart: always 16 | hostname: zoo2 17 | ports: 18 | - 2182:2181 19 | environment: 20 | ZOO_MY_ID: 2 21 | ZOO_SERVERS: server.1=zoo1:2888:3888;2181 server.2=0.0.0.0:2888:3888;2181 server.3=zoo3:2888:3888;2181 22 | 23 | zoo3: 24 | image: zookeeper 25 | restart: always 26 | hostname: zoo3 27 | ports: 28 | - 2183:2181 29 | environment: 30 | ZOO_MY_ID: 3 31 | ZOO_SERVERS: server.1=zoo1:2888:3888;2181 server.2=zoo2:2888:3888;2181 server.3=0.0.0.0:2888:3888;2181 32 | 33 | drill: 34 | build: . 35 | image: drill:test 36 | depends_on: 37 | - zoo1 38 | - zoo2 39 | - zoo3 40 | restart: always 41 | hostname: drill 42 | ports: 43 | - 8047:8047 44 | - 31010:31010 45 | environment: 46 | DRILL_ZK_CLUSTER: zoo1:2181,zoo2:2181,zoo3:2181 47 | -------------------------------------------------------------------------------- /sasl/krb_client.go: -------------------------------------------------------------------------------- 1 | package sasl 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/user" 7 | "strings" 8 | 9 | "github.com/jcmturner/gokrb5/v8/client" 10 | "github.com/jcmturner/gokrb5/v8/config" 11 | "github.com/jcmturner/gokrb5/v8/credentials" 12 | ) 13 | 14 | // getKrbClient is an internal helper for checking the KRB5_CONFIG and KRB5CCNAME 15 | // environment variables in order to get the current kerberos cached ticket 16 | // 17 | // this does not currently support performing authentication itself and assumes you 18 | // already have cached credentials 19 | func getKrbClient(principal string) (*client.Client, error) { 20 | configPath := os.Getenv("KRB5_CONFIG") 21 | if configPath == "" { 22 | configPath = "/etc/krb5.conf" 23 | } 24 | 25 | cfg, err := config.Load(configPath) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | ccachePath := os.Getenv("KRB5CCNAME") 31 | if strings.Contains(ccachePath, ":") { 32 | if strings.HasPrefix(ccachePath, "FILE:") { 33 | ccachePath = strings.SplitN(ccachePath, ":", 2)[1] 34 | } else { 35 | return nil, fmt.Errorf("unusable cache: %s", ccachePath) 36 | } 37 | } else if ccachePath == "" { 38 | u, err := user.Current() 39 | if err != nil { 40 | return nil, err 41 | } 42 | ccachePath = fmt.Sprintf("/tmp/krb5cc_%s", u.Uid) 43 | } 44 | 45 | ccache, _ := credentials.LoadCCache(ccachePath) 46 | 47 | return client.NewFromCCache(ccache, cfg) 48 | } 49 | -------------------------------------------------------------------------------- /internal/data/numeric_vec_typemap.gen_test.go.tmpl: -------------------------------------------------------------------------------- 1 | package data_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/factset/go-drill/internal/data" 7 | "github.com/factset/go-drill/internal/rpc/proto/common" 8 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 9 | "github.com/stretchr/testify/assert" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | {{- range .In}} 14 | {{ if .ProtoType -}} 15 | func TestNewNumericVecRequired{{.Name}}(t *testing.T) { 16 | const N = 10 17 | b := data.{{.Name}}Traits.CastToBytes([]{{.Type}}{ 18 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19 | }) 20 | 21 | meta := &shared.SerializedField{ 22 | MajorType: &common.MajorType{ 23 | MinorType: common.MinorType_{{.ProtoType}}.Enum(), 24 | Mode: common.DataMode_REQUIRED.Enum(), 25 | }, 26 | } 27 | 28 | dv := data.NewValueVec(b, meta) 29 | assert.IsType(t, (*data.{{.Name}}Vector)(nil), dv) 30 | } 31 | 32 | func TestNewNumericVecOptional{{.Name}}(t *testing.T) { 33 | const N = 10 34 | b := data.{{.Name}}Traits.CastToBytes([]{{.Type}}{ 35 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 36 | }) 37 | 38 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 39 | 40 | meta := &shared.SerializedField{ 41 | ValueCount: proto.Int32(0), 42 | MajorType: &common.MajorType{ 43 | MinorType: common.MinorType_{{.ProtoType}}.Enum(), 44 | Mode: common.DataMode_OPTIONAL.Enum(), 45 | }, 46 | } 47 | 48 | dv := data.NewValueVec(append(bytemap, b...), meta) 49 | assert.IsType(t, (*data.Nullable{{.Name}}Vector)(nil), dv) 50 | } 51 | {{end}} 52 | {{end}} 53 | -------------------------------------------------------------------------------- /driver/rows.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "database/sql/driver" 5 | "io" 6 | "reflect" 7 | 8 | "github.com/factset/go-drill" 9 | ) 10 | 11 | type rows struct { 12 | handle drill.DataHandler 13 | curRow int 14 | } 15 | 16 | func (r *rows) Close() error { 17 | return r.handle.Close() 18 | } 19 | 20 | func (r *rows) Columns() []string { 21 | return r.handle.GetCols() 22 | } 23 | 24 | func (r *rows) Next(dest []driver.Value) error { 25 | rb := r.handle.GetRecordBatch() 26 | if rb == nil { 27 | return io.EOF 28 | } 29 | 30 | if int32(r.curRow) >= rb.NumRows() { 31 | var err error 32 | rb, err = r.handle.Next() 33 | if err != nil { 34 | return err 35 | } 36 | 37 | r.curRow = 0 38 | } 39 | 40 | src := rb.GetVectors() 41 | for i := range dest { 42 | dest[i] = src[i].Value(uint(r.curRow)) 43 | } 44 | 45 | r.curRow++ 46 | return nil 47 | } 48 | 49 | func (r *rows) ColumnTypeScanType(index int) reflect.Type { 50 | return r.handle.GetRecordBatch().GetVectors()[index].Type() 51 | } 52 | 53 | func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { 54 | return r.handle.GetRecordBatch().IsNullable(index), true 55 | } 56 | 57 | func (r *rows) ColumnTypeDatabaseTypeName(index int) string { 58 | return r.handle.GetRecordBatch().TypeName(index) 59 | } 60 | 61 | func (r *rows) ColumnTypeLength(index int) (int64, bool) { 62 | return r.handle.GetRecordBatch().GetVectors()[index].TypeLen() 63 | } 64 | 65 | func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { 66 | return r.handle.GetRecordBatch().PrecisionScale(index) 67 | } 68 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | jobs: 8 | release: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Wait For Build Check 12 | uses: lewagon/wait-on-check-action@v0.1 13 | with: 14 | ref: master 15 | check-name: "Build" 16 | repo-token: ${{ secrets.GITHUB_TOKEN }} 17 | wait-interval: 10 # seconds 18 | 19 | - name: Wait For ZK SmokeTest (true) 20 | uses: lewagon/wait-on-check-action@v0.1 21 | with: 22 | ref: master 23 | check-name: "ZK Drill Test (true)" 24 | repo-token: ${{ secrets.GITHUB_TOKEN }} 25 | wait-interval: 10 # seconds 26 | 27 | - name: Wait For ZK SmokeTest (false) 28 | uses: lewagon/wait-on-check-action@v0.1 29 | with: 30 | ref: master 31 | check-name: "ZK Drill Test (false)" 32 | repo-token: ${{ secrets.GITHUB_TOKEN }} 33 | wait-interval: 10 # seconds 34 | 35 | - name: Setup Go 36 | uses: actions/setup-go@v2 37 | with: 38 | go-version: ^1.14 39 | id: go 40 | 41 | - uses: actions/checkout@v2 42 | - run: | 43 | git fetch --prune --unshallow --tags 44 | 45 | - name: Get go-semantic-release 46 | env: 47 | GO111MODULE: "on" 48 | run: go get github.com/Nightapes/go-semantic-release/cmd/go-semantic-release 49 | 50 | - name: Do release 51 | env: 52 | CI: "true" 53 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 54 | run: go-semantic-release release --loglevel trace 55 | -------------------------------------------------------------------------------- /internal/data/type_traits_numeric.gen.go.tmpl: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "reflect" 5 | "unsafe" 6 | ) 7 | 8 | var ( 9 | {{ range .In }} 10 | {{.Name}}Traits {{.name}}Traits 11 | {{- end}} 12 | ) 13 | 14 | {{range .In}} 15 | // {{ .Name}} traits 16 | 17 | const ( 18 | // {{.Name}}SizeBytes specifies the number of bytes required to store a single {{.Type}} in memory 19 | {{.Name}}SizeBytes = int(unsafe.Sizeof({{.Type}}({{.Default}}))) 20 | ) 21 | 22 | type {{.name}}Traits struct{} 23 | 24 | // BytesRequired returns the number of bytes required to store n elements in memory 25 | func ({{.name}}Traits) BytesRequired(n int) int { return {{.Name}}SizeBytes * n } 26 | 27 | // CastFromBytes reinterprets the slice b to a slice of type {{.Type}} 28 | // 29 | // NOTE: len(b) must be a multiple of {{.Name}}SizeBytes 30 | func ({{.name}}Traits) CastFromBytes(b []byte) []{{.Type}} { 31 | h := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 32 | 33 | var res []{{.Type}} 34 | s := (*reflect.SliceHeader)(unsafe.Pointer(&res)) 35 | s.Data = h.Data 36 | s.Len = h.Len/{{.Name}}SizeBytes 37 | s.Cap = h.Cap/{{.Name}}SizeBytes 38 | 39 | return res 40 | } 41 | 42 | // CastToBytes reinterprets the slice b to a slice of bytes. 43 | func ({{.name}}Traits) CastToBytes(b []{{.Type}}) []byte { 44 | h := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 45 | 46 | var res []byte 47 | s := (*reflect.SliceHeader)(unsafe.Pointer(&res)) 48 | s.Data = h.Data 49 | s.Len = h.Len*{{.Name}}SizeBytes 50 | s.Cap = h.Cap*{{.Name}}SizeBytes 51 | 52 | return res 53 | } 54 | 55 | // Copy copies src to dst 56 | func ({{.name}}Traits) Copy(dst, src []{{.Type}}) { copy(dst, src) } 57 | {{end}} 58 | -------------------------------------------------------------------------------- /internal/data/numeric.tmpldata: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "Name": "Int64", 4 | "name": "int64", 5 | "Type": "int64", 6 | "Default": "0", 7 | "Size": "8", 8 | "ProtoType": "BIGINT" 9 | }, 10 | { 11 | "Name": "Int32", 12 | "name": "int32", 13 | "Type": "int32", 14 | "Default": "0", 15 | "Size": "4", 16 | "ProtoType": "INT" 17 | }, 18 | { 19 | "Name": "Float64", 20 | "name": "float64", 21 | "Type": "float64", 22 | "Default": "0", 23 | "Size": "8", 24 | "ProtoType": "FLOAT8" 25 | }, 26 | { 27 | "Name": "Uint64", 28 | "name": "uint64", 29 | "Type": "uint64", 30 | "Default": "0", 31 | "Size": "8", 32 | "ProtoType": "UINT8" 33 | }, 34 | { 35 | "Name": "Uint32", 36 | "name": "uint32", 37 | "Type": "uint32", 38 | "Default": "0", 39 | "Size": "4", 40 | "ProtoType": "UINT4" 41 | }, 42 | { 43 | "Name": "Float32", 44 | "name": "float32", 45 | "Type": "float32", 46 | "Default": "0", 47 | "Size": "4", 48 | "ProtoType": "FLOAT4" 49 | }, 50 | { 51 | "Name": "Int16", 52 | "name": "int16", 53 | "Type": "int16", 54 | "Default": "0", 55 | "Size": "2", 56 | "ProtoType": "SMALLINT" 57 | }, 58 | { 59 | "Name": "Uint16", 60 | "name": "uint16", 61 | "Type": "uint16", 62 | "Default": "0", 63 | "Size": "2", 64 | "ProtoType": "UINT2" 65 | }, 66 | { 67 | "Name": "Int8", 68 | "name": "int8", 69 | "Type": "int8", 70 | "Default": "0", 71 | "Size": "1", 72 | "ProtoType": "TINYINT" 73 | }, 74 | { 75 | "Name": "Uint8", 76 | "name": "uint8", 77 | "Type": "uint8", 78 | "Default": "0", 79 | "Size": "1", 80 | "ProtoType": "UINT1" 81 | } 82 | ] 83 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package drill 2 | 3 | import "time" 4 | 5 | const drillRPCVersion int32 = 5 6 | const clientName = "Apache Drill Golang Client" 7 | const drillVersion = "1.17.0" 8 | const drillMajorVersion = 1 9 | const drillMinorVersion = 17 10 | const drillPatchVersion = 0 11 | const defaultHeartbeatFreq = 15 * time.Second 12 | 13 | // Options for a Drill Connection 14 | type Options struct { 15 | // the default Schema to use 16 | Schema string 17 | // true if expected to use encryption for communication 18 | SaslEncrypt bool 19 | // the HOST portion to use for the spn to authenticate with, if _HOST or 20 | // empty, will use the address of the drillbit that is connected to 21 | ServiceHost string 22 | // the krb service name to use for authentication 23 | ServiceName string 24 | // what authentication mechanism to use, currently only supports kerberos 25 | // or no auth 26 | Auth string 27 | // the Drill clusters name which is used by ZooKeeper to store the endpoint 28 | // information 29 | ClusterName string 30 | // use this instead of ClusterName to fully specify the Zookeeper path instead 31 | // of using the /drill prefix 32 | ZKPath string 33 | // whether or not the server should support complex types such as List 34 | SupportComplexTypes bool 35 | // what Application Name to use for connecting to the server 36 | ApplicationName string 37 | // the username to authenticate as 38 | User string 39 | // Password to use for PLAIN auth 40 | Passwd string 41 | // the heartbeatfrequency to use, if nil then will use the default (15 seconds) 42 | // set to 0 to disable it. 43 | HeartbeatFreq *time.Duration 44 | // UseArrow controls whether the raw data in the results is underlined by arrow 45 | // arrays (if true) or not (if false, default) 46 | UseArrow bool 47 | } 48 | -------------------------------------------------------------------------------- /internal/data/arrow_numeric.gen_test.go.tmpl: -------------------------------------------------------------------------------- 1 | package data_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/apache/arrow/go/arrow" 8 | "github.com/apache/arrow/go/arrow/array" 9 | "github.com/factset/go-drill/internal/data" 10 | "github.com/factset/go-drill/internal/rpc/proto/common" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | {{- range .In}} 15 | 16 | func Test{{.Name}}Arrow(t *testing.T) { 17 | const N = 10 18 | b := data.{{.Name}}Traits.CastToBytes([]{{.Type}}{ 19 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 20 | }) 21 | 22 | meta := createMetaField(common.DataMode_REQUIRED, common.MinorType_{{.ProtoType}}.Enum(), int32(N), int32(len(b))) 23 | 24 | arr := data.NewArrowArray(b, meta) 25 | assert.NotNil(t, arr) 26 | assert.IsType(t, arrow.PrimitiveTypes.{{.Name}}, arr.DataType()) 27 | assert.Equal(t, reflect.TypeOf({{.Type}}({{.Default}})), data.ArrowTypeToReflect(arr.DataType())) 28 | assert.Equal(t, N, arr.Len()) 29 | assert.Zero(t, arr.NullN()) 30 | 31 | for i := 0; i < N; i++ { 32 | assert.True(t, arr.IsValid(i)) 33 | assert.Equal(t, {{.name}}(i), arr.(*array.{{.Name}}).Value(i)) 34 | } 35 | } 36 | 37 | func TestOptional{{.Name}}Arrow(t *testing.T) { 38 | const N = 10 39 | b := data.{{.Name}}Traits.CastToBytes([]{{.Type}}{ 40 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 41 | }) 42 | 43 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 44 | meta := createMetaField(common.DataMode_OPTIONAL, common.MinorType_{{.ProtoType}}.Enum(), int32(N), int32(len(b))) 45 | 46 | rawData := append(bytemap, b...) 47 | 48 | arr := data.NewArrowArray(rawData, meta) 49 | assert.NotNil(t, arr) 50 | assert.IsType(t, arrow.PrimitiveTypes.{{.Name}}, arr.DataType()) 51 | assert.Equal(t, N, arr.Len()) 52 | assert.Equal(t, N/2, arr.NullN()) 53 | 54 | for i := 0; i < N; i++ { 55 | assert.Exactly(t, i%2 == 0, arr.IsNull(i)) 56 | if i%2 == 1 { 57 | assert.Exactly(t, {{.Type}}(i), arr.(*array.{{.Name}}).Value(i)) 58 | } 59 | } 60 | } 61 | {{end}} 62 | -------------------------------------------------------------------------------- /internal/data/vector_numeric.gen.go.tmpl: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 7 | ) 8 | 9 | {{- range .In}} 10 | // {{ .Name}} vector 11 | 12 | type {{.Name}}Vector struct { 13 | vector 14 | values []{{.name}} 15 | meta *shared.SerializedField 16 | } 17 | 18 | func ({{.Name}}Vector) Type() reflect.Type { 19 | return reflect.TypeOf({{.name}}({{.Default}})) 20 | } 21 | 22 | func ({{.Name}}Vector) TypeLen() (int64, bool) { 23 | return 0, false 24 | } 25 | 26 | func (v *{{.Name}}Vector) Len() int { 27 | return int(len(v.values)) 28 | } 29 | 30 | func (v *{{.Name}}Vector) Get(index uint) {{.name}} { 31 | return v.values[index] 32 | } 33 | 34 | func (v *{{.Name}}Vector) Value(index uint) interface{} { 35 | return v.Get(index) 36 | } 37 | 38 | func New{{.Name}}Vector(data []byte, meta *shared.SerializedField) *{{.Name}}Vector { 39 | return &{{.Name}}Vector{ 40 | vector: vector{rawData: data}, 41 | values: {{.Name}}Traits.CastFromBytes(data), 42 | meta: meta, 43 | } 44 | } 45 | 46 | type Nullable{{.Name}}Vector struct { 47 | *{{.Name}}Vector 48 | 49 | nullByteMap 50 | {{/* byteMap []byte */}} 51 | } 52 | 53 | {{/* func (nv *Nullable{{.Name}}Vector) IsNull(index uint) bool { 54 | return nv.byteMap[index] == 0 55 | } */}} 56 | 57 | func (nv *Nullable{{.Name}}Vector) Get(index uint) *{{.name}} { 58 | if nv.IsNull(index) { return nil } 59 | 60 | return &nv.values[index] 61 | } 62 | 63 | func (nv *Nullable{{.Name}}Vector) Value(index uint) interface{} { 64 | val := nv.Get(index) 65 | if val != nil { 66 | return *val 67 | } 68 | 69 | return val 70 | } 71 | 72 | func NewNullable{{.Name}}Vector(data []byte, meta *shared.SerializedField) *Nullable{{.Name}}Vector { 73 | byteMap := data[:meta.GetValueCount()] 74 | remaining := data[meta.GetValueCount():] 75 | 76 | return &Nullable{{.Name}}Vector{ 77 | New{{.Name}}Vector(remaining, meta), 78 | nullByteMap{byteMap}, 79 | {{/* byteMap, */}} 80 | } 81 | } 82 | 83 | {{end}} 84 | -------------------------------------------------------------------------------- /smoke_test.go: -------------------------------------------------------------------------------- 1 | // +build smoke 2 | 3 | package drill_test 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log" 9 | 10 | "github.com/factset/go-drill" 11 | ) 12 | 13 | func Example() { 14 | cl := drill.NewClient(drill.Options{Schema: "dfs.sample", ClusterName: "drillbits1"}, "localhost:2181", "localhost:2182", "localhost:2183") 15 | 16 | err := cl.Connect(context.Background()) 17 | if err != nil { 18 | log.Fatal(err) 19 | } 20 | defer cl.Close() 21 | 22 | dh, err := cl.SubmitQuery(drill.TypeSQL, "SELECT * FROM `nation.parquet`") 23 | if err != nil { 24 | log.Fatal(err) 25 | } 26 | 27 | batch, err := dh.Next() 28 | for ; err == nil; batch, err = dh.Next() { 29 | for i := int32(0); i < batch.NumRows(); i++ { 30 | for _, v := range batch.GetVectors() { 31 | val := v.Value(uint(i)) 32 | switch t := val.(type) { 33 | case []byte: 34 | fmt.Print("|", string(t)) 35 | default: 36 | fmt.Print("|", t) 37 | } 38 | } 39 | fmt.Println("|") 40 | } 41 | } 42 | 43 | // Output: 44 | // |0|ALGERIA|0| haggle. carefully f| 45 | // |1|ARGENTINA|1|al foxes promise sly| 46 | // |2|BRAZIL|1|y alongside of the p| 47 | // |3|CANADA|1|eas hang ironic, sil| 48 | // |4|EGYPT|4|y above the carefull| 49 | // |5|ETHIOPIA|0|ven packages wake qu| 50 | // |6|FRANCE|3|refully final reques| 51 | // |7|GERMANY|3|l platelets. regular| 52 | // |8|INDIA|2|ss excuses cajole sl| 53 | // |9|INDONESIA|2| slyly express asymp| 54 | // |10|IRAN|4|efully alongside of | 55 | // |11|IRAQ|4|nic deposits boost a| 56 | // |12|JAPAN|2|ously. final, expres| 57 | // |13|JORDAN|4|ic deposits are blit| 58 | // |14|KENYA|0| pending excuses hag| 59 | // |15|MOROCCO|0|rns. blithely bold c| 60 | // |16|MOZAMBIQUE|0|s. ironic, unusual a| 61 | // |17|PERU|1|platelets. blithely | 62 | // |18|CHINA|2|c dependencies. furi| 63 | // |19|ROMANIA|3|ular asymptotes are | 64 | // |20|SAUDI ARABIA|4|ts. silent requests | 65 | // |21|VIETNAM|2|hely enticingly expr| 66 | // |22|RUSSIA|3| requests against th| 67 | // |23|UNITED KINGDOM|3|eans boost carefully| 68 | // |24|UNITED STATES|1|y final packages. sl| 69 | } 70 | -------------------------------------------------------------------------------- /internal/data/vector_numeric.gen_test.go.tmpl: -------------------------------------------------------------------------------- 1 | package data_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/factset/go-drill/internal/data" 8 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 9 | "github.com/stretchr/testify/assert" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | {{- range .In}} 14 | 15 | func Test{{.Name}}Vector(t *testing.T) { 16 | const N = 10 17 | b := data.{{.Name}}Traits.CastToBytes([]{{.Type}}{ 18 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19 | }) 20 | 21 | vec := data.New{{.Name}}Vector(b, nil) 22 | 23 | assert.Equal(t, reflect.TypeOf({{.name}}({{.Default}})), vec.Type()) 24 | 25 | l, ok := vec.TypeLen() 26 | assert.Zero(t, l) 27 | assert.False(t, ok) 28 | assert.Equal(t, N, vec.Len()) 29 | 30 | for i := 0; i < N; i++ { 31 | assert.Exactly(t, {{.name}}(i), vec.Get(uint(i))) 32 | assert.Exactly(t, {{.name}}(i), vec.Value(uint(i))) 33 | } 34 | 35 | assert.Same(t, &b[0], &vec.GetRawBytes()[0]) 36 | } 37 | 38 | func TestNullable{{.Name}}Vector(t *testing.T) { 39 | const N = 10 40 | b := data.{{.Name}}Traits.CastToBytes([]{{.Type}}{ 41 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 42 | }) 43 | 44 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 45 | meta := &shared.SerializedField{ValueCount: proto.Int32(10)} 46 | 47 | vec := data.NewNullable{{.Name}}Vector(append(bytemap, b...), meta) 48 | assert.Equal(t, reflect.TypeOf({{.name}}({{.Default}})), vec.Type()) 49 | 50 | assert.EqualValues(t, b, vec.GetRawBytes()) 51 | assert.EqualValues(t, bytemap, vec.GetNullBytemap()) 52 | 53 | l, ok := vec.TypeLen() 54 | assert.Zero(t, l) 55 | assert.False(t, ok) 56 | assert.Equal(t, N, vec.Len()) 57 | 58 | for i := 0; i < N; i++ { 59 | assert.Equal(t, i%2 == 0, vec.IsNull(uint(i))) 60 | if i%2 == 1 { 61 | val := new({{.name}}) 62 | *val = {{.name}}(i) 63 | 64 | assert.Exactly(t, val, vec.Get(uint(i))) 65 | assert.Exactly(t, {{.name}}(i), vec.Value(uint(i))) 66 | } else { 67 | assert.Nil(t, vec.Get(uint(i))) 68 | assert.Nil(t, vec.Value(uint(i))) 69 | } 70 | } 71 | } 72 | {{end}} 73 | -------------------------------------------------------------------------------- /driver/prepared.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "io" 7 | 8 | "github.com/factset/go-drill" 9 | ) 10 | 11 | type prepared struct { 12 | stmt drill.PreparedHandle 13 | 14 | client drill.Conn 15 | } 16 | 17 | func (p *prepared) Close() error { 18 | p.stmt = nil 19 | p.client = nil 20 | return nil 21 | } 22 | 23 | func (p *prepared) NumInput() int { 24 | return 0 25 | } 26 | 27 | func (p *prepared) Exec(args []driver.Value) (driver.Result, error) { 28 | return driver.ResultNoRows, driver.ErrSkip 29 | } 30 | 31 | func (p *prepared) Query(args []driver.Value) (driver.Rows, error) { 32 | return nil, driver.ErrSkip 33 | } 34 | 35 | type result struct { 36 | rowsAffected int64 37 | rowsError error 38 | } 39 | 40 | func (r result) LastInsertId() (int64, error) { 41 | return driver.ResultNoRows.LastInsertId() 42 | } 43 | 44 | func (r result) RowsAffected() (int64, error) { 45 | return r.rowsAffected, r.rowsError 46 | } 47 | 48 | func (p *prepared) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 49 | if len(args) > 0 { 50 | return nil, errNoPrepSupport 51 | } 52 | 53 | handle, err := p.client.ExecuteStmt(p.stmt) 54 | if err != nil { 55 | return nil, driver.ErrBadConn 56 | } 57 | 58 | var affectedRows int64 = 0 59 | err = processWithCtx(ctx, handle, func(h drill.DataHandler) error { 60 | var err error 61 | var batch drill.RowBatch 62 | for batch, err = h.Next(); err == nil; batch, err = h.Next() { 63 | affectedRows += int64(batch.AffectedRows()) 64 | } 65 | 66 | return err 67 | }) 68 | 69 | if err == io.EOF { 70 | err = nil 71 | } 72 | 73 | return result{rowsAffected: affectedRows, rowsError: err}, nil 74 | } 75 | 76 | func (p *prepared) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 77 | if len(args) > 0 { 78 | return nil, errNoPrepSupport 79 | } 80 | 81 | handle, err := p.client.ExecuteStmt(p.stmt) 82 | if err != nil { 83 | return nil, driver.ErrBadConn 84 | } 85 | 86 | r := &rows{handle: handle} 87 | return r, processWithCtx(ctx, handle, func(h drill.DataHandler) error { 88 | _, err := h.Next() 89 | return err 90 | }) 91 | } 92 | -------------------------------------------------------------------------------- /kerberos_smoke_test.go: -------------------------------------------------------------------------------- 1 | // +build smoke,kerberos 2 | 3 | package drill_test 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log" 9 | 10 | "github.com/factset/go-drill" 11 | ) 12 | 13 | func Example_kerberos() { 14 | cl := drill.NewClient(drill.Options{Schema: "dfs.sample", Auth: "kerberos", ClusterName: "drillbits1", SaslEncrypt: true, ServiceName: "drill"}, "localhost:2181", "localhost:2182", "localhost:2183") 15 | 16 | err := cl.Connect(context.Background()) 17 | if err != nil { 18 | log.Fatal(err) 19 | } 20 | defer cl.Close() 21 | 22 | dh, err := cl.SubmitQuery(drill.TypeSQL, "SELECT * FROM `nation.parquet`") 23 | if err != nil { 24 | log.Fatal(err) 25 | } 26 | 27 | batch, err := dh.Next() 28 | for ; err == nil; batch, err = dh.Next() { 29 | for i := int32(0); i < batch.NumRows(); i++ { 30 | for _, v := range batch.GetVectors() { 31 | val := v.Value(uint(i)) 32 | switch t := val.(type) { 33 | case []byte: 34 | fmt.Print("|", string(t)) 35 | default: 36 | fmt.Print("|", t) 37 | } 38 | } 39 | fmt.Println("|") 40 | } 41 | } 42 | 43 | // Output: 44 | // |0|ALGERIA|0| haggle. carefully f| 45 | // |1|ARGENTINA|1|al foxes promise sly| 46 | // |2|BRAZIL|1|y alongside of the p| 47 | // |3|CANADA|1|eas hang ironic, sil| 48 | // |4|EGYPT|4|y above the carefull| 49 | // |5|ETHIOPIA|0|ven packages wake qu| 50 | // |6|FRANCE|3|refully final reques| 51 | // |7|GERMANY|3|l platelets. regular| 52 | // |8|INDIA|2|ss excuses cajole sl| 53 | // |9|INDONESIA|2| slyly express asymp| 54 | // |10|IRAN|4|efully alongside of | 55 | // |11|IRAQ|4|nic deposits boost a| 56 | // |12|JAPAN|2|ously. final, expres| 57 | // |13|JORDAN|4|ic deposits are blit| 58 | // |14|KENYA|0| pending excuses hag| 59 | // |15|MOROCCO|0|rns. blithely bold c| 60 | // |16|MOZAMBIQUE|0|s. ironic, unusual a| 61 | // |17|PERU|1|platelets. blithely | 62 | // |18|CHINA|2|c dependencies. furi| 63 | // |19|ROMANIA|3|ular asymptotes are | 64 | // |20|SAUDI ARABIA|4|ts. silent requests | 65 | // |21|VIETNAM|2|hely enticingly expr| 66 | // |22|RUSSIA|3| requests against th| 67 | // |23|UNITED KINGDOM|3|eans boost carefully| 68 | // |24|UNITED STATES|1|y final packages. sl| 69 | } 70 | -------------------------------------------------------------------------------- /internal/data/numeric_vec_typemap.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by numeric_vec_typemap.gen.go.tmpl. DO NOT EDIT. 2 | 3 | package data 4 | 5 | import ( 6 | "github.com/factset/go-drill/internal/rpc/proto/common" 7 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 8 | ) 9 | 10 | func NewNumericValueVec(rawData []byte, meta *shared.SerializedField) DataVector { 11 | if meta.GetMajorType().GetMode() == common.DataMode_OPTIONAL { 12 | switch meta.GetMajorType().GetMinorType() { 13 | 14 | case common.MinorType_BIGINT: 15 | return NewNullableInt64Vector(rawData, meta) 16 | 17 | case common.MinorType_INT: 18 | return NewNullableInt32Vector(rawData, meta) 19 | 20 | case common.MinorType_FLOAT8: 21 | return NewNullableFloat64Vector(rawData, meta) 22 | 23 | case common.MinorType_UINT8: 24 | return NewNullableUint64Vector(rawData, meta) 25 | 26 | case common.MinorType_UINT4: 27 | return NewNullableUint32Vector(rawData, meta) 28 | 29 | case common.MinorType_FLOAT4: 30 | return NewNullableFloat32Vector(rawData, meta) 31 | 32 | case common.MinorType_SMALLINT: 33 | return NewNullableInt16Vector(rawData, meta) 34 | 35 | case common.MinorType_UINT2: 36 | return NewNullableUint16Vector(rawData, meta) 37 | 38 | case common.MinorType_TINYINT: 39 | return NewNullableInt8Vector(rawData, meta) 40 | 41 | case common.MinorType_UINT1: 42 | return NewNullableUint8Vector(rawData, meta) 43 | 44 | default: 45 | return nil 46 | } 47 | } 48 | 49 | switch meta.GetMajorType().GetMinorType() { 50 | 51 | case common.MinorType_BIGINT: 52 | return NewInt64Vector(rawData, meta) 53 | 54 | case common.MinorType_INT: 55 | return NewInt32Vector(rawData, meta) 56 | 57 | case common.MinorType_FLOAT8: 58 | return NewFloat64Vector(rawData, meta) 59 | 60 | case common.MinorType_UINT8: 61 | return NewUint64Vector(rawData, meta) 62 | 63 | case common.MinorType_UINT4: 64 | return NewUint32Vector(rawData, meta) 65 | 66 | case common.MinorType_FLOAT4: 67 | return NewFloat32Vector(rawData, meta) 68 | 69 | case common.MinorType_SMALLINT: 70 | return NewInt16Vector(rawData, meta) 71 | 72 | case common.MinorType_UINT2: 73 | return NewUint16Vector(rawData, meta) 74 | 75 | case common.MinorType_TINYINT: 76 | return NewInt8Vector(rawData, meta) 77 | 78 | case common.MinorType_UINT1: 79 | return NewUint8Vector(rawData, meta) 80 | 81 | default: 82 | return nil 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /driver/connector.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "fmt" 7 | "strconv" 8 | "strings" 9 | "time" 10 | 11 | "github.com/factset/go-drill" 12 | ) 13 | 14 | type connector struct { 15 | base drill.Conn 16 | } 17 | 18 | func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { 19 | dc, err := c.base.NewConnection(ctx) 20 | if err != nil { 21 | return nil, err 22 | } 23 | return &conn{dc}, nil 24 | } 25 | 26 | func (c *connector) Driver() driver.Driver { 27 | return drillDriver{} 28 | } 29 | 30 | func parseConnectStr(connectStr string) (driver.Connector, error) { 31 | opts := drill.Options{} 32 | 33 | var zknodes []string 34 | var host string 35 | var port int = 31010 // default port is 31010 if connecting directly 36 | args := strings.Split(connectStr, ";") 37 | for _, kv := range args { 38 | parsed := strings.Split(kv, "=") 39 | if len(parsed) != 2 { 40 | return nil, fmt.Errorf("invalid format for connector string") 41 | } 42 | 43 | parsed[1] = strings.TrimSpace(parsed[1]) 44 | 45 | switch strings.TrimSpace(parsed[0]) { 46 | case "zk": 47 | zknodes = strings.Split(parsed[1], ",") 48 | slash := strings.Index(zknodes[len(zknodes)-1], "/") 49 | if slash != -1 { 50 | addr := zknodes[len(zknodes)-1] 51 | zknodes[len(zknodes)-1] = addr[:slash] 52 | opts.ZKPath = addr[slash:] 53 | } 54 | case "auth": 55 | opts.Auth = parsed[1] 56 | case "schema": 57 | opts.Schema = parsed[1] 58 | case "service": 59 | opts.ServiceName = parsed[1] 60 | case "encrypt": 61 | val, err := strconv.ParseBool(parsed[1]) 62 | if err != nil { 63 | return nil, err 64 | } 65 | opts.SaslEncrypt = val 66 | case "user": 67 | opts.User = parsed[1] 68 | case "pass": 69 | opts.Passwd = parsed[1] 70 | case "cluster": 71 | opts.ClusterName = parsed[1] 72 | case "host": 73 | host = parsed[1] 74 | case "port": 75 | var err error 76 | port, err = strconv.Atoi(parsed[1]) 77 | if err != nil { 78 | return nil, fmt.Errorf("drill: invalid port format '%s': %w", parsed[1], err) 79 | } 80 | case "heartbeat": 81 | hbsec, err := strconv.Atoi(parsed[1]) 82 | if err != nil { 83 | return nil, err 84 | } 85 | opts.HeartbeatFreq = new(time.Duration) 86 | *opts.HeartbeatFreq = time.Duration(hbsec) * time.Second 87 | default: 88 | return nil, fmt.Errorf("invalid argument for connection string: %s", parsed[0]) 89 | } 90 | } 91 | 92 | if len(zknodes) == 0 && host != "" { 93 | return &connector{base: drill.NewDirectClient(opts, host, int32(port))}, nil 94 | } 95 | return &connector{base: drill.NewClient(opts, zknodes...)}, nil 96 | } 97 | -------------------------------------------------------------------------------- /internal/data/decimal_utils.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "encoding/binary" 5 | "math" 6 | "math/big" 7 | ) 8 | 9 | type DecimalTraits interface { 10 | NumDigits() int 11 | ByteWidth() int 12 | IsSparse() bool 13 | MaxPrecision() int 14 | } 15 | 16 | var ( 17 | Decimal28DenseTraits decimal28DenseTraits 18 | Decimal38DenseTraits decimal38DenseTraits 19 | Decimal28SparseTraits decimal28SparseTraits 20 | Decimal38SparseTraits decimal38SparseTraits 21 | ) 22 | 23 | type decimal28DenseTraits struct{} 24 | 25 | func (decimal28DenseTraits) NumDigits() int { return 3 } 26 | func (decimal28DenseTraits) ByteWidth() int { return 12 } 27 | func (decimal28DenseTraits) IsSparse() bool { return false } 28 | func (decimal28DenseTraits) MaxPrecision() int { return 28 } 29 | 30 | type decimal38DenseTraits struct{} 31 | 32 | func (decimal38DenseTraits) NumDigits() int { return 4 } 33 | func (decimal38DenseTraits) ByteWidth() int { return 16 } 34 | func (decimal38DenseTraits) IsSparse() bool { return false } 35 | func (decimal38DenseTraits) MaxPrecision() int { return 38 } 36 | 37 | type decimal28SparseTraits struct{} 38 | 39 | func (decimal28SparseTraits) NumDigits() int { return 5 } 40 | func (decimal28SparseTraits) ByteWidth() int { return 20 } 41 | func (decimal28SparseTraits) IsSparse() bool { return true } 42 | func (decimal28SparseTraits) MaxPrecision() int { return 28 } 43 | 44 | type decimal38SparseTraits struct{} 45 | 46 | func (decimal38SparseTraits) NumDigits() int { return 6 } 47 | func (decimal38SparseTraits) ByteWidth() int { return 24 } 48 | func (decimal38SparseTraits) IsSparse() bool { return true } 49 | func (decimal38SparseTraits) MaxPrecision() int { return 38 } 50 | 51 | const ( 52 | maxdigits = 9 53 | digBase = 1000000000 54 | ) 55 | 56 | var base = big.NewFloat(digBase) 57 | 58 | func getFloatFromBytes(valbytes []byte, digits, scale int, truncate bool) *big.Float { 59 | // sparse types (truncate == true) are little endian, otherwise we're big endian 60 | var order binary.ByteOrder 61 | if truncate { 62 | order = binary.LittleEndian 63 | } else { 64 | order = binary.BigEndian 65 | } 66 | 67 | val := big.NewFloat(float64(order.Uint32(valbytes) & 0x7FFFFFFF)) 68 | for i := 1; i < digits; i++ { 69 | tmp := big.NewFloat(float64(order.Uint32(valbytes[i*Uint32SizeBytes:]))) 70 | val.Mul(val, base) 71 | val.Add(val, tmp) 72 | } 73 | 74 | actualDigits := int32(scale % maxdigits) 75 | if truncate && scale > 0 && (actualDigits != 0) { 76 | val.Quo(val, big.NewFloat(math.Pow10(int(maxdigits-actualDigits)))) 77 | } 78 | 79 | if order.Uint32(valbytes)&0x80000000 != 0 { 80 | val.Neg(val) 81 | } 82 | 83 | // scale it and return it 84 | return val.Quo(val, big.NewFloat(math.Pow10(int(scale)))) 85 | } 86 | -------------------------------------------------------------------------------- /zk_handler.go: -------------------------------------------------------------------------------- 1 | package drill 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/factset/go-drill/internal/log" 8 | "github.com/factset/go-drill/internal/rpc/proto/exec" 9 | "github.com/go-zookeeper/zk" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | type zkconn interface { 14 | Get(path string) ([]byte, *zk.Stat, error) 15 | Children(path string) ([]string, *zk.Stat, error) 16 | Close() 17 | } 18 | 19 | type zkHandler struct { 20 | conn zkconn 21 | 22 | Nodes []string 23 | Path string 24 | Connecting bool 25 | Err error 26 | } 27 | 28 | // newZKHandler attempts to connect to a zookeeper cluster made up of the provided nodes. 29 | // 30 | // The cluster passed in here would be the Drill cluster name which is used to form the path 31 | // to the drill meta data information. 32 | func newZKHandler(path string, nodes ...string) (*zkHandler, error) { 33 | hdlr := &zkHandler{Connecting: true, Nodes: zk.FormatServers(nodes), Path: path} 34 | var err error 35 | hdlr.conn, _, err = zk.Connect(hdlr.Nodes, 30*time.Second, zk.WithLogger(&log.Logger), zk.WithEventCallback(func(ev zk.Event) { 36 | switch ev.Type { 37 | case zk.EventSession: 38 | switch ev.State { 39 | case zk.StateAuthFailed: 40 | hdlr.Err = fmt.Errorf("ZK Auth Failed: %w", zk.ErrAuthFailed) 41 | hdlr.conn.Close() 42 | case zk.StateExpired: 43 | hdlr.Err = fmt.Errorf("ZK Session Expired: %w", zk.ErrSessionExpired) 44 | hdlr.conn.Close() 45 | } 46 | } 47 | 48 | hdlr.Connecting = false 49 | if ev.State == zk.StateConnected { 50 | log.Print("Connected to Zookeeper.") 51 | } 52 | })) 53 | 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | return hdlr, nil 59 | } 60 | 61 | // GetDrillBits returns the list of drillbit names that can in turn be passed to 62 | // GetEndpoint to get the endpoint information to connect to them. 63 | func (z *zkHandler) GetDrillBits() []string { 64 | children, stat, err := z.conn.Children(z.Path) 65 | if err != nil { 66 | z.Err = err 67 | } 68 | 69 | log.Printf("%+v %+v", children, stat) 70 | return children 71 | } 72 | 73 | // GetEndpoint returns the information necessary to connect to a given drillbit 74 | // from its name. 75 | func (z *zkHandler) GetEndpoint(drillbit string) Drillbit { 76 | data, _, err := z.conn.Get(z.Path + "/" + drillbit) 77 | if err != nil { 78 | z.Err = err 79 | return nil 80 | } 81 | 82 | drillServer := exec.DrillServiceInstance{} 83 | if err = proto.Unmarshal(data, &drillServer); err != nil { 84 | z.Err = err 85 | return nil 86 | } 87 | 88 | log.Printf("%+v", drillServer.String()) 89 | 90 | return drillServer.GetEndpoint() 91 | } 92 | 93 | // Close closes the zookeeper connection and should be called when finished. 94 | func (z *zkHandler) Close() { 95 | z.conn.Close() 96 | } 97 | -------------------------------------------------------------------------------- /smoketest/storage-plugins-override.conf: -------------------------------------------------------------------------------- 1 | "storage": { 2 | dfs: { 3 | "type" : "file", 4 | "connection" : "file:///", 5 | "config" : null, 6 | "workspaces" : { 7 | "tmp" : { 8 | "location" : "/tmp", 9 | "writable" : true, 10 | "defaultInputFormat" : null, 11 | "allowAccessOutsideWorkspace" : false 12 | }, 13 | "root" : { 14 | "location" : "/", 15 | "writable" : false, 16 | "defaultInputFormat" : null, 17 | "allowAccessOutsideWorkspace" : false 18 | }, 19 | "sample" : { 20 | "location" : "/opt/drill/sample-data", 21 | "writable" : false, 22 | "defaultInputFormat" : null, 23 | "allowAccessOutsideWorkspace" : false 24 | } 25 | }, 26 | "formats" : { 27 | "psv" : { 28 | "type" : "text", 29 | "extensions" : [ "tbl" ], 30 | "delimiter" : "|" 31 | }, 32 | "csv" : { 33 | "type" : "text", 34 | "extensions" : [ "csv" ], 35 | "delimiter" : "," 36 | }, 37 | "tsv" : { 38 | "type" : "text", 39 | "extensions" : [ "tsv" ], 40 | "delimiter" : "\t" 41 | }, 42 | "httpd" : { 43 | "type" : "httpd", 44 | "logFormat" : "%h %l %u %t \"%r\" %>s %b \"%{Referer}i\" \"%{User-agent}i\"" 45 | }, 46 | "parquet" : { 47 | "type" : "parquet" 48 | }, 49 | "json" : { 50 | "type" : "json", 51 | "extensions" : [ "json" ] 52 | }, 53 | "pcap" : { 54 | "type" : "pcap", 55 | "extensions" : [ "pcap" ] 56 | }, 57 | "pcapng" : { 58 | "type" : "pcapng", 59 | "extensions" : [ "pcapng" ] 60 | }, 61 | "avro" : { 62 | "type" : "avro" 63 | }, 64 | "sequencefile" : { 65 | "type" : "sequencefile", 66 | "extensions" : [ "seq" ] 67 | }, 68 | "csvh" : { 69 | "type" : "text", 70 | "extensions" : [ "csvh" ], 71 | "extractHeader" : true, 72 | "delimiter" : "," 73 | }, 74 | "image" : { 75 | "type" : "image", 76 | "extensions" : [ "jpg", "jpeg", "jpe", "tif", "tiff", "dng", "psd", "png", "bmp", "gif", "ico", "pcx", "wav", "wave", "avi", "webp", "mov", "mp4", "m4a", "m4p", "m4b", "m4r", "m4v", "3gp", "3g2", "eps", "epsf", "epsi", "ai", "arw", "crw", "cr2", "nef", "orf", "raf", "rw2", "rwl", "srw", "x3f" ] 77 | }, 78 | "syslog" : { 79 | "type" : "syslog", 80 | "extensions" : [ "syslog" ] 81 | }, 82 | "shp" : { 83 | "type" : "shp" 84 | }, 85 | "excel" : { 86 | "type" : "excel" 87 | }, 88 | "ltsv" : { 89 | "type" : "ltsv", 90 | "extensions" : [ "ltsv" ] 91 | } 92 | }, 93 | "enabled" : true 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /internal/cmd/tmpl/main.go: -------------------------------------------------------------------------------- 1 | // Generator for templated type files 2 | package main 3 | 4 | import ( 5 | "bytes" 6 | "encoding/json" 7 | "flag" 8 | "fmt" 9 | "go/format" 10 | "io/ioutil" 11 | "log" 12 | "os" 13 | "path/filepath" 14 | "strings" 15 | "text/template" 16 | ) 17 | 18 | type data struct { 19 | In interface{} 20 | D listValue 21 | } 22 | 23 | type listValue map[string]string 24 | 25 | func (l listValue) String() string { 26 | res := make([]string, 0, len(l)) 27 | for k, v := range l { 28 | res = append(res, fmt.Sprintf("%s=%s", k, v)) 29 | } 30 | return strings.Join(res, ", ") 31 | } 32 | 33 | func (l listValue) Set(v string) error { 34 | nv := strings.Split(v, "=") 35 | if len(nv) != 2 { 36 | return fmt.Errorf("expected NAME=VALUE, got %s", v) 37 | } 38 | l[nv[0]] = nv[1] 39 | return nil 40 | } 41 | 42 | func parsePath(path string) (string, string) { 43 | p := strings.IndexByte(path, '=') 44 | if p == -1 { 45 | if filepath.Ext(path) != ".tmpl" { 46 | log.Fatalf("template file '%s' must have .tmpl extension", path) 47 | } 48 | return path, path[:len(path)-len(".tmpl")] 49 | } 50 | 51 | return path[:p], path[p+1:] 52 | } 53 | 54 | func main() { 55 | var ( 56 | dataArg = flag.String("data", "", "input JSON data") 57 | in = &data{D: make(listValue)} 58 | ) 59 | 60 | flag.Var(&in.D, "d", "-d NAME=VALUE") 61 | flag.Parse() 62 | if *dataArg == "" { 63 | log.Fatal("data option is required") 64 | } 65 | 66 | paths := flag.Args() 67 | if len(paths) == 0 { 68 | log.Fatal("no tmpl files specified") 69 | } 70 | 71 | in.In = readData(*dataArg) 72 | process(in, paths) 73 | } 74 | 75 | func readData(path string) interface{} { 76 | data, err := ioutil.ReadFile(path) 77 | if err != nil { 78 | log.Fatal("Read Data: ", err) 79 | } 80 | 81 | var v interface{} 82 | if err := json.Unmarshal(data, &v); err != nil { 83 | log.Fatal("Unmarshal: ", err) 84 | } 85 | return v 86 | } 87 | 88 | func process(data interface{}, paths []string) { 89 | for _, p := range paths { 90 | var ( 91 | t *template.Template 92 | err error 93 | ) 94 | 95 | in, out := parsePath(p) 96 | 97 | contents, _ := ioutil.ReadFile(in) 98 | t, err = template.New("gen").Parse(string(contents)) 99 | if err != nil { 100 | log.Fatal("Template Parse: ", err) 101 | } 102 | 103 | var buf bytes.Buffer 104 | fmt.Fprintf(&buf, "// Code generated by %s. DO NOT EDIT.\n\n", p) 105 | 106 | err = t.Execute(&buf, data) 107 | if err != nil { 108 | log.Fatal("Tmpl Execute: ", err) 109 | } 110 | 111 | generated := buf.Bytes() 112 | generated, err = format.Source(generated) 113 | if err != nil { 114 | log.Fatal("Format: ", err) 115 | } 116 | 117 | ioutil.WriteFile(out, generated, os.ModePerm) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /zk_handler_test.go: -------------------------------------------------------------------------------- 1 | package drill 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/factset/go-drill/internal/rpc/proto/exec" 7 | "github.com/go-zookeeper/zk" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/mock" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | type mockzk struct { 14 | mock.Mock 15 | } 16 | 17 | func (m *mockzk) Get(path string) ([]byte, *zk.Stat, error) { 18 | args := m.Called(path) 19 | return args.Get(0).([]byte), args.Get(1).(*zk.Stat), args.Error(2) 20 | } 21 | 22 | func (m *mockzk) Children(path string) ([]string, *zk.Stat, error) { 23 | args := m.Called(path) 24 | return args.Get(0).([]string), args.Get(1).(*zk.Stat), args.Error(2) 25 | } 26 | 27 | func (m *mockzk) Close() { 28 | m.Called() 29 | } 30 | 31 | func TestZKHandlerClose(t *testing.T) { 32 | m := new(mockzk) 33 | m.Test(t) 34 | m.On("Close") 35 | 36 | hdlr := zkHandler{conn: m} 37 | hdlr.Close() 38 | m.AssertExpectations(t) 39 | } 40 | 41 | func TestZKHandlerGetDrillBits(t *testing.T) { 42 | m := new(mockzk) 43 | m.Test(t) 44 | 45 | m.On("Children", "/drill/cluster").Return([]string{"a", "b", "c"}, (*zk.Stat)(nil), nil) 46 | 47 | hdlr := zkHandler{conn: m, Path: "/drill/cluster"} 48 | val := hdlr.GetDrillBits() 49 | assert.Equal(t, []string{"a", "b", "c"}, val) 50 | m.AssertExpectations(t) 51 | } 52 | 53 | func TestZKHandlerGetDrillBitsError(t *testing.T) { 54 | m := new(mockzk) 55 | m.Test(t) 56 | 57 | m.On("Children", "/drill/cluster").Return([]string{"c"}, (*zk.Stat)(nil), assert.AnError) 58 | 59 | hdlr := zkHandler{conn: m, Path: "/drill/cluster"} 60 | val := hdlr.GetDrillBits() 61 | assert.Equal(t, []string{"c"}, val) 62 | assert.Same(t, assert.AnError, hdlr.Err) 63 | 64 | m.AssertExpectations(t) 65 | } 66 | 67 | func TestZKHandlerGetEndpoint(t *testing.T) { 68 | m := new(mockzk) 69 | m.Test(t) 70 | 71 | service := &exec.DrillServiceInstance{ 72 | Id: proto.String("bit"), 73 | Endpoint: &exec.DrillbitEndpoint{ 74 | Address: proto.String("foobar"), 75 | UserPort: proto.Int32(2020), 76 | }, 77 | } 78 | 79 | data, _ := proto.Marshal(service) 80 | m.On("Get", "/drill/cluster/bit").Return(data, (*zk.Stat)(nil), nil) 81 | 82 | hdlr := zkHandler{conn: m, Path: "/drill/cluster"} 83 | bit := hdlr.GetEndpoint("bit") 84 | 85 | assert.Equal(t, "foobar", bit.GetAddress()) 86 | assert.Equal(t, int32(2020), bit.GetUserPort()) 87 | m.AssertExpectations(t) 88 | } 89 | 90 | func TestZKHandlerGetEndpointErr(t *testing.T) { 91 | m := new(mockzk) 92 | m.Test(t) 93 | 94 | m.On("Get", "/drill/cluster/bit").Return([]byte{}, (*zk.Stat)(nil), assert.AnError) 95 | hdlr := zkHandler{conn: m, Path: "/drill/cluster"} 96 | bit := hdlr.GetEndpoint("bit") 97 | assert.Nil(t, bit) 98 | assert.Same(t, assert.AnError, hdlr.Err) 99 | } 100 | 101 | func TestZKHandlerGetEndpointProtoErr(t *testing.T) { 102 | m := new(mockzk) 103 | m.Test(t) 104 | 105 | m.On("Get", "/drill/cluster/bit").Return([]byte{0x00}, (*zk.Stat)(nil), nil) 106 | hdlr := zkHandler{conn: m, Path: "/drill/cluster"} 107 | bit := hdlr.GetEndpoint("bit") 108 | assert.Nil(t, bit) 109 | assert.Error(t, hdlr.Err) 110 | } 111 | 112 | func TestNewZKHandlerFailConnect(t *testing.T) { 113 | hdlr, err := newZKHandler("") 114 | assert.Nil(t, hdlr) 115 | assert.Error(t, err) 116 | } 117 | -------------------------------------------------------------------------------- /sasl/gssapi_test.go: -------------------------------------------------------------------------------- 1 | package sasl 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jcmturner/gokrb5/v8/crypto" 7 | "github.com/jcmturner/gokrb5/v8/gssapi" 8 | "github.com/jcmturner/gokrb5/v8/iana/keyusage" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestTokenContext(t *testing.T) { 13 | tk := gssapiKrb5Token{ctx: &authContext{}} 14 | 15 | assert.Same(t, tk.ctx, tk.Context().Value(ctxAuthCtx)) 16 | } 17 | 18 | func TestGssapiMechOID(t *testing.T) { 19 | mech := NewGSSAPIKrb5Mech(nil, "foobar", SecurityProps{}) 20 | assert.Equal(t, gssapi.OIDKRB5.OID(), mech.OID()) 21 | } 22 | 23 | func TestGssapiMIC(t *testing.T) { 24 | // just returns a blank MICToken for now: 25 | g := gssapiKrb5Mech{} 26 | assert.Equal(t, gssapi.MICToken{}, g.MIC()) 27 | } 28 | 29 | func TestGssapiAcceptSecContext(t *testing.T) { 30 | sampleSt := gssapi.Status{Code: gssapi.StatusContinueNeeded} 31 | tk := &mockToken{verify: true, st: sampleSt} 32 | 33 | g := gssapiKrb5Mech{} 34 | valid, _, st := g.AcceptSecContext(tk) 35 | assert.True(t, valid) 36 | assert.Equal(t, sampleSt, st) 37 | } 38 | 39 | func TestGssapiMechCtxFlags(t *testing.T) { 40 | tests := []struct { 41 | name string 42 | props SecurityProps 43 | extraFlags []int 44 | }{ 45 | {"encryption flags", SecurityProps{UseEncryption: true}, []int{gssapi.ContextFlagConf, gssapi.ContextFlagInteg}}, 46 | {"integrity flag", SecurityProps{MaxSsf: 1}, []int{gssapi.ContextFlagInteg}}, 47 | {"conf flag", SecurityProps{MaxSsf: 256}, []int{gssapi.ContextFlagInteg, gssapi.ContextFlagConf}}, 48 | } 49 | 50 | for _, tt := range tests { 51 | t.Run(tt.name, func(t *testing.T) { 52 | g := gssapiKrb5Mech{saslProps: tt.props} 53 | assert.ElementsMatch(t, 54 | append([]int{gssapi.ContextFlagMutual /*, gssapi.ContextFlagSequence*/}, 55 | tt.extraFlags...), g.getCtxFlags()) 56 | }) 57 | } 58 | } 59 | 60 | func TestGssapiMechUnwrapNoEncrypt(t *testing.T) { 61 | g := gssapiKrb5Mech{} 62 | assert.Equal(t, deadbeef, g.Unwrap(gssapi.WrapToken{Payload: deadbeef})) 63 | } 64 | 65 | func TestGssapiUnwrapEncrypted(t *testing.T) { 66 | tokenHdrBytes := []byte{5, 4, 2, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} 67 | 68 | etyp, _ := crypto.GetEtype(testkey.KeyType) 69 | _, data, _ := etyp.EncryptMessage(testkey.KeyValue, append(deadbeef, tokenHdrBytes...), keyusage.GSSAPI_ACCEPTOR_SEAL) 70 | 71 | g := gssapiKrb5Mech{ 72 | ctx: authContext{key: testkey}, 73 | } 74 | 75 | assert.Equal(t, deadbeef, g.Unwrap(gssapi.WrapToken{ 76 | Flags: 0x02, 77 | EC: 0, 78 | RRC: 0, 79 | SndSeqNum: 1, 80 | CheckSum: make([]byte, 0), 81 | Payload: data, 82 | })) 83 | } 84 | 85 | func TestGssapiUnwrapPanic(t *testing.T) { 86 | g := gssapiKrb5Mech{ 87 | ctx: authContext{key: testkey}, 88 | } 89 | 90 | assert.Panics(t, func() { 91 | g.Unwrap(gssapi.WrapToken{ 92 | Flags: 0x02, 93 | Payload: deadbeef, 94 | }) 95 | }) 96 | } 97 | 98 | func TestGssapiWrapNoConf(t *testing.T) { 99 | g := gssapiKrb5Mech{ctx: authContext{key: testkey}} 100 | 101 | tok := g.Wrap(deadbeef) 102 | assert.Equal(t, deadbeef, tok.Payload) 103 | } 104 | 105 | func TestGssapiWrapNoConfPanic(t *testing.T) { 106 | g := gssapiKrb5Mech{} 107 | // panics because we have no key 108 | assert.Panics(t, func() { g.Wrap(deadbeef) }) 109 | } 110 | 111 | func TestGssapiWrapEncrypt(t *testing.T) { 112 | g := gssapiKrb5Mech{ctx: authContext{key: testkey, qop: QopConf, localSeqNum: 1}} 113 | 114 | tok := g.Wrap(deadbeef) 115 | 116 | data, err := crypto.DecryptMessage(tok.Payload, testkey, keyusage.GSSAPI_INITIATOR_SEAL) 117 | assert.NoError(t, err) 118 | assert.Equal(t, append(deadbeef, []byte{5, 4, 2, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}...), data) 119 | assert.EqualValues(t, 2, g.ctx.localSeqNum) 120 | } 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-drill 2 | 3 | [![PkgGoDev](https://pkg.go.dev/badge/github.com/factset/go-drill)](https://pkg.go.dev/github.com/factset/go-drill) 4 | [![codecov](https://codecov.io/gh/factset/go-drill/branch/master/graph/badge.svg)](https://codecov.io/gh/factset/go-drill) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/factset/go-drill)](https://goreportcard.com/report/github.com/factset/go-drill) 6 | [![CI Test](https://github.com/factset/go-drill/workflows/Go/badge.svg)](https://github.com/factset/go-drill/actions) 7 | [![Smoke Test](https://github.com/factset/go-drill/workflows/SmokeTest/badge.svg)](https://github.com/factset/go-drill/actions) 8 | [![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](http://www.apache.org/licenses/LICENSE-2.0) 9 | 10 | **go-drill** is a highly efficient Pure Go Client and Sql driver for [Apache Drill](https://drill.apache.org) and [Dremio](https://www.dremio.com). 11 | It differs from other clients / drivers by using the native Protobuf API to communicate instead of the REST API. The use of Protobuf 12 | enables **zero-copy access to the returned data, resulting in greater efficiency.** 13 | 14 | 15 | At the present time, the driver may be used without authentication or with 16 | authentication via SASL gssapi-krb-5. 17 | 18 | In typical use, the driver is initialized with a list of zookeeper hosts 19 | to enable the driver to locate drillbits. It is also possible to connect 20 | directly to a drillbit via the client. 21 | 22 | ## Install 23 | 24 | #### Client 25 | 26 | ```bash 27 | go get -u github.com/factset/go-drill 28 | ``` 29 | 30 | #### Driver 31 | 32 | ```bash 33 | go get -u github.com/factset/go-drill/driver 34 | ``` 35 | 36 | ## Usage 37 | 38 | The driver can be used like a typical Golang SQL driver: 39 | 40 | ```go 41 | import ( 42 | "strings" 43 | "database/sql" 44 | 45 | _ "github.com/factset/go-drill/driver" 46 | ) 47 | 48 | func main() { 49 | props := []string{ 50 | "zk=zookeeper1,zookeeper2,zookeeper3", 51 | "auth=kerberos", 52 | "service=", 53 | "cluster=", 54 | } 55 | 56 | db, err := sql.Open("drill", strings.Join(props, ";")) 57 | } 58 | ``` 59 | 60 | Alternately, you can just use the client directly: 61 | 62 | ```go 63 | import ( 64 | "context" 65 | 66 | "github.com/factset/go-drill" 67 | ) 68 | 69 | func main() { 70 | // create client, doesn't connect yet 71 | cl := drill.NewClient(drill.Options{/* fill out options */}, "zookeeper1", "zookeeper2", "zookeeper3") 72 | 73 | // connect the client 74 | err := cl.Connect(context.Background()) 75 | // if there was any issue connecting, err will contain the error, otherwise will 76 | // be nil if successfully connected 77 | } 78 | ``` 79 | 80 | ## Developing 81 | 82 | ### Refreshing the Protobuf Definitions 83 | 84 | A command is provided to easily refresh the protobuf definitions, provided you have 85 | `protoc` already on your `PATH`. The source should be in a directory structure like 86 | `.../github.com/factset/go-drill/` for development, allowing usage of `go generate` 87 | which will run the command. 88 | 89 | Alternatively, the provided command `drillProto` can be used manually via 90 | `go run ./internal/cmd/drillProto` from the root of the source directory. 91 | 92 | ```bash 93 | $ go run ./internal/cmd/drillProto -h 94 | Drill Proto. 95 | 96 | Usage: 97 | drillProto -h | --help 98 | drillProto download [-o PATH] 99 | drillProto fixup [-o PATH] 100 | drillProto gen [-o PATH] ROOTPATH 101 | drillProto runall [-o PATH] ROOTPATH 102 | 103 | Arguments: 104 | ROOTPATH location of the root output for the generated .go files 105 | 106 | Options: 107 | -h --help Show this screen. 108 | -o PATH --out PATH .proto destination path [default: protobuf] 109 | ``` 110 | 111 | `drillProto download` will simply download the .proto files to the specified path 112 | from the apache drill github repo. 113 | 114 | `drillProto fixup` adds the `option go_package = "github.com/factset/go-drill/internal/rpc/proto/..."` to each file. 115 | 116 | `drillProto gen` will generate the `.pb.go` files from the protobuf files, using the 117 | provided `ROOTPATH` as the root output where it will write the files in the structure 118 | of `/github.com/factset/go-drill/internal/rpc/proto/...`. 119 | 120 | `drillProto runall` does all of the steps in order as one command. 121 | 122 | ### Regenerate the data vector handling 123 | 124 | Running `go generate ./internal/data` will regenerate the `.gen.go` files from their 125 | templates. 126 | -------------------------------------------------------------------------------- /internal/cmd/drillProto/main.go: -------------------------------------------------------------------------------- 1 | // Mainprog to update the protobuf definitions 2 | package main 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "log" 10 | "net/http" 11 | "os" 12 | "os/exec" 13 | "path/filepath" 14 | "regexp" 15 | "strings" 16 | 17 | "github.com/docopt/docopt-go" 18 | "github.com/google/go-github/v32/github" 19 | ) 20 | 21 | const usage = `Drill Proto. 22 | 23 | Usage: 24 | drillProto -h | --help 25 | drillProto download [-o PATH] 26 | drillProto fixup [-o PATH] 27 | drillProto gen [-o PATH] ROOTPATH 28 | drillProto runall [-o PATH] ROOTPATH 29 | 30 | Arguments: 31 | ROOTPATH location of the root output for the generated .go files 32 | 33 | Options: 34 | -h --help Show this screen. 35 | -o PATH --out PATH .proto destination path [default: protobuf]` 36 | 37 | func download(outdir string) { 38 | fmt.Println("Download .proto files from Apache Drill Git Repo") 39 | client := github.NewClient(nil) 40 | 41 | _, dircont, _, err := client.Repositories.GetContents(context.Background(), "apache", "drill", "protocol/src/main/protobuf", &github.RepositoryContentGetOptions{Ref: "master"}) 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | 46 | info, err := os.Stat(outdir) 47 | if os.IsNotExist(err) { 48 | if err := os.Mkdir(outdir, os.ModePerm); err != nil { 49 | log.Fatal(err) 50 | } 51 | } else { 52 | if !info.IsDir() { 53 | log.Fatal("Path is a file, can't be used") 54 | } 55 | } 56 | 57 | for _, c := range dircont { 58 | path := filepath.Join(outdir, c.GetName()) 59 | log.Printf("Downloading: %s to %s\n", c.GetName(), path) 60 | 61 | resp, err := http.Get(c.GetDownloadURL()) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | 66 | defer resp.Body.Close() 67 | out, err := os.Create(path) 68 | if err != nil { 69 | log.Fatal(err) 70 | } 71 | defer out.Close() 72 | 73 | _, err = io.Copy(out, resp.Body) 74 | if err != nil { 75 | log.Fatal(err) 76 | } 77 | } 78 | } 79 | 80 | func fixup(dirpath string) { 81 | fmt.Println("Update .proto files with go_package option") 82 | files, err := ioutil.ReadDir(dirpath) 83 | if err != nil { 84 | log.Fatal(err) 85 | } 86 | 87 | reGopkg := regexp.MustCompile(`option go_package = (".*");`) 88 | rePkg := regexp.MustCompile(`package (?P.*);`) 89 | 90 | root := []string{"github.com/factset/go-drill/internal/rpc/proto"} 91 | 92 | for _, f := range files { 93 | contents, err := ioutil.ReadFile(filepath.Join(dirpath, f.Name())) 94 | if err != nil { 95 | log.Fatal(err) 96 | } 97 | 98 | if reGopkg.Match(contents) { 99 | continue 100 | } 101 | 102 | submatches := rePkg.FindSubmatchIndex(contents) 103 | pieces := strings.Split(string(contents[submatches[2]:submatches[3]]), ".") 104 | pkgname := filepath.Join(append(root, pieces...)...) 105 | 106 | insert := fmt.Sprintf("option go_package = \"%s\";\n", pkgname) 107 | 108 | data := make([]byte, len(contents)+len(insert)) 109 | copy(data, contents[:submatches[1]+2]) 110 | copy(data[submatches[1]+2:], []byte(insert)) 111 | copy(data[submatches[1]+2+len(insert):], contents[submatches[1]+2:]) 112 | 113 | if err = ioutil.WriteFile(filepath.Join(dirpath, f.Name()), data, os.ModePerm); err != nil { 114 | log.Fatal(err) 115 | } 116 | } 117 | } 118 | 119 | func gen(dirpath, outpath string) { 120 | fmt.Println("Generate the .go files from the proto definitions") 121 | dirpath, _ = filepath.Abs(dirpath) 122 | outpath, _ = filepath.Abs(outpath) 123 | 124 | protos, err := filepath.Glob(dirpath + "/*.proto") 125 | if err != nil { 126 | log.Fatal(err) 127 | } 128 | 129 | args := []string{ 130 | "--proto_path=" + dirpath, 131 | "--go_out=" + outpath, 132 | } 133 | args = append(args, protos...) 134 | 135 | cmd := exec.Command("protoc", args...) 136 | out, err := cmd.CombinedOutput() 137 | if err != nil { 138 | fmt.Println(string(out)) 139 | log.Fatal(err) 140 | } 141 | } 142 | 143 | func main() { 144 | opts, err := docopt.ParseDoc(usage) 145 | if err != nil { 146 | panic(err) 147 | } 148 | var config struct { 149 | Help bool `docopt:"-h"` 150 | Dir string `docopt:"--out"` 151 | Download bool 152 | Fixup bool 153 | Gen bool 154 | RunAll bool `docopt:"runall"` 155 | RootPath string `docopt:"ROOTPATH"` 156 | } 157 | 158 | opts.Bind(&config) 159 | 160 | if config.Download || config.RunAll { 161 | download(config.Dir) 162 | } 163 | 164 | if config.Fixup || config.RunAll { 165 | fixup(config.Dir) 166 | } 167 | 168 | if config.Gen || config.RunAll { 169 | gen(config.Dir, config.RootPath) 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /meta_requests.go: -------------------------------------------------------------------------------- 1 | package drill 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/factset/go-drill/internal/rpc/proto/exec/rpc" 7 | "github.com/factset/go-drill/internal/rpc/proto/exec/user" 8 | ) 9 | 10 | func getLikeFilter(pattern string, escape *string) *user.LikeFilter { 11 | return &user.LikeFilter{ 12 | Pattern: &pattern, 13 | Escape: escape, 14 | } 15 | } 16 | 17 | // GetMetadata returns a structure consisting of all of the Drill Server metadata 18 | // including what sql keywords are supported, escape characters, max lengths etc. 19 | func (d *Client) GetMetadata() (*user.ServerMeta, error) { 20 | resp := &user.GetServerMetaResp{} 21 | if err := d.makeReqGetResp(rpc.RpcMode_REQUEST, user.RpcType_GET_SERVER_META, &user.GetServerMetaReq{}, resp); err != nil { 22 | return nil, err 23 | } 24 | 25 | if resp.GetStatus() != user.RequestStatus_OK { 26 | return nil, fmt.Errorf("get_meta error: %s", resp.Error.GetMessage()) 27 | } 28 | 29 | return resp.ServerMeta, nil 30 | } 31 | 32 | // GetCatalogs uses the given pattern to search and return the catalogs available on 33 | // the server. For drill, this is always only "DRILL". The syntax of the pattern is 34 | // equivalent to using a LIKE sql expression. If there is no need to escape characters 35 | // in the search filter, pass nil for the second argument, otherwise it should point 36 | // to a string consisting of the characters used for escaping in the pattern. 37 | func (d *Client) GetCatalogs(pattern string, escape *string) ([]*user.CatalogMetadata, error) { 38 | req := &user.GetCatalogsReq{ 39 | CatalogNameFilter: getLikeFilter(pattern, escape), 40 | } 41 | 42 | resp := &user.GetCatalogsResp{} 43 | if err := d.makeReqGetResp(rpc.RpcMode_REQUEST, user.RpcType_GET_CATALOGS, req, resp); err != nil { 44 | return nil, err 45 | } 46 | 47 | if resp.GetStatus() != user.RequestStatus_OK { 48 | return nil, fmt.Errorf("get_catalogs error: %s", resp.Error.GetMessage()) 49 | } 50 | 51 | return resp.Catalogs, nil 52 | } 53 | 54 | // GetSchemas returns all the schemas which fit the filter patterns provided. 55 | // 56 | // The syntax for the filter pattern is the same as for GetCatalogs. 57 | func (d *Client) GetSchemas(catalogPattern, schemaPattern string, escape *string) ([]*user.SchemaMetadata, error) { 58 | req := &user.GetSchemasReq{ 59 | CatalogNameFilter: getLikeFilter(catalogPattern, escape), 60 | SchemaNameFilter: getLikeFilter(schemaPattern, escape), 61 | } 62 | 63 | resp := &user.GetSchemasResp{} 64 | if err := d.makeReqGetResp(rpc.RpcMode_REQUEST, user.RpcType_GET_SCHEMAS, req, resp); err != nil { 65 | return nil, err 66 | } 67 | 68 | if resp.GetStatus() != user.RequestStatus_OK { 69 | return nil, fmt.Errorf("get_schemas error: %s", resp.Error.GetMessage()) 70 | } 71 | 72 | return resp.GetSchemas(), nil 73 | } 74 | 75 | // GetTables returns the metadata for all the tables which fit the filter patterns 76 | // provided and are of the table types passed in. 77 | // 78 | // The syntax for the filter pattern is the same as for GetCatalogs. 79 | func (d *Client) GetTables(catalogPattern, schemaPattern, tablePattern string, escape *string, tableTypes ...string) ([]*user.TableMetadata, error) { 80 | req := &user.GetTablesReq{ 81 | CatalogNameFilter: getLikeFilter(catalogPattern, escape), 82 | SchemaNameFilter: getLikeFilter(schemaPattern, escape), 83 | TableNameFilter: getLikeFilter(tablePattern, escape), 84 | TableTypeFilter: tableTypes, 85 | } 86 | 87 | resp := &user.GetTablesResp{} 88 | if err := d.makeReqGetResp(rpc.RpcMode_REQUEST, user.RpcType_GET_TABLES, req, resp); err != nil { 89 | return nil, err 90 | } 91 | 92 | if resp.GetStatus() != user.RequestStatus_OK { 93 | return nil, fmt.Errorf("get_tables error: %s", resp.Error.GetMessage()) 94 | } 95 | 96 | return resp.GetTables(), nil 97 | } 98 | 99 | // GetColumns returns the metadata for all the columns from all the tables which fit the provided 100 | // filter patterns. 101 | // 102 | // The syntax for the filter pattern is the same as for GetCatalogs. 103 | func (d *Client) GetColumns(catalogPattern, schemaPattern, tablePattern, columnPattern string, escape *string) ([]*user.ColumnMetadata, error) { 104 | req := &user.GetColumnsReq{ 105 | CatalogNameFilter: getLikeFilter(catalogPattern, escape), 106 | SchemaNameFilter: getLikeFilter(schemaPattern, escape), 107 | TableNameFilter: getLikeFilter(tablePattern, escape), 108 | ColumnNameFilter: getLikeFilter(columnPattern, escape), 109 | } 110 | 111 | resp := &user.GetColumnsResp{} 112 | if err := d.makeReqGetResp(rpc.RpcMode_REQUEST, user.RpcType_GET_COLUMNS, req, resp); err != nil { 113 | return nil, err 114 | } 115 | 116 | if resp.GetStatus() != user.RequestStatus_OK { 117 | return nil, fmt.Errorf("get_columns error: %s", resp.Error.GetMessage()) 118 | } 119 | 120 | return resp.GetColumns(), nil 121 | } 122 | -------------------------------------------------------------------------------- /auth.go: -------------------------------------------------------------------------------- 1 | package drill 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math" 7 | 8 | "github.com/factset/go-drill/internal/rpc/proto/exec/rpc" 9 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 10 | "github.com/factset/go-drill/internal/rpc/proto/exec/user" 11 | "github.com/factset/go-drill/sasl" 12 | "github.com/jcmturner/gokrb5/v8/gssapi" 13 | "google.golang.org/protobuf/proto" 14 | ) 15 | 16 | func (d *Client) doHandshake() error { 17 | u2b := user.UserToBitHandshake{ 18 | Channel: shared.RpcChannel_USER.Enum(), 19 | RpcVersion: proto.Int32(drillRPCVersion), 20 | SupportListening: proto.Bool(true), 21 | SupportTimeout: proto.Bool(true), 22 | SaslSupport: user.SaslSupport_SASL_PRIVACY.Enum(), 23 | SupportComplexTypes: proto.Bool(d.Opts.SupportComplexTypes), 24 | ClientInfos: &user.RpcEndpointInfos{ 25 | Name: proto.String(clientName), 26 | Version: proto.String(drillVersion), 27 | Application: &d.Opts.ApplicationName, 28 | MajorVersion: proto.Uint32(drillMajorVersion), 29 | MinorVersion: proto.Uint32(drillMinorVersion), 30 | PatchVersion: proto.Uint32(drillPatchVersion), 31 | }, 32 | Credentials: &shared.UserCredentials{ 33 | UserName: &d.Opts.User, 34 | }, 35 | Properties: &user.UserProperties{ 36 | Properties: []*user.Property{ 37 | {Key: proto.String("schema"), Value: &d.Opts.Schema}, 38 | {Key: proto.String("userName"), Value: &d.Opts.User}, 39 | }, 40 | }, 41 | } 42 | 43 | if d.Opts.Passwd != "" && (d.Opts.Auth == "PLAIN" || d.Opts.Auth == "") { 44 | u2b.Properties.Properties = append(u2b.Properties.Properties, &user.Property{Key: proto.String("password"), Value: &d.Opts.Passwd}) 45 | } 46 | 47 | _, err := d.dataEncoder.Write(d.conn, rpc.RpcMode_REQUEST, user.RpcType_HANDSHAKE, d.nextCoordID(), &u2b) 48 | if err != nil { 49 | return err 50 | } 51 | 52 | d.serverInfo = &user.BitToUserHandshake{} 53 | _, err = d.dataEncoder.ReadMsg(d.conn, d.serverInfo) 54 | if err != nil { 55 | return err 56 | } 57 | 58 | if d.Opts.SaslEncrypt != d.serverInfo.GetEncrypted() { 59 | return errors.New("invalid security options") 60 | } 61 | 62 | switch d.serverInfo.GetStatus() { 63 | case user.HandshakeStatus_SUCCESS: 64 | if (len(d.Opts.Auth) > 0 && d.Opts.Auth != "plain") || d.Opts.SaslEncrypt { 65 | return errors.New("client wanted auth, but server didn't require it") 66 | } 67 | case user.HandshakeStatus_RPC_VERSION_MISMATCH: 68 | return fmt.Errorf("invalid rpc version, expected: %d, actual: %d", drillRPCVersion, d.serverInfo.GetRpcVersion()) 69 | case user.HandshakeStatus_AUTH_FAILED: 70 | return errors.New("authentication failure") 71 | case user.HandshakeStatus_UNKNOWN_FAILURE: 72 | return errors.New("unknown handshake failure") 73 | case user.HandshakeStatus_AUTH_REQUIRED: 74 | return d.handleAuth() 75 | } 76 | 77 | return nil 78 | } 79 | 80 | var createSasl = sasl.NewSaslWrapper 81 | 82 | func (d *Client) handleAuth() error { 83 | if ((len(d.Opts.Auth) > 0 && d.Opts.Auth != "plain") || d.Opts.SaslEncrypt) && !d.serverInfo.GetEncrypted() { 84 | return errors.New("client wants encryption, server doesn't support encryption") 85 | } 86 | 87 | host := d.Opts.ServiceHost 88 | if d.Opts.ServiceHost == "_HOST" || d.Opts.ServiceHost == "" { 89 | host = d.endpoint.GetAddress() 90 | } 91 | 92 | wrapper, err := createSasl(d.Opts.User, d.Opts.ServiceName+"/"+host, sasl.SecurityProps{ 93 | MinSsf: 56, 94 | MaxSsf: math.MaxUint32, 95 | MaxBufSize: d.serverInfo.GetMaxWrappedSize(), 96 | UseEncryption: d.serverInfo.GetEncrypted(), 97 | }) 98 | 99 | if err != nil { 100 | return err 101 | } 102 | 103 | token, err := wrapper.InitAuthPayload() 104 | if err != nil { 105 | return err 106 | } 107 | 108 | d.dataEncoder.Write(d.conn, rpc.RpcMode_REQUEST, user.RpcType_SASL_MESSAGE, d.nextCoordID(), &shared.SaslMessage{ 109 | Mechanism: &d.Opts.Auth, 110 | Data: token, 111 | Status: shared.SaslStatus_SASL_START.Enum(), 112 | }) 113 | 114 | saslResp := &shared.SaslMessage{} 115 | _, err = d.dataEncoder.ReadMsg(d.conn, saslResp) 116 | if err != nil { 117 | return err 118 | } 119 | 120 | for saslResp.GetStatus() == shared.SaslStatus_SASL_IN_PROGRESS { 121 | token, st := wrapper.Step(saslResp.GetData()) 122 | if st.Code != gssapi.StatusContinueNeeded && st.Code != gssapi.StatusComplete { 123 | return errors.New(st.Error()) 124 | } 125 | 126 | encodeStatus := shared.SaslStatus_SASL_IN_PROGRESS.Enum() 127 | if st.Code == gssapi.StatusComplete { 128 | encodeStatus = shared.SaslStatus_SASL_SUCCESS.Enum() 129 | } 130 | 131 | d.dataEncoder.Write(d.conn, rpc.RpcMode_REQUEST, user.RpcType_SASL_MESSAGE, d.nextCoordID(), &shared.SaslMessage{ 132 | Data: token, 133 | Status: encodeStatus, 134 | }) 135 | 136 | _, err = d.dataEncoder.ReadMsg(d.conn, saslResp) 137 | if err != nil { 138 | return err 139 | } 140 | } 141 | 142 | d.conn = wrapper.GetWrappedConn(d.conn) 143 | 144 | return nil 145 | } 146 | -------------------------------------------------------------------------------- /.github/workflows/smoketest.yml: -------------------------------------------------------------------------------- 1 | name: SmokeTest 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | smoketest: 11 | name: ZK Drill Test 12 | runs-on: ubuntu-latest 13 | services: 14 | zoo1: 15 | image: zookeeper 16 | ports: 17 | - 2181:2181 18 | env: 19 | ZOO_MY_ID: 1 20 | ZOO_SERVERS: server.1=0.0.0.0:2888:3888;2181 server.2=zoo2:2888:3888;2181 server.3=zoo3:2888:3888;2181 21 | 22 | zoo2: 23 | image: zookeeper 24 | ports: 25 | - 2182:2181 26 | env: 27 | ZOO_MY_ID: 2 28 | ZOO_SERVERS: server.1=0.0.0.0:2888:3888;2181 server.2=zoo2:2888:3888;2181 server.3=zoo3:2888:3888;2181 29 | 30 | zoo3: 31 | image: zookeeper 32 | ports: 33 | - 2183:2181 34 | env: 35 | ZOO_MY_ID: 3 36 | ZOO_SERVERS: server.1=0.0.0.0:2888:3888;2181 server.2=zoo2:2888:3888;2181 server.3=zoo3:2888:3888;2181 37 | strategy: 38 | matrix: 39 | kerberos: [true, false] 40 | steps: 41 | - name: Checkout code into the directory 42 | uses: actions/checkout@v2 43 | 44 | - name: Setup Kerberos 45 | if: ${{ matrix.kerberos }} 46 | env: 47 | REALM: "EXAMPLE.COM" 48 | PASSWORD: "password1234" 49 | run: | 50 | sudo apt-get update 51 | sudo mkdir -p /opt/var/log 52 | sudo tee /etc/krb5.conf << EOF 53 | [libdefaults] 54 | default_realm = $REALM 55 | dns_lookup_realm = false 56 | dns_lookup_kdc = false 57 | [realms] 58 | $REALM = { 59 | kdc = localhost 60 | admin_server = localhost 61 | } 62 | [logging] 63 | default = FILE:/opt/var/log/krb5libs.log 64 | kdc = FILE:/opt/var/log/krb5kdc.log 65 | admin_server = FILE:/opt/var/log/kadmind.log 66 | [domain_realm] 67 | .localhost = $REALM 68 | localhost = $REALM 69 | EOF 70 | 71 | sudo mkdir /etc/krb5kdc 72 | printf '*/*@%s\t*' "$REALM" | sudo tee /etc/krb5kdc/kadm5.acl 73 | 74 | sudo apt-get install -y krb5-user krb5-kdc krb5-admin-server 75 | printf "$PASSWORD\n$PASSWORD" | sudo kdb5_util -r "$REALM" create -s 76 | sudo kadmin.local -q "addprinc -randkey drill/$(hostname -A | cut -d' ' -f1)@$REALM" 77 | sudo kadmin.local -q "addprinc -randkey drill/localhost@$REALM" 78 | sudo kadmin.local -q "ktadd -k /tmp/drill.keytab drill/$(hostname -A | cut -d' ' -f1)@$REALM" 79 | sudo kadmin.local -q "ktadd -k /tmp/drill.keytab drill/localhost@$REALM" 80 | sudo chmod +rx /tmp/drill.keytab 81 | 82 | sudo service krb5-kdc restart 83 | sudo service krb5-admin-server restart 84 | kinit -kt /tmp/drill.keytab "drill/localhost@$REALM" 85 | 86 | - name: Setup and Start Drill 87 | env: 88 | KERBEROS: "${{ matrix.kerberos }}" 89 | run: | 90 | sudo chmod +rwx -R /opt 91 | mkdir /opt/drill 92 | sh -c 'wget -O - http://apache.mirrors.hoobly.com/drill/drill-1.19.0/apache-drill-1.19.0.tar.gz | tar -xz' 93 | mv apache-drill-1.19.0/* /opt/drill/ 94 | cp smoketest/storage-plugins-override.conf /opt/drill/conf/ 95 | if [ $KERBEROS = "true" ]; then 96 | cp smoketest/drill-override-kerberos.conf /opt/drill/conf/drill-override.conf 97 | else 98 | cp smoketest/drill-override.conf /opt/drill/conf/drill-override.conf 99 | fi 100 | sudo adduser drill --gecos "First Last,RoomNumber,WorkPhone,HomePhone" --disabled-password 101 | echo "drill:password" | sudo chpasswd 102 | sudo chown -R drill:drill /opt/drill 103 | sudo -u drill /opt/drill/bin/drillbit.sh --config /opt/drill/conf start 104 | 105 | - name: Set up Go 106 | uses: actions/setup-go@v2 107 | with: 108 | go-version: ^1.13 109 | id: go 110 | 111 | - name: Get dependencies 112 | run: go mod download all 113 | 114 | - name: Run smoke test 115 | if: ${{ !matrix.kerberos }} 116 | run: | 117 | go install github.com/ory/go-acc@latest 118 | go-acc -o coverage.out ./... -- -race -v -tags smoke -run ^Example$ 119 | 120 | - name: Run Smoke Test Kerberos 121 | if: ${{ matrix.kerberos }} 122 | run: | 123 | go install github.com/ory/go-acc@latest 124 | go-acc -o coverage.out ./... -- -race -v -tags smoke,kerberos -run ^Example_kerberos$ 125 | 126 | - name: Codecov 127 | uses: codecov/codecov-action@v1.0.12 128 | with: 129 | file: coverage.out 130 | flags: smoketest 131 | 132 | - name: print drill log 133 | if: ${{ failure() }} 134 | run: cat /opt/drill/log/drillbit.log /opt/drill/log/drillbit.out 135 | 136 | - name: Stop drillbit 137 | run: sudo /opt/drill/bin/drillbit.sh graceful_stop 138 | -------------------------------------------------------------------------------- /driver/conn.go: -------------------------------------------------------------------------------- 1 | // Package driver provides a driver compatible with the golang database/sql/driver 2 | // standard package. 3 | // 4 | // Basic example 5 | // 6 | // import ( 7 | // "strings" 8 | // "database/sql" 9 | // 10 | // _ "github.com/factset/go-drill/driver" 11 | // ) 12 | // 13 | // func main() { 14 | // props := []string{ 15 | // "zk=zookeeper1,zookeeper2,zookeeper3", 16 | // "auth=kerberos", 17 | // "service=", 18 | // "cluster=", 19 | // } 20 | // db, err := sql.Open("drill", strings.Join(props, ";")) 21 | // ... 22 | // } 23 | // 24 | // Connection String 25 | // 26 | // zk=node1,node2,node3/non/default/path 27 | // Specify the zookeeper nodes to utilize for discovering endpoints. Will 28 | // default to using port 2181 if not specified with the address for each 29 | // zookeeper node. Will default to using /drill/drillbits unless a non default 30 | // zookeeper path is specified as shown. 31 | // 32 | // auth= 33 | // If using sasl authentication, this is used to specify the authentication 34 | // mechanism. Currently only supports "kerberos" which will use GSSAPI 35 | // authentication. If using user/password authentication, this is ignored. 36 | // 37 | // schema= 38 | // Default schema/context to run queries in. 39 | // 40 | // service= 41 | // If using kerberos authentication, this should be the kerberos service name 42 | // for the ticket utilized by the server for authentication. 43 | // 44 | // encrypt= 45 | // Set to true if using Sasl Encryption for communication. 46 | // 47 | // user= 48 | // Username to authenticate as either for the kerberos TKT to use for auth, 49 | // or the username to authenticate with the provided password. 50 | // 51 | // pass= 52 | // If using user/pass authentication instead of kerberos, this is how you provide 53 | // the password to use. 54 | // 55 | // cluster= 56 | // If using a non-default cluster name for drill, specify this so that the 57 | // zookeeper cluster can be found properly for the drillbit endpoints. 58 | // 59 | // host= 60 | // Hostname to connect to for direct connection. 61 | // 62 | // port= 63 | // Port number to use if not using the default 31010 port for direct connection. 64 | // 65 | // heartbeat= 66 | // By default the driver will use a 15 second heartbeat frequency to keep the 67 | // the connection going. If a different frequency is desired it can be specified 68 | // with this parameter. A frequency of 0 results in no heartbeat used. 69 | package driver 70 | 71 | import ( 72 | "context" 73 | "database/sql" 74 | "database/sql/driver" 75 | "errors" 76 | "io" 77 | 78 | "github.com/factset/go-drill" 79 | ) 80 | 81 | var errNoPrepSupport = errors.New("drill does not support parameters in prepared statements") 82 | 83 | func init() { 84 | sql.Register("drill", drillDriver{}) 85 | } 86 | 87 | type drillDriver struct{} 88 | 89 | func (d drillDriver) Open(dsn string) (driver.Conn, error) { 90 | cn, err := d.OpenConnector(dsn) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | return cn.Connect(context.Background()) 96 | } 97 | 98 | func (d drillDriver) OpenConnector(name string) (driver.Connector, error) { 99 | return parseConnectStr(name) 100 | } 101 | 102 | func processWithCtx(ctx context.Context, handle drill.DataHandler, f func(h drill.DataHandler) error) error { 103 | done := make(chan struct{}) 104 | defer close(done) 105 | 106 | go func() { 107 | select { 108 | case <-ctx.Done(): 109 | handle.Cancel() 110 | case <-done: 111 | } 112 | }() 113 | 114 | return f(handle) 115 | } 116 | 117 | type conn struct { 118 | drill.Conn 119 | } 120 | 121 | func (c *conn) Begin() (driver.Tx, error) { 122 | return nil, errors.New("not implemented") 123 | } 124 | 125 | func (c *conn) Prepare(query string) (driver.Stmt, error) { 126 | stmt, err := c.Conn.PrepareQuery(query) 127 | if err != nil { 128 | return nil, err 129 | } 130 | return &prepared{stmt: stmt, client: c.Conn}, nil 131 | } 132 | 133 | func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 134 | if len(args) > 0 { 135 | return nil, errNoPrepSupport 136 | } 137 | 138 | handle, err := c.Conn.SubmitQuery(drill.TypeSQL, query) 139 | if err != nil { 140 | return nil, driver.ErrBadConn 141 | } 142 | 143 | var affectedRows int64 = 0 144 | err = processWithCtx(ctx, handle, func(h drill.DataHandler) error { 145 | var err error 146 | var batch drill.RowBatch 147 | for batch, err = h.Next(); err == nil; batch, err = h.Next() { 148 | affectedRows += int64(batch.AffectedRows()) 149 | } 150 | 151 | return err 152 | }) 153 | 154 | if err == io.EOF { 155 | err = nil 156 | } 157 | 158 | return result{rowsAffected: affectedRows, rowsError: err}, nil 159 | } 160 | 161 | func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 162 | if len(args) > 0 { 163 | return nil, errNoPrepSupport 164 | } 165 | 166 | handle, err := c.Conn.SubmitQuery(drill.TypeSQL, query) 167 | if err != nil { 168 | return nil, driver.ErrBadConn 169 | } 170 | 171 | r := &rows{handle: handle} 172 | return r, processWithCtx(ctx, handle, func(h drill.DataHandler) error { 173 | _, err := h.Next() 174 | return err 175 | }) 176 | } 177 | -------------------------------------------------------------------------------- /internal/data/arrow.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | //go:generate go run ../cmd/tmpl -data numeric.tmpldata arrow_numeric.gen_test.go.tmpl 4 | 5 | import ( 6 | "reflect" 7 | 8 | "github.com/apache/arrow/go/v7/arrow" 9 | "github.com/apache/arrow/go/v7/arrow/array" 10 | "github.com/apache/arrow/go/v7/arrow/bitutil" 11 | "github.com/apache/arrow/go/v7/arrow/decimal128" 12 | "github.com/apache/arrow/go/v7/arrow/memory" 13 | "github.com/factset/go-drill/internal/rpc/proto/common" 14 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 15 | ) 16 | 17 | // ArrowTypeToReflect will get the reflection type from the arrow datatype. 18 | // 19 | // TODO: handle decimal types properly 20 | func ArrowTypeToReflect(typ arrow.DataType) reflect.Type { 21 | switch typ.ID() { 22 | case arrow.BOOL: 23 | return reflect.TypeOf(true) 24 | case arrow.UINT8: 25 | return reflect.TypeOf(uint8(0)) 26 | case arrow.INT8: 27 | return reflect.TypeOf(int8(0)) 28 | case arrow.UINT16: 29 | return reflect.TypeOf(uint16(0)) 30 | case arrow.INT16: 31 | return reflect.TypeOf(int16(0)) 32 | case arrow.UINT32: 33 | return reflect.TypeOf(uint32(0)) 34 | case arrow.INT32: 35 | return reflect.TypeOf(int32(0)) 36 | case arrow.UINT64: 37 | return reflect.TypeOf(uint64(0)) 38 | case arrow.INT64: 39 | return reflect.TypeOf(int64(0)) 40 | case arrow.FLOAT32: 41 | return reflect.TypeOf(float32(0)) 42 | case arrow.FLOAT64: 43 | return reflect.TypeOf(float64(0)) 44 | case arrow.STRING: 45 | return reflect.TypeOf("") 46 | case arrow.BINARY, arrow.FIXED_SIZE_BINARY: 47 | return reflect.TypeOf([]byte{}) 48 | case arrow.TIMESTAMP: 49 | return reflect.TypeOf(arrow.Timestamp(0)) 50 | case arrow.DATE64: 51 | return reflect.TypeOf(arrow.Date64(0)) 52 | case arrow.TIME32: 53 | return reflect.TypeOf(arrow.Time32(0)) 54 | case arrow.INTERVAL_MONTHS: 55 | return reflect.TypeOf(arrow.MonthInterval(0)) 56 | case arrow.INTERVAL_DAY_TIME: 57 | return reflect.TypeOf(arrow.DayTimeInterval{}) 58 | case arrow.INTERVAL: 59 | switch typ.(type) { 60 | case *arrow.DayTimeIntervalType: 61 | return reflect.TypeOf(arrow.DayTimeInterval{}) 62 | case *arrow.MonthIntervalType: 63 | return reflect.TypeOf(arrow.MonthInterval(0)) 64 | } 65 | case arrow.DECIMAL: 66 | return reflect.TypeOf(decimal128.FromI64(0)) 67 | } 68 | return nil 69 | } 70 | 71 | // TypeToArrowType converts the specified type enum to an arrow Data Type 72 | // 73 | // TODO: handle decimal types 74 | func TypeToArrowType(typ common.MinorType) arrow.DataType { 75 | switch typ { 76 | case common.MinorType_BIGINT: 77 | return arrow.PrimitiveTypes.Int64 78 | case common.MinorType_INT: 79 | return arrow.PrimitiveTypes.Int32 80 | case common.MinorType_SMALLINT: 81 | return arrow.PrimitiveTypes.Int16 82 | case common.MinorType_TINYINT: 83 | return arrow.PrimitiveTypes.Int8 84 | case common.MinorType_DATE: 85 | return arrow.FixedWidthTypes.Date64 86 | case common.MinorType_TIME: 87 | return arrow.FixedWidthTypes.Time32ms 88 | case common.MinorType_BIT: 89 | return arrow.FixedWidthTypes.Boolean 90 | case common.MinorType_FLOAT4: 91 | return arrow.PrimitiveTypes.Float32 92 | case common.MinorType_FLOAT8: 93 | return arrow.PrimitiveTypes.Float64 94 | case common.MinorType_UINT1: 95 | return arrow.PrimitiveTypes.Uint8 96 | case common.MinorType_UINT2: 97 | return arrow.PrimitiveTypes.Uint16 98 | case common.MinorType_UINT4: 99 | return arrow.PrimitiveTypes.Uint32 100 | case common.MinorType_UINT8: 101 | return arrow.PrimitiveTypes.Uint64 102 | case common.MinorType_INTERVALDAY: 103 | return arrow.FixedWidthTypes.DayTimeInterval 104 | case common.MinorType_INTERVALYEAR: 105 | return arrow.FixedWidthTypes.MonthInterval 106 | case common.MinorType_VARCHAR: 107 | return arrow.BinaryTypes.String 108 | case common.MinorType_VARBINARY: 109 | return arrow.BinaryTypes.Binary 110 | case common.MinorType_TIMESTAMP: 111 | return arrow.FixedWidthTypes.Timestamp_ms 112 | } 113 | return arrow.Null 114 | } 115 | 116 | func nullBytesToBits(bytemap []byte) []byte { 117 | ret := make([]byte, bitutil.CeilByte(len(bytemap))) 118 | for idx, b := range bytemap { 119 | if b != 0 { 120 | bitutil.SetBit(ret, idx) 121 | } 122 | } 123 | return ret 124 | } 125 | 126 | // NewArrowArray constructs an arrow.Interface array from the given raw data and serialized 127 | // metadata as a zero-copy array. 128 | // 129 | // TODO: Handle decimal types properly 130 | func NewArrowArray(rawData []byte, meta *shared.SerializedField) (ret array.Interface) { 131 | arrowType := TypeToArrowType(meta.GetMajorType().GetMinorType()) 132 | if arrowType == arrow.Null { 133 | return array.NewNull(int(meta.GetValueCount())) 134 | } 135 | 136 | fieldMeta := meta 137 | remaining := rawData 138 | buffers := make([]*memory.Buffer, 1, 2) 139 | nullCount := array.UnknownNullCount 140 | if meta.GetMajorType().GetMode() == common.DataMode_OPTIONAL { 141 | buffers[0] = memory.NewBufferBytes(nullBytesToBits(rawData[:meta.GetValueCount()])) 142 | remaining = rawData[meta.GetValueCount():] 143 | fieldMeta = meta.Child[1] 144 | } else { 145 | buffers[0] = nil 146 | nullCount = 0 147 | } 148 | 149 | if len(fieldMeta.Child) > 0 && fieldMeta.Child[0].NamePart.GetName() == "$offsets$" { 150 | buffers = append(buffers, memory.NewBufferBytes(remaining[:fieldMeta.Child[0].GetBufferLength()])) 151 | remaining = remaining[fieldMeta.Child[0].GetBufferLength():] 152 | } 153 | 154 | buffers = append(buffers, memory.NewBufferBytes(remaining)) 155 | 156 | data := array.NewData(arrowType, int(meta.GetValueCount()), buffers, nil, nullCount, 0) 157 | switch arrowType.(type) { 158 | case *arrow.DayTimeIntervalType, *arrow.MonthIntervalType: 159 | ret = array.NewIntervalData(data) 160 | default: 161 | ret = array.MakeFromData(data) 162 | } 163 | data.Release() 164 | return 165 | } 166 | -------------------------------------------------------------------------------- /driver/connector_test.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/factset/go-drill" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/mock" 12 | ) 13 | 14 | type mockDrillClient struct { 15 | mock.Mock 16 | } 17 | 18 | func (m *mockDrillClient) NewConnection(ctx context.Context) (drill.Conn, error) { 19 | args := m.Called(ctx) 20 | return args.Get(0).(drill.Conn), args.Error(1) 21 | } 22 | 23 | func (m *mockDrillClient) GetEndpoint() drill.Drillbit { return nil } 24 | func (m *mockDrillClient) Connect(context.Context) error { return nil } 25 | func (m *mockDrillClient) ConnectEndpoint(context.Context, drill.Drillbit) error { return nil } 26 | func (m *mockDrillClient) ConnectWithZK(context.Context, ...string) error { return nil } 27 | func (m *mockDrillClient) Ping(context.Context) error { return nil } 28 | func (m *mockDrillClient) Close() error { return nil } 29 | func (m *mockDrillClient) ExecuteStmt(p drill.PreparedHandle) (drill.DataHandler, error) { 30 | args := m.Called(p) 31 | return args.Get(0).(drill.DataHandler), args.Error(1) 32 | } 33 | func (m *mockDrillClient) SubmitQuery(t drill.QueryType, query string) (drill.DataHandler, error) { 34 | args := m.Called(t, query) 35 | return args.Get(0).(drill.DataHandler), args.Error(1) 36 | } 37 | func (m *mockDrillClient) PrepareQuery(query string) (drill.PreparedHandle, error) { 38 | args := m.Called(query) 39 | return args.Get(0).(drill.PreparedHandle), args.Error(1) 40 | } 41 | 42 | func TestParseConnectStrZKDirect(t *testing.T) { 43 | tests := []struct { 44 | name string 45 | testStr string 46 | expected []string 47 | host string 48 | port int32 49 | err error 50 | }{ 51 | {"simple zk", "zk=node1,node2,node3", []string{"node1", "node2", "node3"}, "", 0, nil}, 52 | {"simple direct", "host=localhost;port=8080", nil, "localhost", 8080, nil}, 53 | {"invalid port", "host=localhost;port=foobar", nil, "", 0, errors.New("invalid port")}, 54 | {"default port", "host=localhost", nil, "localhost", 31010, nil}, 55 | } 56 | 57 | for _, tt := range tests { 58 | t.Run(tt.name, func(t *testing.T) { 59 | c, err := parseConnectStr(tt.testStr) 60 | if tt.err != nil { 61 | assert.Error(t, err) 62 | return 63 | } 64 | 65 | assert.NoError(t, err) 66 | assert.Equal(t, tt.expected, c.(*connector).base.(*drill.Client).ZkNodes) 67 | 68 | endpoint := c.(*connector).base.(*drill.Client).GetEndpoint() 69 | if len(tt.expected) > 0 { 70 | assert.Nil(t, endpoint) 71 | return 72 | } 73 | 74 | assert.Equal(t, tt.host, endpoint.GetAddress()) 75 | assert.Equal(t, tt.port, endpoint.GetUserPort()) 76 | }) 77 | } 78 | } 79 | 80 | func TestParseConnectStr(t *testing.T) { 81 | durtest := new(time.Duration) 82 | *durtest = 5 * time.Second 83 | 84 | tests := []struct { 85 | name string 86 | testStr string 87 | expected drill.Options 88 | }{ 89 | {"auth", "auth=kerberos", drill.Options{Auth: "kerberos"}}, 90 | {"schema", "schema=foobar", drill.Options{Schema: "foobar"}}, 91 | {"service", "service=nidrill", drill.Options{ServiceName: "nidrill"}}, 92 | {"encrypt true", "encrypt=true", drill.Options{SaslEncrypt: true}}, 93 | {"encrypt false", "encrypt=false", drill.Options{SaslEncrypt: false}}, 94 | {"user", "user=driller", drill.Options{User: "driller"}}, 95 | {"cluster", "cluster=supercluster", drill.Options{ClusterName: "supercluster"}}, 96 | {"heartbeat", "heartbeat=5", drill.Options{HeartbeatFreq: durtest}}, 97 | {"multiple opts", "auth=kerberos;user=foobar;encrypt=true", drill.Options{Auth: "kerberos", User: "foobar", SaslEncrypt: true}}, 98 | {"zkpath", "zk=node1,node2,node3/drillbits", drill.Options{ZKPath: "/drillbits"}}, 99 | {"user passwd", "user=driller;pass=12345", drill.Options{User: "driller", Passwd: "12345"}}, 100 | } 101 | for _, tt := range tests { 102 | t.Run(tt.name, func(t *testing.T) { 103 | conn, err := parseConnectStr(tt.testStr) 104 | assert.NoError(t, err) 105 | 106 | assert.EqualValues(t, tt.expected, conn.(*connector).base.(*drill.Client).Opts) 107 | }) 108 | } 109 | } 110 | 111 | func TestParseConnectStrInvalid(t *testing.T) { 112 | tests := []struct { 113 | name string 114 | testStr string 115 | errMsg string 116 | }{ 117 | {"invalid format", "foo", "invalid format for connector string"}, 118 | {"trailing semicolon doesn't work", "auth=bar;", "invalid format for connector string"}, 119 | {"invalid encrypt val", "encrypt=foo", "strconv.ParseBool: parsing \"foo\": invalid syntax"}, 120 | {"invalid heartbeat freq", "heartbeat=foo", "strconv.Atoi: parsing \"foo\": invalid syntax"}, 121 | {"invalid arg", "foo=bar", "invalid argument for connection string: foo"}, 122 | } 123 | 124 | for _, tt := range tests { 125 | t.Run(tt.name, func(t *testing.T) { 126 | _, err := parseConnectStr(tt.testStr) 127 | assert.Error(t, err) 128 | assert.EqualError(t, err, tt.errMsg) 129 | }) 130 | } 131 | } 132 | 133 | func TestConnectorDriver(t *testing.T) { 134 | c := &connector{} 135 | assert.IsType(t, drillDriver{}, c.Driver()) 136 | } 137 | 138 | func TestConnectorConnect(t *testing.T) { 139 | m := new(mockDrillClient) 140 | m.Test(t) 141 | 142 | ctx := context.Background() 143 | m.On("NewConnection", ctx).Return(m, nil) 144 | 145 | c := &connector{base: m} 146 | cn, err := c.Connect(ctx) 147 | assert.NoError(t, err) 148 | assert.Same(t, cn.(*conn).Conn, m) 149 | } 150 | 151 | func TestConnectorConnectErr(t *testing.T) { 152 | m := new(mockDrillClient) 153 | m.Test(t) 154 | 155 | ctx := context.Background() 156 | m.On("NewConnection", ctx).Return(m, assert.AnError) 157 | 158 | c := &connector{base: m} 159 | conn, err := c.Connect(ctx) 160 | assert.Nil(t, conn) 161 | assert.Same(t, assert.AnError, err) 162 | } 163 | -------------------------------------------------------------------------------- /internal/data/type_traits_numeric.gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by type_traits_numeric.gen_test.go.tmpl. DO NOT EDIT. 2 | 3 | package data_test 4 | 5 | import ( 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/factset/go-drill/internal/data" 10 | ) 11 | 12 | func TestInt64Traits(t *testing.T) { 13 | const N = 10 14 | b1 := data.Int64Traits.CastToBytes([]int64{ 15 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16 | }) 17 | 18 | v1 := data.Int64Traits.CastFromBytes(b1) 19 | for i, v := range v1 { 20 | if got, want := v, int64(i); got != want { 21 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 22 | } 23 | } 24 | 25 | v2 := make([]int64, N) 26 | data.Int64Traits.Copy(v2, v1) 27 | 28 | if !reflect.DeepEqual(v1, v2) { 29 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 30 | } 31 | } 32 | 33 | func TestInt32Traits(t *testing.T) { 34 | const N = 10 35 | b1 := data.Int32Traits.CastToBytes([]int32{ 36 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 37 | }) 38 | 39 | v1 := data.Int32Traits.CastFromBytes(b1) 40 | for i, v := range v1 { 41 | if got, want := v, int32(i); got != want { 42 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 43 | } 44 | } 45 | 46 | v2 := make([]int32, N) 47 | data.Int32Traits.Copy(v2, v1) 48 | 49 | if !reflect.DeepEqual(v1, v2) { 50 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 51 | } 52 | } 53 | 54 | func TestFloat64Traits(t *testing.T) { 55 | const N = 10 56 | b1 := data.Float64Traits.CastToBytes([]float64{ 57 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 58 | }) 59 | 60 | v1 := data.Float64Traits.CastFromBytes(b1) 61 | for i, v := range v1 { 62 | if got, want := v, float64(i); got != want { 63 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 64 | } 65 | } 66 | 67 | v2 := make([]float64, N) 68 | data.Float64Traits.Copy(v2, v1) 69 | 70 | if !reflect.DeepEqual(v1, v2) { 71 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 72 | } 73 | } 74 | 75 | func TestUint64Traits(t *testing.T) { 76 | const N = 10 77 | b1 := data.Uint64Traits.CastToBytes([]uint64{ 78 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 79 | }) 80 | 81 | v1 := data.Uint64Traits.CastFromBytes(b1) 82 | for i, v := range v1 { 83 | if got, want := v, uint64(i); got != want { 84 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 85 | } 86 | } 87 | 88 | v2 := make([]uint64, N) 89 | data.Uint64Traits.Copy(v2, v1) 90 | 91 | if !reflect.DeepEqual(v1, v2) { 92 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 93 | } 94 | } 95 | 96 | func TestUint32Traits(t *testing.T) { 97 | const N = 10 98 | b1 := data.Uint32Traits.CastToBytes([]uint32{ 99 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100 | }) 101 | 102 | v1 := data.Uint32Traits.CastFromBytes(b1) 103 | for i, v := range v1 { 104 | if got, want := v, uint32(i); got != want { 105 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 106 | } 107 | } 108 | 109 | v2 := make([]uint32, N) 110 | data.Uint32Traits.Copy(v2, v1) 111 | 112 | if !reflect.DeepEqual(v1, v2) { 113 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 114 | } 115 | } 116 | 117 | func TestFloat32Traits(t *testing.T) { 118 | const N = 10 119 | b1 := data.Float32Traits.CastToBytes([]float32{ 120 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 121 | }) 122 | 123 | v1 := data.Float32Traits.CastFromBytes(b1) 124 | for i, v := range v1 { 125 | if got, want := v, float32(i); got != want { 126 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 127 | } 128 | } 129 | 130 | v2 := make([]float32, N) 131 | data.Float32Traits.Copy(v2, v1) 132 | 133 | if !reflect.DeepEqual(v1, v2) { 134 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 135 | } 136 | } 137 | 138 | func TestInt16Traits(t *testing.T) { 139 | const N = 10 140 | b1 := data.Int16Traits.CastToBytes([]int16{ 141 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 142 | }) 143 | 144 | v1 := data.Int16Traits.CastFromBytes(b1) 145 | for i, v := range v1 { 146 | if got, want := v, int16(i); got != want { 147 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 148 | } 149 | } 150 | 151 | v2 := make([]int16, N) 152 | data.Int16Traits.Copy(v2, v1) 153 | 154 | if !reflect.DeepEqual(v1, v2) { 155 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 156 | } 157 | } 158 | 159 | func TestUint16Traits(t *testing.T) { 160 | const N = 10 161 | b1 := data.Uint16Traits.CastToBytes([]uint16{ 162 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 163 | }) 164 | 165 | v1 := data.Uint16Traits.CastFromBytes(b1) 166 | for i, v := range v1 { 167 | if got, want := v, uint16(i); got != want { 168 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 169 | } 170 | } 171 | 172 | v2 := make([]uint16, N) 173 | data.Uint16Traits.Copy(v2, v1) 174 | 175 | if !reflect.DeepEqual(v1, v2) { 176 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 177 | } 178 | } 179 | 180 | func TestInt8Traits(t *testing.T) { 181 | const N = 10 182 | b1 := data.Int8Traits.CastToBytes([]int8{ 183 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 184 | }) 185 | 186 | v1 := data.Int8Traits.CastFromBytes(b1) 187 | for i, v := range v1 { 188 | if got, want := v, int8(i); got != want { 189 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 190 | } 191 | } 192 | 193 | v2 := make([]int8, N) 194 | data.Int8Traits.Copy(v2, v1) 195 | 196 | if !reflect.DeepEqual(v1, v2) { 197 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 198 | } 199 | } 200 | 201 | func TestUint8Traits(t *testing.T) { 202 | const N = 10 203 | b1 := data.Uint8Traits.CastToBytes([]uint8{ 204 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 205 | }) 206 | 207 | v1 := data.Uint8Traits.CastFromBytes(b1) 208 | for i, v := range v1 { 209 | if got, want := v, uint8(i); got != want { 210 | t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) 211 | } 212 | } 213 | 214 | v2 := make([]uint8, N) 215 | data.Uint8Traits.Copy(v2, v1) 216 | 217 | if !reflect.DeepEqual(v1, v2) { 218 | t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package drill 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "io" 7 | "net" 8 | 9 | "github.com/apache/arrow/go/v7/arrow" 10 | "github.com/factset/go-drill/internal/rpc/proto/exec/rpc" 11 | "github.com/factset/go-drill/internal/rpc/proto/exec/user" 12 | "google.golang.org/protobuf/proto" 13 | ) 14 | 15 | type encoder interface { 16 | WriteRaw(net.Conn, []byte) (int, error) 17 | Write(net.Conn, rpc.RpcMode, user.RpcType, int32, proto.Message) (int, error) 18 | ReadMsg(net.Conn, proto.Message) (*rpc.RpcHeader, error) 19 | ReadRaw(net.Conn) (*rpc.CompleteRpcMessage, error) 20 | } 21 | 22 | type rpcEncoder struct{} 23 | 24 | func (rpcEncoder) WriteRaw(conn net.Conn, b []byte) (int, error) { 25 | return conn.Write(makePrefixedMessage(b)) 26 | } 27 | 28 | func (rpcEncoder) Write(conn net.Conn, mode rpc.RpcMode, typ user.RpcType, coord int32, msg proto.Message) (int, error) { 29 | encoded, err := encodeRPCMessage(mode, typ, coord, msg) 30 | if err != nil { 31 | return 0, err 32 | } 33 | return conn.Write(makePrefixedMessage(encoded)) 34 | } 35 | 36 | func (rpcEncoder) ReadRaw(conn net.Conn) (*rpc.CompleteRpcMessage, error) { 37 | return readPrefixedRaw(conn) 38 | } 39 | 40 | func (rpcEncoder) ReadMsg(conn net.Conn, msg proto.Message) (*rpc.RpcHeader, error) { 41 | return readPrefixedMessage(conn, msg) 42 | } 43 | 44 | var errInvalidResponse = errors.New("invalid response") 45 | 46 | func makePrefixedMessage(data []byte) []byte { 47 | if data == nil { 48 | return nil 49 | } 50 | 51 | buf := make([]byte, binary.MaxVarintLen32) 52 | nbytes := binary.PutUvarint(buf, uint64(len(data))) 53 | return append(buf[:nbytes], data...) 54 | } 55 | 56 | func readPrefixed(r io.Reader) ([]byte, error) { 57 | vbytes := make([]byte, binary.MaxVarintLen32) 58 | n, err := io.ReadAtLeast(r, vbytes, binary.MaxVarintLen32) 59 | if err == io.EOF { 60 | return nil, io.ErrUnexpectedEOF 61 | } else if err != nil { 62 | return nil, err 63 | } 64 | 65 | respLength, vlength := binary.Uvarint(vbytes) 66 | 67 | // if we got an empty message and read too many bytes we're screwed 68 | // but this shouldn't happen anyways, just in case 69 | if vlength < 1 || vlength+int(respLength) < n { 70 | return nil, errInvalidResponse 71 | } 72 | 73 | respBytes := make([]byte, respLength) 74 | extraLen := copy(respBytes, vbytes[vlength:]) 75 | _, err = io.ReadFull(r, respBytes[extraLen:]) 76 | if err == io.EOF { 77 | return nil, io.ErrUnexpectedEOF 78 | } else if err != nil { 79 | return nil, err 80 | } 81 | 82 | return respBytes, nil 83 | } 84 | 85 | func readPrefixedRaw(r io.Reader) (*rpc.CompleteRpcMessage, error) { 86 | respBytes, err := readPrefixed(r) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | return getRawRPCMessage(respBytes) 92 | } 93 | 94 | func readPrefixedMessage(r io.Reader, msg proto.Message) (*rpc.RpcHeader, error) { 95 | respBytes, err := readPrefixed(r) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | return decodeRPCMessage(respBytes, msg) 101 | } 102 | 103 | func encodeRPCMessage(mode rpc.RpcMode, msgType user.RpcType, coordID int32, msg proto.Message) ([]byte, error) { 104 | data, err := proto.Marshal(msg) 105 | if err != nil { 106 | return nil, err 107 | } 108 | 109 | rpcMsg := &rpc.CompleteRpcMessage{ 110 | Header: &rpc.RpcHeader{ 111 | Mode: &mode, 112 | CoordinationId: &coordID, 113 | RpcType: proto.Int32(int32(msgType)), 114 | }, 115 | ProtobufBody: data, 116 | } 117 | 118 | return proto.Marshal(rpcMsg) 119 | } 120 | 121 | func getRawRPCMessage(data []byte) (*rpc.CompleteRpcMessage, error) { 122 | rpcMsg := &rpc.CompleteRpcMessage{} 123 | if err := proto.Unmarshal(data, rpcMsg); err != nil { 124 | return nil, err 125 | } 126 | 127 | return rpcMsg, nil 128 | } 129 | 130 | func decodeRPCMessage(data []byte, msg proto.Message) (*rpc.RpcHeader, error) { 131 | rpcMsg, err := getRawRPCMessage(data) 132 | if err != nil { 133 | return nil, err 134 | } 135 | 136 | ret := rpcMsg.GetHeader() 137 | return ret, proto.Unmarshal(rpcMsg.ProtobufBody, msg) 138 | } 139 | 140 | type ColumnMeta interface { 141 | GetColumnName() string 142 | GetIsNullable() bool 143 | GetDataType() string 144 | GetCharMaxLength() int32 145 | GetCharOctetLength() int32 146 | GetNumericPrecision() int32 147 | GetNumericPrecisionRadix() int32 148 | GetNumericScale() int32 149 | GetDateTimePrecision() int32 150 | GetIntervalType() string 151 | GetIntervalPrecision() int32 152 | GetColumnSize() int32 153 | GetDefaultValue() string 154 | } 155 | 156 | func arrowDataTypeFromCol(c ColumnMeta) arrow.DataType { 157 | switch c.GetDataType() { 158 | case "BOOLEAN": 159 | return arrow.FixedWidthTypes.Boolean 160 | case "BINARY VARYING": 161 | return arrow.BinaryTypes.Binary 162 | case "CHARACTER VARYING": 163 | return arrow.BinaryTypes.String 164 | case "INTEGER": 165 | return arrow.PrimitiveTypes.Int32 166 | case "BIGINT": 167 | return arrow.PrimitiveTypes.Int64 168 | case "SMALLINT": 169 | return arrow.PrimitiveTypes.Int16 170 | case "TINYINT": 171 | return arrow.PrimitiveTypes.Int8 172 | case "DATE": 173 | return arrow.FixedWidthTypes.Date64 174 | case "TIME": 175 | return arrow.FixedWidthTypes.Time32ms 176 | case "FLOAT": 177 | return arrow.PrimitiveTypes.Float32 178 | case "DOUBLE": 179 | return arrow.PrimitiveTypes.Float64 180 | case "TIMESTAMP": 181 | return arrow.FixedWidthTypes.Timestamp_ms 182 | default: 183 | panic("arrow type conversion not found for: " + c.GetDataType()) 184 | } 185 | } 186 | 187 | // ColMetaToArrowField returns an arrow.Field for the column metadata provided, 188 | // panics if not of type BOOLEAN, VARCHAR, VARBINARY, INTEGER, SMALLINT, BIGINT, 189 | // TINYINT, DATE, TIME, TIMESTAMP, FLOAT, or DOUBLE. 190 | // 191 | // TODO: handle decimal types 192 | // 193 | // Adds the following metadata: 194 | // Default Value: key "default" 195 | // 196 | func ColMetaToArrowField(c ColumnMeta) arrow.Field { 197 | return arrow.Field{ 198 | Name: c.GetColumnName(), 199 | Nullable: c.GetIsNullable(), 200 | Metadata: arrow.NewMetadata([]string{"default"}, []string{c.GetDefaultValue()}), 201 | Type: arrowDataTypeFromCol(c), 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /internal/rpc/proto/exec/SchemaDef.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.25.0 4 | // protoc v3.12.4 5 | // source: SchemaDef.proto 6 | 7 | // 8 | // Licensed to the Apache Software Foundation (ASF) under one 9 | // or more contributor license agreements. See the NOTICE file 10 | // distributed with this work for additional information 11 | // regarding copyright ownership. The ASF licenses this file 12 | // to you under the Apache License, Version 2.0 (the 13 | // "License"); you may not use this file except in compliance 14 | // with the License. You may obtain a copy of the License at 15 | // 16 | // http://www.apache.org/licenses/LICENSE-2.0 17 | // 18 | // Unless required by applicable law or agreed to in writing, software 19 | // distributed under the License is distributed on an "AS IS" BASIS, 20 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | // See the License for the specific language governing permissions and 22 | // limitations under the License. 23 | // 24 | 25 | package exec 26 | 27 | import ( 28 | proto "github.com/golang/protobuf/proto" 29 | _ "github.com/factset/go-drill/internal/rpc/proto/common" 30 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 31 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 32 | reflect "reflect" 33 | sync "sync" 34 | ) 35 | 36 | const ( 37 | // Verify that this generated code is sufficiently up-to-date. 38 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 39 | // Verify that runtime/protoimpl is sufficiently up-to-date. 40 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 41 | ) 42 | 43 | // This is a compile-time assertion that a sufficiently up-to-date version 44 | // of the legacy proto package is being used. 45 | const _ = proto.ProtoPackageIsVersion4 46 | 47 | type ValueMode int32 48 | 49 | const ( 50 | ValueMode_VALUE_VECTOR ValueMode = 0 51 | ValueMode_RLE ValueMode = 1 52 | ValueMode_DICT ValueMode = 2 53 | ) 54 | 55 | // Enum value maps for ValueMode. 56 | var ( 57 | ValueMode_name = map[int32]string{ 58 | 0: "VALUE_VECTOR", 59 | 1: "RLE", 60 | 2: "DICT", 61 | } 62 | ValueMode_value = map[string]int32{ 63 | "VALUE_VECTOR": 0, 64 | "RLE": 1, 65 | "DICT": 2, 66 | } 67 | ) 68 | 69 | func (x ValueMode) Enum() *ValueMode { 70 | p := new(ValueMode) 71 | *p = x 72 | return p 73 | } 74 | 75 | func (x ValueMode) String() string { 76 | return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) 77 | } 78 | 79 | func (ValueMode) Descriptor() protoreflect.EnumDescriptor { 80 | return file_SchemaDef_proto_enumTypes[0].Descriptor() 81 | } 82 | 83 | func (ValueMode) Type() protoreflect.EnumType { 84 | return &file_SchemaDef_proto_enumTypes[0] 85 | } 86 | 87 | func (x ValueMode) Number() protoreflect.EnumNumber { 88 | return protoreflect.EnumNumber(x) 89 | } 90 | 91 | // Deprecated: Do not use. 92 | func (x *ValueMode) UnmarshalJSON(b []byte) error { 93 | num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) 94 | if err != nil { 95 | return err 96 | } 97 | *x = ValueMode(num) 98 | return nil 99 | } 100 | 101 | // Deprecated: Use ValueMode.Descriptor instead. 102 | func (ValueMode) EnumDescriptor() ([]byte, []int) { 103 | return file_SchemaDef_proto_rawDescGZIP(), []int{0} 104 | } 105 | 106 | var File_SchemaDef_proto protoreflect.FileDescriptor 107 | 108 | var file_SchemaDef_proto_rawDesc = []byte{ 109 | 0x0a, 0x0f, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x44, 0x65, 0x66, 0x2e, 0x70, 0x72, 0x6f, 0x74, 110 | 0x6f, 0x12, 0x04, 0x65, 0x78, 0x65, 0x63, 0x1a, 0x0b, 0x54, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x70, 111 | 0x72, 0x6f, 0x74, 0x6f, 0x2a, 0x30, 0x0a, 0x09, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x4d, 0x6f, 0x64, 112 | 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x56, 0x41, 0x4c, 0x55, 0x45, 0x5f, 0x56, 0x45, 0x43, 0x54, 0x4f, 113 | 0x52, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x52, 0x4c, 0x45, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 114 | 0x44, 0x49, 0x43, 0x54, 0x10, 0x02, 0x42, 0x67, 0x0a, 0x1b, 0x6f, 0x72, 0x67, 0x2e, 0x61, 0x70, 115 | 0x61, 0x63, 0x68, 0x65, 0x2e, 0x64, 0x72, 0x69, 0x6c, 0x6c, 0x2e, 0x65, 0x78, 0x65, 0x63, 0x2e, 116 | 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x42, 0x0f, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x44, 0x65, 0x66, 117 | 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x48, 0x01, 0x5a, 0x35, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 118 | 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x65, 0x72, 0x6f, 0x73, 0x68, 0x61, 0x64, 0x65, 0x2f, 0x67, 119 | 0x6f, 0x2d, 0x64, 0x72, 0x69, 0x6c, 0x6c, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 120 | 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x65, 0x78, 0x65, 0x63, 121 | } 122 | 123 | var ( 124 | file_SchemaDef_proto_rawDescOnce sync.Once 125 | file_SchemaDef_proto_rawDescData = file_SchemaDef_proto_rawDesc 126 | ) 127 | 128 | func file_SchemaDef_proto_rawDescGZIP() []byte { 129 | file_SchemaDef_proto_rawDescOnce.Do(func() { 130 | file_SchemaDef_proto_rawDescData = protoimpl.X.CompressGZIP(file_SchemaDef_proto_rawDescData) 131 | }) 132 | return file_SchemaDef_proto_rawDescData 133 | } 134 | 135 | var file_SchemaDef_proto_enumTypes = make([]protoimpl.EnumInfo, 1) 136 | var file_SchemaDef_proto_goTypes = []interface{}{ 137 | (ValueMode)(0), // 0: exec.ValueMode 138 | } 139 | var file_SchemaDef_proto_depIdxs = []int32{ 140 | 0, // [0:0] is the sub-list for method output_type 141 | 0, // [0:0] is the sub-list for method input_type 142 | 0, // [0:0] is the sub-list for extension type_name 143 | 0, // [0:0] is the sub-list for extension extendee 144 | 0, // [0:0] is the sub-list for field type_name 145 | } 146 | 147 | func init() { file_SchemaDef_proto_init() } 148 | func file_SchemaDef_proto_init() { 149 | if File_SchemaDef_proto != nil { 150 | return 151 | } 152 | type x struct{} 153 | out := protoimpl.TypeBuilder{ 154 | File: protoimpl.DescBuilder{ 155 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 156 | RawDescriptor: file_SchemaDef_proto_rawDesc, 157 | NumEnums: 1, 158 | NumMessages: 0, 159 | NumExtensions: 0, 160 | NumServices: 0, 161 | }, 162 | GoTypes: file_SchemaDef_proto_goTypes, 163 | DependencyIndexes: file_SchemaDef_proto_depIdxs, 164 | EnumInfos: file_SchemaDef_proto_enumTypes, 165 | }.Build() 166 | File_SchemaDef_proto = out.File 167 | file_SchemaDef_proto_rawDesc = nil 168 | file_SchemaDef_proto_goTypes = nil 169 | file_SchemaDef_proto_depIdxs = nil 170 | } 171 | -------------------------------------------------------------------------------- /driver/prepared_test.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "io" 7 | "testing" 8 | "time" 9 | 10 | "github.com/factset/go-drill" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/mock" 13 | ) 14 | 15 | func TestPreparedImplements(t *testing.T) { 16 | assert.Implements(t, (*driver.StmtExecContext)(nil), new(prepared)) 17 | assert.Implements(t, (*driver.StmtQueryContext)(nil), new(prepared)) 18 | } 19 | 20 | func TestPreparedClose(t *testing.T) { 21 | prep := &prepared{stmt: drill.PreparedHandle(5), client: new(mockDrillClient)} 22 | assert.NoError(t, prep.Close()) 23 | assert.Nil(t, prep.stmt) 24 | assert.Nil(t, prep.client) 25 | } 26 | 27 | func TestPreparedNumInput(t *testing.T) { 28 | p := &prepared{} 29 | assert.Zero(t, p.NumInput()) 30 | } 31 | 32 | func TestPreparedExec(t *testing.T) { 33 | p := &prepared{} 34 | r, e := p.Exec([]driver.Value{}) 35 | assert.Equal(t, driver.ResultNoRows, r) 36 | assert.Same(t, driver.ErrSkip, e) 37 | } 38 | 39 | func TestPreparedQuery(t *testing.T) { 40 | p := &prepared{} 41 | r, e := p.Query([]driver.Value{}) 42 | assert.Nil(t, r) 43 | assert.Same(t, driver.ErrSkip, e) 44 | } 45 | 46 | func TestPreparedExecContext(t *testing.T) { 47 | m := new(mockDrillClient) 48 | m.Test(t) 49 | defer m.AssertExpectations(t) 50 | 51 | mr := new(mockResHandle) 52 | mr.Test(t) 53 | defer mr.AssertExpectations(t) 54 | 55 | p := drill.PreparedHandle(5) 56 | m.On("ExecuteStmt", p).Return(mr, nil) 57 | 58 | rb := new(mockBatch) 59 | rb.On("AffectedRows").Return(5) 60 | 61 | mr.On("Next").Return(nil, rb).Twice() 62 | mr.On("Next").Return(io.EOF, (drill.RowBatch)(nil)) 63 | 64 | prep := &prepared{stmt: p, client: m} 65 | r, err := prep.ExecContext(context.Background(), []driver.NamedValue{}) 66 | assert.NoError(t, err) 67 | num, err := r.RowsAffected() 68 | assert.NoError(t, err) 69 | assert.EqualValues(t, 10, num) 70 | 71 | _, err = r.LastInsertId() 72 | assert.Error(t, err) 73 | } 74 | 75 | func TestPreparedExecContextWithErr(t *testing.T) { 76 | m := new(mockDrillClient) 77 | m.Test(t) 78 | defer m.AssertExpectations(t) 79 | 80 | mr := new(mockResHandle) 81 | mr.Test(t) 82 | defer mr.AssertExpectations(t) 83 | 84 | p := drill.PreparedHandle(5) 85 | m.On("ExecuteStmt", p).Return(mr, nil) 86 | 87 | rb := new(mockBatch) 88 | rb.On("AffectedRows").Return(5) 89 | 90 | mr.On("Next").Return(nil, rb).Twice() 91 | mr.On("Next").Return(assert.AnError, (drill.RowBatch)(nil)) 92 | 93 | prep := &prepared{stmt: p, client: m} 94 | r, err := prep.ExecContext(context.Background(), []driver.NamedValue{}) 95 | assert.NoError(t, err) 96 | num, err := r.RowsAffected() 97 | assert.Same(t, assert.AnError, err) 98 | assert.EqualValues(t, 10, num) 99 | } 100 | 101 | func TestPreparedExecContextCtxTimeout(t *testing.T) { 102 | m := new(mockDrillClient) 103 | m.Test(t) 104 | defer m.AssertExpectations(t) 105 | 106 | mr := new(mockResHandle) 107 | mr.Test(t) 108 | defer mr.AssertExpectations(t) 109 | 110 | p := drill.PreparedHandle(5) 111 | m.On("ExecuteStmt", p).Return(mr, nil) 112 | 113 | waiter := make(chan time.Time) 114 | mr.On("Cancel").Run(func(mock.Arguments) { 115 | waiter <- time.Now() 116 | }) 117 | mr.On("Next").WaitUntil(waiter).Return(assert.AnError, (drill.RowBatch)(nil)) 118 | 119 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 120 | defer cancel() 121 | 122 | prep := &prepared{stmt: p, client: m} 123 | r, err := prep.ExecContext(ctx, []driver.NamedValue{}) 124 | assert.NoError(t, err) 125 | _, err = r.RowsAffected() 126 | assert.Same(t, assert.AnError, err) 127 | } 128 | 129 | func TestPreparedExecContextErr(t *testing.T) { 130 | m := new(mockDrillClient) 131 | m.Test(t) 132 | defer m.AssertExpectations(t) 133 | 134 | mr := new(mockResHandle) 135 | mr.Test(t) 136 | defer mr.AssertExpectations(t) 137 | 138 | p := drill.PreparedHandle(5) 139 | m.On("ExecuteStmt", p).Return((*drill.ResultHandle)(nil), assert.AnError) 140 | 141 | prep := &prepared{stmt: p, client: m} 142 | r, err := prep.ExecContext(context.Background(), []driver.NamedValue{}) 143 | assert.Same(t, driver.ErrBadConn, err) 144 | assert.Nil(t, r) 145 | } 146 | 147 | func TestPreparedExecContextNoPrep(t *testing.T) { 148 | prep := &prepared{} 149 | r, err := prep.ExecContext(context.Background(), []driver.NamedValue{{Name: "foobar"}}) 150 | assert.Nil(t, r) 151 | assert.Same(t, errNoPrepSupport, err) 152 | } 153 | 154 | func TestPreparedQueryContextNoPrep(t *testing.T) { 155 | prep := &prepared{} 156 | r, err := prep.QueryContext(context.Background(), []driver.NamedValue{{Name: "foobar"}}) 157 | assert.Nil(t, r) 158 | assert.Same(t, errNoPrepSupport, err) 159 | } 160 | 161 | func TestPreparedQueryContextErr(t *testing.T) { 162 | m := new(mockDrillClient) 163 | m.Test(t) 164 | defer m.AssertExpectations(t) 165 | 166 | mr := new(mockResHandle) 167 | mr.Test(t) 168 | defer mr.AssertExpectations(t) 169 | 170 | p := drill.PreparedHandle(5) 171 | m.On("ExecuteStmt", p).Return((*drill.ResultHandle)(nil), assert.AnError) 172 | 173 | prep := &prepared{stmt: p, client: m} 174 | r, err := prep.QueryContext(context.Background(), []driver.NamedValue{}) 175 | assert.Same(t, driver.ErrBadConn, err) 176 | assert.Nil(t, r) 177 | } 178 | 179 | func TestPreparedQueryContextCtxTimeout(t *testing.T) { 180 | m := new(mockDrillClient) 181 | m.Test(t) 182 | defer m.AssertExpectations(t) 183 | 184 | mr := new(mockResHandle) 185 | mr.Test(t) 186 | defer mr.AssertExpectations(t) 187 | 188 | p := drill.PreparedHandle(5) 189 | m.On("ExecuteStmt", p).Return(mr, nil) 190 | 191 | waiter := make(chan time.Time) 192 | mr.On("Cancel").Run(func(mock.Arguments) { 193 | waiter <- time.Now() 194 | }) 195 | mr.On("Next").WaitUntil(waiter).Return(assert.AnError, (drill.RowBatch)(nil)) 196 | 197 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 198 | defer cancel() 199 | 200 | prep := &prepared{stmt: p, client: m} 201 | _, err := prep.QueryContext(ctx, []driver.NamedValue{}) 202 | assert.Same(t, assert.AnError, err) 203 | } 204 | 205 | func TestPreparedQueryContext(t *testing.T) { 206 | m := new(mockDrillClient) 207 | m.Test(t) 208 | defer m.AssertExpectations(t) 209 | 210 | mr := new(mockResHandle) 211 | mr.Test(t) 212 | defer mr.AssertExpectations(t) 213 | 214 | p := drill.PreparedHandle(5) 215 | m.On("ExecuteStmt", p).Return(mr, nil) 216 | mr.On("Next").Return(nil, (drill.RowBatch)(nil)) 217 | 218 | prep := &prepared{stmt: p, client: m} 219 | r, err := prep.QueryContext(context.Background(), []driver.NamedValue{}) 220 | assert.NoError(t, err) 221 | assert.NotNil(t, r) 222 | assert.IsType(t, &rows{}, r) 223 | } 224 | -------------------------------------------------------------------------------- /driver/conn_test.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "io" 7 | "testing" 8 | "time" 9 | 10 | "github.com/factset/go-drill" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/mock" 13 | ) 14 | 15 | func TestConnImplements(t *testing.T) { 16 | // verify we implement the interfaces that aren't automatically going to 17 | // be enforced by the compiler so that we don't mess up any of the functions 18 | assert.Implements(t, (*driver.Pinger)(nil), new(conn)) 19 | assert.Implements(t, (*driver.QueryerContext)(nil), new(conn)) 20 | assert.Implements(t, (*driver.ExecerContext)(nil), new(conn)) 21 | } 22 | 23 | func TestDriverOpenErr(t *testing.T) { 24 | c, err := drillDriver{}.Open(";") 25 | assert.Nil(t, c) 26 | assert.Error(t, err) 27 | } 28 | 29 | func TestDriverOpenConnector(t *testing.T) { 30 | c, err := drillDriver{}.OpenConnector("auth=plain") 31 | assert.Equal(t, "plain", c.(*connector).base.(*drill.Client).Opts.Auth) 32 | assert.NoError(t, err) 33 | } 34 | 35 | func TestConnBegin(t *testing.T) { 36 | c := &conn{nil} 37 | tx, err := c.Begin() 38 | assert.Nil(t, tx) 39 | assert.EqualError(t, err, "not implemented") 40 | } 41 | 42 | func TestConnPrepare(t *testing.T) { 43 | m := new(mockDrillClient) 44 | m.Test(t) 45 | defer m.AssertExpectations(t) 46 | 47 | p := drill.PreparedHandle(5) 48 | m.On("PrepareQuery", "foobar").Return(p, nil) 49 | 50 | c := &conn{m} 51 | 52 | stmt, err := c.Prepare("foobar") 53 | assert.Equal(t, &prepared{stmt: p, client: m}, stmt) 54 | assert.NoError(t, err) 55 | } 56 | 57 | func TestConnPrepareErr(t *testing.T) { 58 | m := new(mockDrillClient) 59 | m.Test(t) 60 | defer m.AssertExpectations(t) 61 | 62 | p := drill.PreparedHandle(5) 63 | m.On("PrepareQuery", "foobar").Return(p, assert.AnError) 64 | 65 | c := &conn{m} 66 | stmt, err := c.Prepare("foobar") 67 | assert.Nil(t, stmt) 68 | assert.Same(t, assert.AnError, err) 69 | } 70 | 71 | func TestConnQueryContextNoPrep(t *testing.T) { 72 | c := &conn{} 73 | 74 | rows, err := c.QueryContext(context.Background(), "foo", []driver.NamedValue{{Name: "foo"}}) 75 | assert.Nil(t, rows) 76 | assert.Same(t, errNoPrepSupport, err) 77 | } 78 | 79 | func TestConnExecContextNoPrep(t *testing.T) { 80 | c := &conn{} 81 | 82 | rows, err := c.ExecContext(context.Background(), "foo", []driver.NamedValue{{Name: "foo"}}) 83 | assert.Nil(t, rows) 84 | assert.Same(t, errNoPrepSupport, err) 85 | } 86 | 87 | func TestConnQueryContextErr(t *testing.T) { 88 | m := new(mockDrillClient) 89 | m.Test(t) 90 | defer m.AssertExpectations(t) 91 | 92 | m.On("SubmitQuery", drill.TypeSQL, "foobar").Return((*drill.ResultHandle)(nil), assert.AnError) 93 | 94 | c := &conn{m} 95 | rows, err := c.QueryContext(context.Background(), "foobar", []driver.NamedValue{}) 96 | assert.Nil(t, rows) 97 | assert.Same(t, driver.ErrBadConn, err) 98 | } 99 | 100 | func TestConnExecContextErr(t *testing.T) { 101 | m := new(mockDrillClient) 102 | m.Test(t) 103 | defer m.AssertExpectations(t) 104 | 105 | m.On("SubmitQuery", drill.TypeSQL, "foobar").Return((*drill.ResultHandle)(nil), assert.AnError) 106 | 107 | c := &conn{m} 108 | rows, err := c.ExecContext(context.Background(), "foobar", []driver.NamedValue{}) 109 | assert.Nil(t, rows) 110 | assert.Same(t, driver.ErrBadConn, err) 111 | } 112 | 113 | type mockResHandle struct { 114 | mock.Mock 115 | } 116 | 117 | func (m *mockResHandle) Cancel() { m.Called() } 118 | func (m *mockResHandle) Close() error { return m.Called().Error(0) } 119 | func (m *mockResHandle) GetCols() []string { return m.Called().Get(0).([]string) } 120 | func (m *mockResHandle) GetRecordBatch() drill.RowBatch { 121 | ret := m.Called().Get(0) 122 | if ret == nil { 123 | return nil 124 | } 125 | return ret.(drill.RowBatch) 126 | } 127 | func (m *mockResHandle) Next() (drill.RowBatch, error) { 128 | args := m.Called() 129 | if args.Get(1) == nil { 130 | return nil, args.Error(0) 131 | } 132 | return args.Get(1).(drill.RowBatch), args.Error(0) 133 | } 134 | 135 | func TestConnQueryContextCtxTimeout(t *testing.T) { 136 | m := new(mockDrillClient) 137 | m.Test(t) 138 | defer m.AssertExpectations(t) 139 | 140 | mr := new(mockResHandle) 141 | mr.Test(t) 142 | defer mr.AssertExpectations(t) 143 | 144 | m.On("SubmitQuery", drill.TypeSQL, "foobar").Return(mr, nil) 145 | 146 | waiter := make(chan time.Time) 147 | mr.On("Cancel").Run(func(mock.Arguments) { 148 | waiter <- time.Now() 149 | }) 150 | mr.On("Next").WaitUntil(waiter).Return(assert.AnError, (drill.RowBatch)(nil)) 151 | 152 | c := &conn{m} 153 | 154 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 155 | defer cancel() 156 | 157 | _, err := c.QueryContext(ctx, "foobar", []driver.NamedValue{}) 158 | assert.Same(t, assert.AnError, err) 159 | } 160 | 161 | func TestConnExecContextCtxTimeout(t *testing.T) { 162 | m := new(mockDrillClient) 163 | m.Test(t) 164 | defer m.AssertExpectations(t) 165 | 166 | mr := new(mockResHandle) 167 | mr.Test(t) 168 | defer mr.AssertExpectations(t) 169 | 170 | m.On("SubmitQuery", drill.TypeSQL, "foobar").Return(mr, nil) 171 | 172 | waiter := make(chan time.Time) 173 | mr.On("Cancel").Run(func(mock.Arguments) { 174 | waiter <- time.Now() 175 | }) 176 | mr.On("Next").WaitUntil(waiter).Return(assert.AnError, (drill.RowBatch)(nil)) 177 | 178 | c := &conn{m} 179 | 180 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 181 | defer cancel() 182 | 183 | r, err := c.ExecContext(ctx, "foobar", []driver.NamedValue{}) 184 | assert.NoError(t, err) 185 | 186 | _, err = r.RowsAffected() 187 | assert.Same(t, assert.AnError, err) 188 | } 189 | 190 | func TestConnQueryContext(t *testing.T) { 191 | m := new(mockDrillClient) 192 | m.Test(t) 193 | defer m.AssertExpectations(t) 194 | 195 | mr := new(mockResHandle) 196 | mr.Test(t) 197 | defer mr.AssertExpectations(t) 198 | 199 | m.On("SubmitQuery", drill.TypeSQL, "foobar").Return(mr, nil) 200 | mr.On("Next").After(100*time.Millisecond).Return(nil, (drill.RowBatch)(nil)) 201 | 202 | c := &conn{m} 203 | r, err := c.QueryContext(context.Background(), "foobar", []driver.NamedValue{}) 204 | assert.NoError(t, err) 205 | assert.NotNil(t, r) 206 | assert.IsType(t, &rows{}, r) 207 | } 208 | 209 | func TestConnExecContext(t *testing.T) { 210 | m := new(mockDrillClient) 211 | m.Test(t) 212 | defer m.AssertExpectations(t) 213 | 214 | mr := new(mockResHandle) 215 | mr.Test(t) 216 | defer mr.AssertExpectations(t) 217 | 218 | m.On("SubmitQuery", drill.TypeSQL, "foobar").Return(mr, nil) 219 | 220 | rb := new(mockBatch) 221 | rb.On("AffectedRows").Return(5) 222 | 223 | mr.On("Next").Return(nil, rb).Twice() 224 | mr.On("Next").Return(io.EOF, (drill.RowBatch)(nil)) 225 | 226 | c := &conn{m} 227 | r, err := c.ExecContext(context.Background(), "foobar", []driver.NamedValue{}) 228 | assert.NoError(t, err) 229 | num, err := r.RowsAffected() 230 | assert.NoError(t, err) 231 | assert.EqualValues(t, 10, num) 232 | } 233 | 234 | func TestConnExecContextWithErr(t *testing.T) { 235 | m := new(mockDrillClient) 236 | m.Test(t) 237 | defer m.AssertExpectations(t) 238 | 239 | mr := new(mockResHandle) 240 | mr.Test(t) 241 | defer mr.AssertExpectations(t) 242 | 243 | m.On("SubmitQuery", drill.TypeSQL, "foobar").Return(mr, nil) 244 | 245 | rb := new(mockBatch) 246 | rb.On("AffectedRows").Return(5) 247 | 248 | mr.On("Next").Return(nil, rb).Twice() 249 | mr.On("Next").Return(assert.AnError, (drill.RowBatch)(nil)) 250 | 251 | c := &conn{m} 252 | r, err := c.ExecContext(context.Background(), "foobar", []driver.NamedValue{}) 253 | assert.NoError(t, err) 254 | num, err := r.RowsAffected() 255 | assert.Same(t, assert.AnError, err) 256 | assert.EqualValues(t, 10, num) 257 | } 258 | -------------------------------------------------------------------------------- /sasl/sasl.go: -------------------------------------------------------------------------------- 1 | // Package sasl provides the utilities for SASL authentication via gssapi 2 | package sasl 3 | 4 | import ( 5 | "bytes" 6 | "encoding/binary" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "time" 12 | 13 | "github.com/jcmturner/gokrb5/v8/gssapi" 14 | ) 15 | 16 | // SecurityProps simply contains settings used for the sasl negotiation. 17 | // 18 | // These are utilized by the gssapi mechanism in order to determine the QOP settings 19 | type SecurityProps struct { 20 | MinSsf uint32 21 | MaxSsf uint32 22 | MaxBufSize int32 23 | UseEncryption bool 24 | } 25 | 26 | // the current state of our authentication 27 | const ( 28 | saslAuthInit = iota 29 | saslAuthNeg 30 | saslAuthSsf 31 | saslAuthComplete 32 | ) 33 | 34 | // Wrapper is the primary interface for sasl-gssapi handling. 35 | // 36 | // A wrapper is returned from NewSaslWrapper which will allow performing authentication 37 | // and then wrapping a desired connection to properly wrap and unwrap messages. 38 | type Wrapper interface { 39 | // InitAuthPayload initializes the local security context and returns a payload 40 | // for sending the initial token for negotiation. 41 | InitAuthPayload() ([]byte, error) 42 | // Step takes the responses from the server (eg. auth challenges) and steps through 43 | // the authentication and negotiation protocols, returning the next payload response 44 | // to send to the server as long as the gssapi.Status is gssapi.StatusContinueNeeded. 45 | // When authentication is complete, the status will be gssapi.StatusComplete. Any other 46 | // status will come associated with an error 47 | Step([]byte) ([]byte, gssapi.Status) 48 | // GetWrappedConn takes the provided connection and wraps it such that anything written 49 | // to or read from the connection will be put through the wrap/unwrap calls of the 50 | // sasl authentication based on the negotiated security context. 51 | GetWrappedConn(net.Conn) net.Conn 52 | } 53 | 54 | type saslwrapper struct { 55 | Props SecurityProps 56 | mech gssapi.Mechanism 57 | ct gssapi.ContextToken 58 | 59 | state byte 60 | } 61 | 62 | // NewSaslWrapper takes the provided SPNs and SecurityProps to provide a Wrapper 63 | // that will perform GSSAPI authentication via kerberos krb5 64 | func NewSaslWrapper(userSpn, serviceSpn string, props SecurityProps) (Wrapper, error) { 65 | krbClient, err := getKrbClient(userSpn) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | return &saslwrapper{ 71 | Props: props, 72 | mech: NewGSSAPIKrb5Mech(krbClient, serviceSpn, props), 73 | state: saslAuthInit, 74 | }, nil 75 | } 76 | 77 | func (s *saslwrapper) InitAuthPayload() ([]byte, error) { 78 | var err error 79 | if s.state != saslAuthInit { 80 | return nil, errors.New("invalid sasl auth state") 81 | } 82 | 83 | s.ct, err = s.mech.InitSecContext() 84 | if err != nil { 85 | return nil, err 86 | } 87 | 88 | s.state = saslAuthNeg 89 | return s.ct.Marshal() 90 | } 91 | 92 | // chooseQop returns both the Qop bytes and the chosen ssf value 93 | func (s *saslwrapper) chooseQop(mechSsf uint32, serverBitMask Qop) (uint32, Qop) { 94 | qop := QopNone 95 | if s.Props.MaxSsf > 0 { 96 | qop |= QopIntegrity 97 | } 98 | if s.Props.MaxSsf > 1 { 99 | qop |= QopConf 100 | } 101 | 102 | var allowed uint32 = s.Props.MaxSsf 103 | var need uint32 = s.Props.MinSsf 104 | 105 | if qop&QopConf != 0 && allowed >= mechSsf && need <= mechSsf && serverBitMask&QopConf != 0 { 106 | if serverBitMask&QopIntegrity != 0 { 107 | return mechSsf, QopIntegrity | QopConf 108 | } 109 | return mechSsf, QopConf 110 | } else if qop&QopIntegrity != 0 && allowed >= 1 && need <= 1 && serverBitMask&QopIntegrity != 0 { 111 | return 1, QopIntegrity 112 | } else if qop&QopNone != 0 && need <= 0 && serverBitMask&QopNone != 0 { 113 | return 0, QopNone 114 | } else { 115 | return 0, 0 116 | } 117 | } 118 | 119 | func (s *saslwrapper) Step(b []byte) ([]byte, gssapi.Status) { 120 | switch s.state { 121 | case saslAuthNeg: 122 | // handle response from InitAuthPayload 123 | if err := s.ct.Unmarshal(b); err != nil { 124 | return nil, gssapi.Status{Code: gssapi.StatusDefectiveCredential, Message: err.Error()} 125 | } 126 | 127 | // next step is negotating the ssf value 128 | s.state = saslAuthSsf 129 | 130 | _, st := s.ct.Verify() 131 | return nil, st 132 | case saslAuthSsf: 133 | var nntoken gssapi.WrapToken 134 | // our ssf negotiation will not have the payload encrypted 135 | if err := nntoken.Unmarshal(b, true); err != nil { 136 | return nil, gssapi.Status{Code: gssapi.StatusDefectiveToken, Message: err.Error()} 137 | } 138 | 139 | if err := VerifyWrapToken(s.ct, nntoken); err != nil { 140 | return nil, gssapi.Status{Code: gssapi.StatusBadSig, Message: err.Error()} 141 | } 142 | 143 | unwrapped := s.mech.Unwrap(nntoken) 144 | if len(unwrapped) != 4 { 145 | return nil, gssapi.Status{Code: gssapi.StatusDefectiveToken, Message: fmt.Sprintf("token invalid, should be 4 bytes, not %+v", len(unwrapped))} 146 | } 147 | 148 | mechSsf := GetSsf(s.ct) 149 | 150 | if s.Props.MinSsf > mechSsf { 151 | return nil, gssapi.Status{Code: gssapi.StatusBadMech, Message: "sasl too weak"} 152 | } else if s.Props.MinSsf > s.Props.MaxSsf { 153 | return nil, gssapi.Status{Code: gssapi.StatusBadMech, Message: "sasl bad param"} 154 | } 155 | 156 | // the 4 bytes we got back should be: 157 | // [0] == server Qop Bitmask 158 | // [1-3] == BigEndian encoded uint32 server max buffer size for a single payload 159 | 160 | var qop Qop 161 | mechSsf, qop = s.chooseQop(mechSsf, Qop(unwrapped[0])) 162 | 163 | maxOutBuf := CalcMaxOutputSize(mechSsf, binary.BigEndian.Uint32(append([]byte{0x00}, unwrapped[1:]...)), s.ct) 164 | s.Props.MaxBufSize = int32(maxOutBuf) 165 | 166 | // our response should be formatted the same way: 167 | // replacing the first byte with the desired QOP value 168 | out := nntoken.Payload 169 | binary.BigEndian.PutUint32(out, maxOutBuf) 170 | out[0] = byte(qop) 171 | 172 | token := s.mech.Wrap(out) 173 | // update our auth context with the chosen qop value *after* we wrap our response 174 | // since the response should not be encrypted even if our future communications will be 175 | SetQOP(s.ct, qop) 176 | s.state = saslAuthComplete 177 | data, err := token.Marshal() 178 | if err != nil { 179 | return nil, gssapi.Status{Code: gssapi.StatusDefectiveToken, Message: err.Error()} 180 | } 181 | 182 | return data, gssapi.Status{Code: gssapi.StatusComplete} 183 | default: 184 | return nil, gssapi.Status{Code: gssapi.StatusBadStatus} 185 | } 186 | } 187 | 188 | func (s *saslwrapper) GetWrappedConn(conn net.Conn) net.Conn { 189 | return &gssapiWrappedConn{ 190 | conn: conn, 191 | sasl: s, 192 | } 193 | } 194 | 195 | type gssapiWrappedConn struct { 196 | conn net.Conn 197 | sasl *saslwrapper 198 | 199 | readBuf bytes.Buffer 200 | } 201 | 202 | // fulfill the net.Conn interface 203 | 204 | func (g *gssapiWrappedConn) Close() error { 205 | return g.conn.Close() 206 | } 207 | 208 | func (g *gssapiWrappedConn) LocalAddr() net.Addr { 209 | return g.conn.LocalAddr() 210 | } 211 | 212 | func (g *gssapiWrappedConn) RemoteAddr() net.Addr { 213 | return g.conn.RemoteAddr() 214 | } 215 | 216 | func (g *gssapiWrappedConn) SetDeadline(t time.Time) error { 217 | return g.conn.SetDeadline(t) 218 | } 219 | 220 | func (g *gssapiWrappedConn) SetReadDeadline(t time.Time) error { 221 | return g.conn.SetReadDeadline(t) 222 | } 223 | 224 | func (g *gssapiWrappedConn) SetWriteDeadline(t time.Time) error { 225 | return g.conn.SetWriteDeadline(t) 226 | } 227 | 228 | func (g *gssapiWrappedConn) Read(b []byte) (int, error) { 229 | // use an internal buffer here 230 | n, err := g.readBuf.Read(b) 231 | if len(b) == n || (err != nil && err != io.EOF) { 232 | return n, err 233 | } 234 | 235 | var sz uint32 236 | if err := binary.Read(g.conn, binary.BigEndian, &sz); err != nil { 237 | return n, err 238 | } 239 | 240 | g.readBuf.Reset() 241 | g.readBuf.Grow(int(sz)) 242 | _, err = io.CopyN(&g.readBuf, g.conn, int64(sz)) 243 | if err != nil { 244 | return n, err 245 | } 246 | 247 | var token gssapi.WrapToken 248 | if err = token.Unmarshal(g.readBuf.Bytes(), true); err != nil { 249 | return n, err 250 | } 251 | 252 | g.readBuf.Reset() 253 | g.readBuf.Write(g.sasl.mech.Unwrap(token)) 254 | 255 | return g.readBuf.Read(b) 256 | } 257 | 258 | func (g *gssapiWrappedConn) Write(b []byte) (int, error) { 259 | token := g.sasl.mech.Wrap(b) 260 | data, err := token.Marshal() 261 | if err != nil { 262 | return 0, err 263 | } 264 | 265 | hdr := make([]byte, 4) 266 | binary.BigEndian.PutUint32(hdr, uint32(len(data))) 267 | return g.conn.Write(append(hdr, data...)) 268 | } 269 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | package drill 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "io" 7 | "testing" 8 | 9 | "github.com/apache/arrow/go/v7/arrow" 10 | "github.com/factset/go-drill/internal/rpc/proto/exec/rpc" 11 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 12 | "github.com/factset/go-drill/internal/rpc/proto/exec/user" 13 | "github.com/golang/protobuf/proto" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | var deadbeef = []byte{0xDE, 0xAD, 0xBE, 0xEF} 18 | 19 | func TestRpcEncoderWriteRaw(t *testing.T) { 20 | m := new(mockConn) 21 | m.Test(t) 22 | 23 | m.On("Write", makePrefixedMessage(deadbeef)).Return(len(deadbeef), nil) 24 | 25 | val, err := rpcEncoder{}.WriteRaw(m, deadbeef) 26 | assert.Equal(t, len(deadbeef), val) 27 | assert.NoError(t, err) 28 | m.AssertExpectations(t) 29 | } 30 | 31 | func TestRpcEncoderWriteRawErr(t *testing.T) { 32 | m := new(mockConn) 33 | m.Test(t) 34 | 35 | m.On("Write", makePrefixedMessage(deadbeef)).Return(0, assert.AnError) 36 | val, err := rpcEncoder{}.WriteRaw(m, deadbeef) 37 | assert.Equal(t, 0, val) 38 | assert.Same(t, assert.AnError, err) 39 | m.AssertExpectations(t) 40 | } 41 | 42 | func TestRpcEncoderWriteMsg(t *testing.T) { 43 | m := new(mockConn) 44 | m.Test(t) 45 | 46 | enc, _ := encodeRPCMessage(rpc.RpcMode_PING, user.RpcType_ACK, 1, &shared.QueryId{}) 47 | 48 | m.On("Write", makePrefixedMessage(enc)).Return(len(enc), nil) 49 | val, err := rpcEncoder{}.Write(m, rpc.RpcMode_PING, user.RpcType_ACK, 1, &shared.QueryId{}) 50 | assert.Equal(t, len(enc), val) 51 | assert.NoError(t, err) 52 | m.AssertExpectations(t) 53 | } 54 | 55 | func TestRpcEncoderWriteMsgErr(t *testing.T) { 56 | m := new(mockConn) 57 | m.Test(t) 58 | 59 | enc, _ := encodeRPCMessage(rpc.RpcMode_PING, user.RpcType_ACK, 1, &shared.QueryId{}) 60 | 61 | m.On("Write", makePrefixedMessage(enc)).Return(0, assert.AnError) 62 | val, err := rpcEncoder{}.Write(m, rpc.RpcMode_PING, user.RpcType_ACK, 1, &shared.QueryId{}) 63 | assert.Equal(t, 0, val) 64 | assert.Same(t, assert.AnError, err) 65 | m.AssertExpectations(t) 66 | } 67 | 68 | func TestRpcEncoderReadRaw(t *testing.T) { 69 | m := new(mockConn) 70 | m.Test(t) 71 | 72 | msg := &rpc.CompleteRpcMessage{ 73 | Header: &rpc.RpcHeader{ 74 | Mode: rpc.RpcMode_PING.Enum(), 75 | CoordinationId: proto.Int(5), 76 | RpcType: proto.Int(5), 77 | }, 78 | ProtobufBody: deadbeef, 79 | } 80 | 81 | data, _ := proto.Marshal(msg) 82 | buf := make([]byte, binary.MaxVarintLen32) 83 | nb := binary.PutUvarint(buf, uint64(len(data))) 84 | m.r = bytes.NewReader(append(buf[:nb], data...)) 85 | 86 | m.On("Read") 87 | out, err := rpcEncoder{}.ReadRaw(m) 88 | assert.NoError(t, err) 89 | 90 | assert.Equal(t, msg.Header.GetMode(), out.Header.GetMode()) 91 | assert.Equal(t, msg.Header.GetCoordinationId(), out.Header.GetCoordinationId()) 92 | assert.Equal(t, msg.Header.GetRpcType(), out.Header.GetRpcType()) 93 | assert.Equal(t, msg.GetProtobufBody(), out.GetProtobufBody()) 94 | m.AssertExpectations(t) 95 | } 96 | 97 | func TestRpcEncoderReadMsg(t *testing.T) { 98 | qid := &shared.QueryId{Part1: proto.Int64(12345), Part2: proto.Int64(98765)} 99 | encoded, _ := proto.Marshal(qid) 100 | msg := &rpc.CompleteRpcMessage{ 101 | Header: &rpc.RpcHeader{ 102 | Mode: rpc.RpcMode_PING.Enum(), 103 | CoordinationId: proto.Int(5), 104 | RpcType: proto.Int(5), 105 | }, 106 | ProtobufBody: encoded, 107 | } 108 | 109 | data, _ := proto.Marshal(msg) 110 | buf := make([]byte, binary.MaxVarintLen32) 111 | nb := binary.PutUvarint(buf, uint64(len(data))) 112 | 113 | m := new(mockConn) 114 | m.Test(t) 115 | m.r = bytes.NewReader(append(buf[:nb], data...)) 116 | 117 | m.On("Read") 118 | out := &shared.QueryId{} 119 | hdr, err := rpcEncoder{}.ReadMsg(m, out) 120 | assert.NoError(t, err) 121 | 122 | assert.Equal(t, msg.Header.GetMode(), hdr.GetMode()) 123 | assert.Equal(t, msg.Header.GetCoordinationId(), hdr.GetCoordinationId()) 124 | assert.Equal(t, msg.Header.GetRpcType(), hdr.GetRpcType()) 125 | assert.Equal(t, qid.GetPart1(), out.GetPart1()) 126 | assert.Equal(t, qid.GetPart2(), out.GetPart2()) 127 | } 128 | 129 | func TestMakePrefixedMessage(t *testing.T) { 130 | out := makePrefixedMessage(deadbeef) 131 | val, nb := binary.Uvarint(out) 132 | 133 | assert.EqualValues(t, len(deadbeef), val) 134 | assert.Len(t, out, len(deadbeef)+nb) 135 | } 136 | 137 | func TestMakePrefixedNil(t *testing.T) { 138 | assert.Nil(t, makePrefixedMessage(nil)) 139 | } 140 | 141 | func TestReadPrefixedSimple(t *testing.T) { 142 | buf := make([]byte, binary.MaxVarintLen32) 143 | nb := binary.PutUvarint(buf, 4) 144 | 145 | out, err := readPrefixed(bytes.NewReader(append(buf[:nb], deadbeef...))) 146 | assert.NoError(t, err) 147 | assert.EqualValues(t, deadbeef, out) 148 | } 149 | 150 | func TestReadPrefixedEof(t *testing.T) { 151 | buf := &bytes.Reader{} 152 | out, err := readPrefixed(buf) 153 | assert.Nil(t, out) 154 | assert.Same(t, io.ErrUnexpectedEOF, err) 155 | } 156 | 157 | func TestReadPrefixedShortRead(t *testing.T) { 158 | buf := []byte{0x01} 159 | out, err := readPrefixed(bytes.NewReader(buf)) 160 | assert.Nil(t, out) 161 | assert.Same(t, io.ErrUnexpectedEOF, err) 162 | } 163 | 164 | func TestReadPrefixedEmpty(t *testing.T) { 165 | buf := []byte{0, 0, 0, 0, 0} 166 | out, err := readPrefixed(bytes.NewBuffer(buf)) 167 | assert.Nil(t, out) 168 | assert.Same(t, errInvalidResponse, err) 169 | } 170 | 171 | func TestReadPrefixedNotEnough(t *testing.T) { 172 | buf := make([]byte, binary.MaxVarintLen32) 173 | nb := binary.PutUvarint(buf, 6) 174 | 175 | out, err := readPrefixed(bytes.NewReader(append(buf[:nb], deadbeef...))) 176 | assert.Nil(t, out) 177 | assert.Same(t, io.ErrUnexpectedEOF, err) 178 | } 179 | 180 | func TestReadPrefixedRaw(t *testing.T) { 181 | msg := &rpc.CompleteRpcMessage{ 182 | Header: &rpc.RpcHeader{ 183 | Mode: rpc.RpcMode_PING.Enum(), 184 | CoordinationId: proto.Int(5), 185 | RpcType: proto.Int(5), 186 | }, 187 | ProtobufBody: deadbeef, 188 | } 189 | 190 | data, _ := proto.Marshal(msg) 191 | buf := make([]byte, binary.MaxVarintLen32) 192 | nb := binary.PutUvarint(buf, uint64(len(data))) 193 | 194 | out, err := readPrefixedRaw(bytes.NewReader(append(buf[:nb], data...))) 195 | assert.NoError(t, err) 196 | 197 | assert.Equal(t, msg.Header.GetMode(), out.Header.GetMode()) 198 | assert.Equal(t, msg.Header.GetCoordinationId(), out.Header.GetCoordinationId()) 199 | assert.Equal(t, msg.Header.GetRpcType(), out.Header.GetRpcType()) 200 | assert.Equal(t, msg.GetProtobufBody(), out.GetProtobufBody()) 201 | } 202 | 203 | func TestReadPrefixedRawErr(t *testing.T) { 204 | buf := &bytes.Reader{} 205 | out, err := readPrefixedRaw(buf) 206 | assert.Nil(t, out) 207 | assert.Error(t, err) 208 | } 209 | 210 | func TestReadPrefixedMessage(t *testing.T) { 211 | qid := &shared.QueryId{Part1: proto.Int64(12345), Part2: proto.Int64(98765)} 212 | encoded, _ := proto.Marshal(qid) 213 | msg := &rpc.CompleteRpcMessage{ 214 | Header: &rpc.RpcHeader{ 215 | Mode: rpc.RpcMode_PING.Enum(), 216 | CoordinationId: proto.Int(5), 217 | RpcType: proto.Int(5), 218 | }, 219 | ProtobufBody: encoded, 220 | } 221 | 222 | data, _ := proto.Marshal(msg) 223 | buf := make([]byte, binary.MaxVarintLen32) 224 | nb := binary.PutUvarint(buf, uint64(len(data))) 225 | 226 | out := &shared.QueryId{} 227 | hdr, err := readPrefixedMessage(bytes.NewReader(append(buf[:nb], data...)), out) 228 | assert.NoError(t, err) 229 | 230 | assert.Equal(t, msg.Header.GetMode(), hdr.GetMode()) 231 | assert.Equal(t, msg.Header.GetCoordinationId(), hdr.GetCoordinationId()) 232 | assert.Equal(t, msg.Header.GetRpcType(), hdr.GetRpcType()) 233 | assert.Equal(t, qid.GetPart1(), out.GetPart1()) 234 | assert.Equal(t, qid.GetPart2(), out.GetPart2()) 235 | } 236 | 237 | func TestReadPrefixedMessageErr(t *testing.T) { 238 | buf := &bytes.Reader{} 239 | out, err := readPrefixedMessage(buf, nil) 240 | assert.Nil(t, out) 241 | assert.Error(t, err) 242 | } 243 | 244 | func TestArrowFromColPanic(t *testing.T) { 245 | c := &user.ColumnMetadata{} 246 | assert.Panics(t, func() { ColMetaToArrowField(c) }) 247 | } 248 | 249 | func TestArrowFromCol(t *testing.T) { 250 | tests := []struct { 251 | name string 252 | typ string 253 | expected arrow.DataType 254 | nullable bool 255 | def string 256 | }{ 257 | {"bool", "BOOLEAN", arrow.FixedWidthTypes.Boolean, true, "false"}, 258 | {"binary", "BINARY VARYING", arrow.BinaryTypes.Binary, false, ""}, 259 | {"varchar", "CHARACTER VARYING", arrow.BinaryTypes.String, true, "foo"}, 260 | {"integer", "INTEGER", arrow.PrimitiveTypes.Int32, false, "1"}, 261 | {"int64", "BIGINT", arrow.PrimitiveTypes.Int64, true, "123456"}, 262 | {"int16", "SMALLINT", arrow.PrimitiveTypes.Int16, false, "65535"}, 263 | {"tinyint", "TINYINT", arrow.PrimitiveTypes.Int8, true, "1"}, 264 | {"date", "DATE", arrow.FixedWidthTypes.Date64, false, "1987-08-04"}, 265 | {"time", "TIME", arrow.FixedWidthTypes.Time32ms, true, "12:30PM"}, 266 | {"float", "FLOAT", arrow.PrimitiveTypes.Float32, false, "1.2"}, 267 | {"double", "DOUBLE", arrow.PrimitiveTypes.Float64, false, "1.2"}, 268 | {"timestamp", "TIMESTAMP", arrow.FixedWidthTypes.Timestamp_ms, true, "123456789"}, 269 | } 270 | 271 | for _, tt := range tests { 272 | t.Run(tt.name, func(t *testing.T) { 273 | c := &user.ColumnMetadata{ 274 | ColumnName: &tt.name, 275 | DataType: &tt.typ, 276 | IsNullable: &tt.nullable, 277 | DefaultValue: &tt.def, 278 | } 279 | 280 | f := ColMetaToArrowField(c) 281 | assert.True(t, f.Equal(arrow.Field{ 282 | Name: tt.name, 283 | Type: tt.expected, 284 | Nullable: tt.nullable, 285 | Metadata: arrow.NewMetadata([]string{"default"}, []string{tt.def}), 286 | })) 287 | }) 288 | } 289 | } 290 | -------------------------------------------------------------------------------- /sasl/gssapi.go: -------------------------------------------------------------------------------- 1 | package sasl 2 | 3 | import ( 4 | "context" 5 | "math" 6 | "math/rand" 7 | 8 | "github.com/jcmturner/gofork/encoding/asn1" 9 | "github.com/jcmturner/gokrb5/v8/client" 10 | "github.com/jcmturner/gokrb5/v8/crypto" 11 | "github.com/jcmturner/gokrb5/v8/gssapi" 12 | "github.com/jcmturner/gokrb5/v8/iana/keyusage" 13 | "github.com/jcmturner/gokrb5/v8/messages" 14 | "github.com/jcmturner/gokrb5/v8/spnego" 15 | "github.com/jcmturner/gokrb5/v8/types" 16 | ) 17 | 18 | // based on krb5_gssapi_encrypt_length 19 | func getEncryptSize(key types.EncryptionKey, len uint32) uint32 { 20 | etyp, _ := crypto.GetEtype(key.KeyType) 21 | return uint32(etyp.GetConfounderByteSize()+etyp.GetHMACBitLength()/8) + len 22 | } 23 | 24 | func genSeqNumber() uint64 { 25 | // Work around implementation incompatibilities by not generating 26 | // initial sequence numbers greater than 2^30. Previous MIT 27 | // implementations use signed sequence numbers, so initial 28 | // sequence numbers 2^31 to 2^32-1 inclusive will be rejected. 29 | // Letting the maximum initial sequence number be 2^30-1 allows 30 | // for about 2^30 messages to be sent before wrapping into 31 | // "negative" numbers. 32 | return uint64(rand.Int63n(int64(math.Pow(2, 30)))) 33 | } 34 | 35 | // NewGSSAPIKrb5Mech constructs a mechanism for gssapi processing using Kerberos via krb5 36 | func NewGSSAPIKrb5Mech(cl *client.Client, spn string, saslProps SecurityProps) gssapi.Mechanism { 37 | return &gssapiKrb5Mech{cl: cl, spn: spn, saslProps: saslProps} 38 | } 39 | 40 | type gssapiKrb5Mech struct { 41 | cl *client.Client 42 | spn string 43 | 44 | saslProps SecurityProps 45 | ctx authContext 46 | } 47 | 48 | // Qop is a bitmask representing the current Quality of Protection settings 49 | type Qop byte 50 | 51 | // Qop will be some combination of none / integrity / confidential 52 | const ( 53 | QopNone Qop = 1 << iota 54 | QopIntegrity 55 | QopConf 56 | ) 57 | 58 | // opaque context object for internal handling 59 | type authContext struct { 60 | key types.EncryptionKey 61 | remoteSeqNum int64 62 | subKey types.EncryptionKey 63 | qop Qop 64 | localSeqNum uint64 65 | } 66 | 67 | // opaque token that fulfills the interface defined in gssapi.ContextToken 68 | type gssapiKrb5Token struct { 69 | krb5Tok spnego.KRB5Token 70 | 71 | ctx *authContext 72 | } 73 | 74 | func (g *gssapiKrb5Token) Marshal() ([]byte, error) { 75 | return g.krb5Tok.Marshal() 76 | } 77 | 78 | func (g *gssapiKrb5Token) Unmarshal(b []byte) error { 79 | return g.krb5Tok.Unmarshal(b) 80 | } 81 | 82 | type contextKey int 83 | 84 | const ( 85 | ctxAuthCtx contextKey = iota 86 | ) 87 | 88 | // VerifyWrapToken allows calling Verify on the token without having to expose 89 | // the encryption key that the context token is holding onto. 90 | func VerifyWrapToken(ct gssapi.ContextToken, wt gssapi.WrapToken) error { 91 | key := ct.Context().Value(ctxAuthCtx).(*authContext).key 92 | _, err := wt.Verify(key, keyusage.GSSAPI_ACCEPTOR_SEAL) 93 | return err 94 | } 95 | 96 | // GetSsf uses the opaque context in the token in order to pull the key and return 97 | // the Security Strength Factor (ssf) value for the given key. 98 | func GetSsf(ct gssapi.ContextToken) uint32 { 99 | key := ct.Context().Value(ctxAuthCtx).(*authContext).key 100 | etyp, _ := crypto.GetEtype(key.KeyType) 101 | return uint32(etyp.GetKeySeedBitLength()) 102 | } 103 | 104 | // CalcMaxOutputSize uses the determined SSF value and provided max buffer size 105 | // combined with the encryption key in the token to figure out what the actual 106 | // max size can be such that the resulting size after encryption will still be 107 | // within the provided maxOutBuf. 108 | // 109 | // As per the general SASL definitions, if the SSF is <= 0, then we wouldn't be 110 | // encrypting the buffer, and just return the maxOutBuf that was passed in. If 111 | // mechSsf > 0, then we grab the key and figure out what size will encrypt to 112 | // a size smaller than the passed in maxOutBuf while also giving room for the 16 113 | // byte token header. 114 | func CalcMaxOutputSize(mechSsf, maxOutBuf uint32, ct gssapi.ContextToken) uint32 { 115 | if mechSsf > 0 { 116 | key := ct.Context().Value(ctxAuthCtx).(*authContext).key 117 | 118 | sz := maxOutBuf 119 | for sz > 0 && getEncryptSize(key, sz+16) > maxOutBuf { 120 | sz-- 121 | } 122 | 123 | if sz > 0 { 124 | sz -= 16 125 | } else { 126 | sz = 0 127 | } 128 | return sz 129 | } 130 | return maxOutBuf 131 | } 132 | 133 | // SetQOP will set the desired Qop value into the opaque token value 134 | func SetQOP(ct gssapi.ContextToken, qop Qop) { 135 | ct.Context().Value(ctxAuthCtx).(*authContext).qop = qop 136 | } 137 | 138 | func (g *gssapiKrb5Token) Verify() (bool, gssapi.Status) { 139 | valid, st := g.krb5Tok.Verify() 140 | if !valid && st.Code == gssapi.StatusFailure { 141 | // gokrb5 doesn't verify APREP yet, but i can! 142 | b, err := crypto.DecryptEncPart(g.krb5Tok.APRep.EncPart, g.ctx.key, keyusage.AP_REP_ENCPART) 143 | if err != nil { 144 | return false, gssapi.Status{Code: gssapi.StatusDefectiveCredential, Message: "Could not decrypt APRep"} 145 | } 146 | 147 | var denc messages.EncAPRepPart 148 | if err = denc.Unmarshal(b); err != nil { 149 | return false, gssapi.Status{Code: gssapi.StatusFailure, Message: err.Error()} 150 | } 151 | 152 | g.ctx.remoteSeqNum = denc.SequenceNumber 153 | g.ctx.subKey = denc.Subkey 154 | // TODO: use denc.CTime and denc.Cusec to verify no clock skew 155 | return true, gssapi.Status{Code: gssapi.StatusContinueNeeded} 156 | } 157 | return valid, st 158 | } 159 | 160 | // Context will return a context.Context that also contains the opaque auth context 161 | // embedded in it so that it can be used and passed around 162 | func (g *gssapiKrb5Token) Context() context.Context { 163 | ctx := g.krb5Tok.Context() 164 | if ctx == nil { 165 | ctx = context.Background() 166 | } 167 | return context.WithValue(ctx, ctxAuthCtx, g.ctx) 168 | } 169 | 170 | // getCtxFlags will provide the list of flags to pass to gssapi creation 171 | // based on the sasl props to determine whether or not we want to use 172 | // integrity checking and/or confidentiality 173 | func (g *gssapiKrb5Mech) getCtxFlags() []int { 174 | // leave out the Sequence flag for now until i figure out why new versions of the krb server 175 | // don't like how I'm constructing my sequence numbers 176 | ret := []int{gssapi.ContextFlagMutual /*, gssapi.ContextFlagSequence*/} 177 | if g.saslProps.UseEncryption { 178 | ret = append(ret, gssapi.ContextFlagConf, gssapi.ContextFlagInteg) 179 | return ret 180 | } 181 | 182 | if g.saslProps.MaxSsf > 0 { 183 | ret = append(ret, gssapi.ContextFlagInteg) 184 | if g.saslProps.MaxSsf > 1 { 185 | ret = append(ret, gssapi.ContextFlagConf) 186 | } 187 | } 188 | return ret 189 | } 190 | 191 | func (gssapiKrb5Mech) OID() asn1.ObjectIdentifier { 192 | return gssapi.OIDKRB5.OID() 193 | } 194 | 195 | func (g *gssapiKrb5Mech) AcquireCred() error { 196 | return g.cl.AffirmLogin() 197 | } 198 | 199 | // AcceptSecContext currently is unimplemented beyond calling Verify on the token 200 | // this does not yet set up the local security context appropriately 201 | func (g *gssapiKrb5Mech) AcceptSecContext(ct gssapi.ContextToken) (bool, context.Context, gssapi.Status) { 202 | valid, st := ct.Verify() 203 | return valid, ct.Context(), st 204 | } 205 | 206 | // InitSecContext uses the spn we initialized with to perform Krb initialization 207 | // by grabbing the service ticket and doing an AP Exchange 208 | func (g *gssapiKrb5Mech) InitSecContext() (gssapi.ContextToken, error) { 209 | ticket, key, err := g.cl.GetServiceTicket(g.spn) 210 | if err != nil { 211 | return nil, err 212 | } 213 | 214 | g.ctx.key = key 215 | 216 | tok, err := spnego.NewKRB5TokenAPREQ(g.cl, ticket, key, g.getCtxFlags(), []int{}) 217 | if err != nil { 218 | return nil, err 219 | } 220 | 221 | g.ctx.localSeqNum = genSeqNumber() 222 | 223 | return &gssapiKrb5Token{krb5Tok: tok, ctx: &g.ctx}, nil 224 | } 225 | 226 | // MIC tokens are currently unimplemented 227 | func (g *gssapiKrb5Mech) MIC() gssapi.MICToken { 228 | return gssapi.MICToken{} 229 | } 230 | 231 | func (g *gssapiKrb5Mech) VerifyMIC(mt gssapi.MICToken) (bool, error) { 232 | return mt.Verify(g.ctx.key, keyusage.GSSAPI_ACCEPTOR_SEAL) 233 | } 234 | 235 | // Wrap will use the current QOP settings in order to determine whether or not 236 | // we'll actually encrypt the data and then returns the correct wrapped data 237 | func (g *gssapiKrb5Mech) Wrap(msg []byte) gssapi.WrapToken { 238 | var data []byte 239 | 240 | if (g.ctx.qop & QopConf) != 0 { 241 | tok := gssapi.WrapToken{ 242 | Flags: 0x02, 243 | EC: 0, 244 | RRC: 0, 245 | SndSeqNum: g.ctx.localSeqNum, 246 | CheckSum: make([]byte, 0), 247 | Payload: make([]byte, 0), 248 | } 249 | 250 | hdr, err := tok.Marshal() 251 | if err != nil { 252 | panic(err) 253 | } 254 | 255 | etyp, _ := crypto.GetEtype(g.ctx.key.KeyType) 256 | _, data, err = etyp.EncryptMessage(g.ctx.key.KeyValue, append(msg, hdr...), keyusage.GSSAPI_INITIATOR_SEAL) 257 | if err != nil { 258 | panic(err) 259 | } 260 | 261 | g.ctx.localSeqNum++ 262 | 263 | tok.Payload = data 264 | return tok 265 | } 266 | 267 | token, err := gssapi.NewInitiatorWrapToken(msg, g.ctx.key) 268 | if err != nil { 269 | panic(err) 270 | } 271 | return *token 272 | } 273 | 274 | // Unwrap will check the flags to determine whether or not we should decrypt the data 275 | // or just return the unwrapped payload 276 | func (g *gssapiKrb5Mech) Unwrap(wt gssapi.WrapToken) []byte { 277 | if wt.Flags&0x02 == 0 { 278 | return wt.Payload 279 | } 280 | 281 | decoded, err := crypto.DecryptMessage(wt.Payload, g.ctx.key, keyusage.GSSAPI_ACCEPTOR_SEAL) 282 | if err != nil { 283 | panic(err) 284 | } 285 | 286 | return decoded[:len(decoded)-16] 287 | } 288 | -------------------------------------------------------------------------------- /internal/data/date_time_vectors.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "math" 7 | "reflect" 8 | "time" 9 | 10 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 11 | ) 12 | 13 | type TimestampVector struct { 14 | *Int64Vector 15 | } 16 | 17 | func (TimestampVector) Type() reflect.Type { 18 | return reflect.TypeOf(time.Time{}) 19 | } 20 | 21 | func NewTimestampVector(data []byte, meta *shared.SerializedField) *TimestampVector { 22 | return &TimestampVector{ 23 | NewInt64Vector(data, meta), 24 | } 25 | } 26 | 27 | func (v *TimestampVector) Get(index uint) time.Time { 28 | ts := v.Int64Vector.Get(index) 29 | return time.Unix(ts/1000, ts%1000) 30 | } 31 | 32 | func (v *TimestampVector) Value(index uint) interface{} { 33 | return v.Get(index) 34 | } 35 | 36 | type DateVector struct { 37 | *TimestampVector 38 | } 39 | 40 | func (dv *DateVector) Get(index uint) time.Time { 41 | return dv.TimestampVector.Get(index).UTC() 42 | } 43 | 44 | func (dv *DateVector) Value(index uint) interface{} { 45 | return dv.Get(index) 46 | } 47 | 48 | func NewDateVector(data []byte, meta *shared.SerializedField) *DateVector { 49 | return &DateVector{NewTimestampVector(data, meta)} 50 | } 51 | 52 | type TimeVector struct { 53 | *Int32Vector 54 | } 55 | 56 | func (TimeVector) Type() reflect.Type { 57 | return reflect.TypeOf(time.Time{}) 58 | } 59 | 60 | func (t *TimeVector) Get(index uint) time.Time { 61 | ts := t.Int32Vector.Get(index) 62 | h, m, s := time.Unix(int64(ts/1000), int64(ts%1000)).UTC().Clock() 63 | return time.Date(0, 1, 1, h, m, s, 0, time.UTC) 64 | } 65 | 66 | func (t *TimeVector) Value(index uint) interface{} { 67 | return t.Get(index) 68 | } 69 | 70 | func NewTimeVector(date []byte, meta *shared.SerializedField) *TimeVector { 71 | return &TimeVector{NewInt32Vector(date, meta)} 72 | } 73 | 74 | type NullableTimestampVector struct { 75 | *NullableInt64Vector 76 | } 77 | 78 | func (NullableTimestampVector) Type() reflect.Type { 79 | return reflect.TypeOf(time.Time{}) 80 | } 81 | 82 | func (v *NullableTimestampVector) Get(index uint) *time.Time { 83 | ts := v.NullableInt64Vector.Get(index) 84 | if ts == nil { 85 | return nil 86 | } 87 | 88 | ret := time.Unix(*ts/1000, *ts%1000) 89 | return &ret 90 | } 91 | 92 | func (v *NullableTimestampVector) Value(index uint) interface{} { 93 | val := v.Get(index) 94 | if val != nil { 95 | return *val 96 | } 97 | 98 | return val 99 | } 100 | 101 | func NewNullableTimestampVector(data []byte, meta *shared.SerializedField) *NullableTimestampVector { 102 | return &NullableTimestampVector{ 103 | NewNullableInt64Vector(data, meta), 104 | } 105 | } 106 | 107 | type NullableDateVector struct { 108 | *NullableTimestampVector 109 | } 110 | 111 | func (nv *NullableDateVector) Get(index uint) *time.Time { 112 | ret := nv.NullableTimestampVector.Get(index) 113 | if ret != nil { 114 | *ret = ret.UTC() 115 | } 116 | 117 | return ret 118 | } 119 | 120 | func (nv *NullableDateVector) Value(index uint) interface{} { 121 | ret := nv.Get(index) 122 | if ret != nil { 123 | return *ret 124 | } 125 | 126 | return nil 127 | } 128 | 129 | func NewNullableDateVector(data []byte, meta *shared.SerializedField) *NullableDateVector { 130 | return &NullableDateVector{NewNullableTimestampVector(data, meta)} 131 | } 132 | 133 | type NullableTimeVector struct { 134 | *NullableInt32Vector 135 | } 136 | 137 | func (NullableTimeVector) Type() reflect.Type { 138 | return reflect.TypeOf(time.Time{}) 139 | } 140 | 141 | func (v *NullableTimeVector) Get(index uint) *time.Time { 142 | ts := v.NullableInt32Vector.Get(index) 143 | if ts == nil { 144 | return nil 145 | } 146 | 147 | h, m, s := time.Unix(int64(*ts/1000), int64(*ts%1000)).UTC().Clock() 148 | ret := time.Date(0, 1, 1, h, m, s, 0, time.UTC) 149 | return &ret 150 | } 151 | 152 | func (v *NullableTimeVector) Value(index uint) interface{} { 153 | val := v.Get(index) 154 | if val != nil { 155 | return *val 156 | } 157 | 158 | return val 159 | } 160 | 161 | func NewNullableTimeVector(data []byte, meta *shared.SerializedField) *NullableTimeVector { 162 | return &NullableTimeVector{NewNullableInt32Vector(data, meta)} 163 | } 164 | 165 | type intervalBase interface { 166 | Type() reflect.Type 167 | TypeLen() (int64, bool) 168 | Len() int 169 | GetRawBytes() []byte 170 | getval(index int) []byte 171 | } 172 | 173 | type fixedWidthVec struct { 174 | data []byte 175 | valsz int 176 | 177 | meta *shared.SerializedField 178 | } 179 | 180 | func (fixedWidthVec) Type() reflect.Type { 181 | return reflect.TypeOf(string("")) 182 | } 183 | 184 | func (fixedWidthVec) TypeLen() (int64, bool) { 185 | return 0, false 186 | } 187 | 188 | func (v *fixedWidthVec) GetRawBytes() []byte { 189 | return v.data 190 | } 191 | 192 | func (v *fixedWidthVec) Len() int { 193 | return int(v.meta.GetValueCount()) 194 | } 195 | 196 | func (v *fixedWidthVec) getval(index int) []byte { 197 | start := index * v.valsz 198 | return v.data[start : start+v.valsz] 199 | } 200 | 201 | type nullableIntervalBase interface { 202 | intervalBase 203 | IsNull(index uint) bool 204 | GetNullBytemap() []byte 205 | } 206 | 207 | type nullableFixedWidthVec struct { 208 | *fixedWidthVec 209 | nullByteMap 210 | } 211 | 212 | func (nv *nullableFixedWidthVec) GetNullBytemap() []byte { 213 | return nv.byteMap 214 | } 215 | 216 | func (nv *nullableFixedWidthVec) getval(index int) []byte { 217 | if nv.IsNull(uint(index)) { 218 | return nil 219 | } 220 | return nv.fixedWidthVec.getval(index) 221 | } 222 | 223 | func newNullableFixedWidth(data []byte, meta *shared.SerializedField, valsz int) *nullableFixedWidthVec { 224 | byteMap := data[:meta.GetValueCount()] 225 | remaining := data[meta.GetValueCount():] 226 | 227 | return &nullableFixedWidthVec{ 228 | &fixedWidthVec{remaining, valsz, meta}, 229 | nullByteMap{byteMap}, 230 | } 231 | } 232 | 233 | type intervalVector struct { 234 | intervalBase 235 | process func([]byte) string 236 | } 237 | 238 | func (iv *intervalVector) Get(index uint) string { 239 | return iv.process(iv.getval(int(index))) 240 | } 241 | 242 | func (iv *intervalVector) Value(index uint) interface{} { 243 | return iv.Get(index) 244 | } 245 | 246 | type nullableIntervalVector struct { 247 | nullableIntervalBase 248 | process func([]byte) string 249 | } 250 | 251 | func (iv *nullableIntervalVector) Get(index uint) *string { 252 | data := iv.getval(int(index)) 253 | if data == nil { 254 | return nil 255 | } 256 | 257 | ret := iv.process(data) 258 | return &ret 259 | } 260 | 261 | func (iv *nullableIntervalVector) Value(index uint) interface{} { 262 | val := iv.Get(index) 263 | if val != nil { 264 | return *val 265 | } 266 | return val 267 | } 268 | 269 | func processYear(val []byte) string { 270 | m := int32(binary.LittleEndian.Uint32(val)) 271 | 272 | var prefix string 273 | if m < 0 { 274 | m = -m 275 | prefix = "-" 276 | } 277 | 278 | years := m / 12 279 | months := m % 12 280 | 281 | return fmt.Sprintf("%s%d-%d", prefix, years, months) 282 | } 283 | 284 | const daysToMillis = 24 * 60 * 60 * 1000 285 | 286 | func processDay(val []byte) string { 287 | days := int32(binary.LittleEndian.Uint32(val)) 288 | millis := int32(binary.LittleEndian.Uint32(val[4:])) 289 | 290 | isneg := (days < 0) || (days == 0 && millis < 0) 291 | if days < 0 { 292 | days = -days 293 | } 294 | if millis < 0 { 295 | millis = -millis 296 | } 297 | 298 | days += millis / daysToMillis 299 | millis = millis % daysToMillis 300 | 301 | dur := time.Duration(millis) * time.Millisecond 302 | var prefix string 303 | if isneg { 304 | prefix = "-" 305 | } 306 | 307 | return fmt.Sprintf("%s%d days %s", prefix, days, dur.String()) 308 | } 309 | 310 | func processInterval(val []byte) string { 311 | m := int32(binary.LittleEndian.Uint32(val)) 312 | days := int32(binary.LittleEndian.Uint32(val[4:])) 313 | millis := int32(binary.LittleEndian.Uint32(val[8:])) 314 | 315 | isneg := (m < 0) || (m == 0 && days < 0) || (m == 0 && days == 0 && millis < 0) 316 | m = int32(math.Abs(float64(m))) 317 | days = int32(math.Abs(float64(days))) 318 | millis = int32(math.Abs(float64(millis))) 319 | 320 | years := m / 12 321 | months := m % 12 322 | 323 | days += millis / daysToMillis 324 | millis = millis % daysToMillis 325 | 326 | dur := time.Duration(millis) * time.Millisecond 327 | 328 | var prefix string 329 | if isneg { 330 | prefix = "-" 331 | } 332 | 333 | return fmt.Sprintf("%s%d-%d-%d %s", prefix, years, months, days, dur.String()) 334 | } 335 | 336 | func NewIntervalYearVector(data []byte, meta *shared.SerializedField) *intervalVector { 337 | return &intervalVector{ 338 | intervalBase: &fixedWidthVec{data, 4, meta}, 339 | process: processYear, 340 | } 341 | } 342 | 343 | func NewNullableIntervalYearVector(data []byte, meta *shared.SerializedField) *nullableIntervalVector { 344 | return &nullableIntervalVector{ 345 | newNullableFixedWidth(data, meta, 4), 346 | processYear, 347 | } 348 | } 349 | 350 | func NewIntervalDayVector(data []byte, meta *shared.SerializedField) *intervalVector { 351 | return &intervalVector{ 352 | intervalBase: &fixedWidthVec{data, 8, meta}, 353 | process: processDay, 354 | } 355 | } 356 | 357 | func NewNullableIntervalDayVector(data []byte, meta *shared.SerializedField) *nullableIntervalVector { 358 | return &nullableIntervalVector{ 359 | newNullableFixedWidth(data, meta, 8), 360 | processDay, 361 | } 362 | } 363 | 364 | func NewIntervalVector(data []byte, meta *shared.SerializedField) *intervalVector { 365 | return &intervalVector{ 366 | intervalBase: &fixedWidthVec{data, 12, meta}, 367 | process: processInterval, 368 | } 369 | } 370 | 371 | func NewNullableIntervalVector(data []byte, meta *shared.SerializedField) *nullableIntervalVector { 372 | return &nullableIntervalVector{ 373 | newNullableFixedWidth(data, meta, 12), 374 | processInterval, 375 | } 376 | } 377 | -------------------------------------------------------------------------------- /internal/data/data_vector.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "encoding/binary" 5 | "math" 6 | "math/big" 7 | "reflect" 8 | "unsafe" 9 | 10 | "github.com/factset/go-drill/internal/rpc/proto/common" 11 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 12 | "google.golang.org/protobuf/proto" 13 | ) 14 | 15 | //go:generate go run ../cmd/tmpl -data numeric.tmpldata vector_numeric.gen.go.tmpl type_traits_numeric.gen.go.tmpl numeric_vec_typemap.gen.go.tmpl 16 | //go:generate go run ../cmd/tmpl -data numeric.tmpldata type_traits_numeric.gen_test.go.tmpl vector_numeric.gen_test.go.tmpl numeric_vec_typemap.gen_test.go.tmpl 17 | 18 | type DataVector interface { 19 | Len() int 20 | Value(index uint) interface{} 21 | Type() reflect.Type 22 | TypeLen() (int64, bool) 23 | GetRawBytes() []byte 24 | } 25 | 26 | type NullableDataVector interface { 27 | DataVector 28 | IsNull(index uint) bool 29 | GetNullBytemap() []byte 30 | } 31 | 32 | type BitVector struct { 33 | vector 34 | values []byte 35 | meta *shared.SerializedField 36 | } 37 | 38 | func (BitVector) Type() reflect.Type { 39 | return reflect.TypeOf(bool(false)) 40 | } 41 | 42 | func (BitVector) TypeLen() (int64, bool) { 43 | return 0, false 44 | } 45 | 46 | func (b *BitVector) Len() int { 47 | return int(b.meta.GetValueCount()) 48 | } 49 | 50 | func (b *BitVector) Get(index uint) bool { 51 | bt := b.values[index/8] 52 | return bt&(1<<(index%8)) != 0 53 | } 54 | 55 | func (b *BitVector) Value(index uint) interface{} { 56 | return b.Get(index) 57 | } 58 | 59 | func NewBitVector(data []byte, meta *shared.SerializedField) *BitVector { 60 | return &BitVector{ 61 | vector: vector{data}, 62 | values: data, 63 | meta: meta, 64 | } 65 | } 66 | 67 | type NullableBitVector struct { 68 | *BitVector 69 | 70 | nullByteMap 71 | } 72 | 73 | func (nb *NullableBitVector) Get(index uint) *bool { 74 | if nb.IsNull(index) { 75 | return nil 76 | } 77 | 78 | return proto.Bool(nb.BitVector.Get(index)) 79 | } 80 | 81 | func (nb *NullableBitVector) Value(index uint) interface{} { 82 | val := nb.Get(index) 83 | if val != nil { 84 | return *val 85 | } 86 | return val 87 | } 88 | 89 | func NewNullableBitVector(data []byte, meta *shared.SerializedField) *NullableBitVector { 90 | bytemap := data[:meta.GetValueCount()] 91 | remaining := data[meta.GetValueCount():] 92 | 93 | return &NullableBitVector{ 94 | NewBitVector(remaining, meta), 95 | nullByteMap{bytemap}, 96 | } 97 | } 98 | 99 | type VarbinaryVector struct { 100 | vector 101 | offsets []uint32 102 | data []byte 103 | 104 | meta *shared.SerializedField 105 | } 106 | 107 | func (VarbinaryVector) Type() reflect.Type { 108 | return reflect.TypeOf([]byte{}) 109 | } 110 | 111 | func (VarbinaryVector) TypeLen() (int64, bool) { 112 | return int64(math.MaxUint16), true 113 | } 114 | 115 | func (v *VarbinaryVector) Len() int { 116 | return int(v.meta.GetValueCount()) 117 | } 118 | 119 | func (v *VarbinaryVector) Get(index uint) []byte { 120 | return v.data[v.offsets[index]:v.offsets[index+1]] 121 | } 122 | 123 | func (v *VarbinaryVector) Value(index uint) interface{} { 124 | return v.Get(index) 125 | } 126 | 127 | func NewVarbinaryVector(data []byte, meta *shared.SerializedField) *VarbinaryVector { 128 | if data == nil { 129 | return &VarbinaryVector{ 130 | vector: vector{data}, 131 | offsets: []uint32{}, 132 | data: []byte{}, 133 | meta: meta, 134 | } 135 | } 136 | 137 | var offsetField *shared.SerializedField 138 | if meta.MajorType.GetMode() == common.DataMode_REQUIRED { 139 | offsetField = meta.Child[0] 140 | } else { 141 | offsetField = meta.Child[1].Child[0] 142 | } 143 | 144 | offsetBytesSize := offsetField.GetBufferLength() 145 | offsetBytes := data[:offsetBytesSize] 146 | remaining := data[offsetBytesSize:] 147 | 148 | offsetList := make([]uint32, meta.GetValueCount()+1) 149 | for i := 0; i < len(offsetList); i++ { 150 | offsetList[i] = binary.LittleEndian.Uint32(offsetBytes[i*4:]) 151 | } 152 | 153 | return &VarbinaryVector{ 154 | vector: vector{data}, 155 | offsets: offsetList, 156 | data: remaining, 157 | meta: meta, 158 | } 159 | } 160 | 161 | type VarcharVector struct { 162 | *VarbinaryVector 163 | } 164 | 165 | func (VarcharVector) Type() reflect.Type { 166 | return reflect.TypeOf(string("")) 167 | } 168 | 169 | func (v *VarcharVector) Get(index uint) string { 170 | b := v.VarbinaryVector.Get(index) 171 | return *(*string)(unsafe.Pointer(&b)) 172 | } 173 | 174 | func NewVarcharVector(data []byte, meta *shared.SerializedField) *VarcharVector { 175 | return &VarcharVector{NewVarbinaryVector(data, meta)} 176 | } 177 | 178 | type NullableVarcharVector struct { 179 | *VarcharVector 180 | 181 | nullByteMap 182 | } 183 | 184 | func (nv *NullableVarcharVector) Get(index uint) *string { 185 | if nv.IsNull(index) { 186 | return nil 187 | } 188 | 189 | b := nv.VarbinaryVector.Get(index) 190 | return (*string)(unsafe.Pointer(&b)) 191 | } 192 | 193 | func (nv *NullableVarcharVector) Value(index uint) interface{} { 194 | val := nv.Get(index) 195 | if val == nil { 196 | return nil 197 | } 198 | 199 | return *val 200 | } 201 | 202 | func NewNullableVarcharVector(data []byte, meta *shared.SerializedField) *NullableVarcharVector { 203 | byteMap := data[:meta.GetValueCount()] 204 | remaining := data[meta.GetValueCount():] 205 | 206 | return &NullableVarcharVector{ 207 | NewVarcharVector(remaining, meta), 208 | nullByteMap{byteMap}, 209 | } 210 | } 211 | 212 | type DecimalVector struct { 213 | *fixedWidthVec 214 | 215 | traits DecimalTraits 216 | scale int 217 | prec int32 218 | } 219 | 220 | func NewDecimalVector(data []byte, meta *shared.SerializedField, traits DecimalTraits) *DecimalVector { 221 | return &DecimalVector{ 222 | fixedWidthVec: &fixedWidthVec{data: data, valsz: traits.ByteWidth(), meta: meta}, 223 | scale: int(meta.MajorType.GetScale()), 224 | prec: meta.MajorType.GetPrecision(), 225 | traits: traits, 226 | } 227 | } 228 | 229 | func (dv *DecimalVector) Get(index uint) *big.Float { 230 | valbytes := dv.getval(int(index)) 231 | if !dv.traits.IsSparse() { 232 | panic("go-drill: currently only supports decimal sparse vectors, not dense") 233 | } 234 | 235 | return getFloatFromBytes(valbytes, dv.traits.NumDigits(), dv.scale, dv.traits.IsSparse()) 236 | } 237 | 238 | func (dv *DecimalVector) Value(index uint) interface{} { 239 | return dv.Get(index) 240 | } 241 | 242 | type NullableDecimalVector struct { 243 | *nullableFixedWidthVec 244 | 245 | traits DecimalTraits 246 | scale int 247 | prec int32 248 | } 249 | 250 | func (dv *NullableDecimalVector) Get(index uint) *big.Float { 251 | valbytes := dv.getval(int(index)) 252 | if valbytes == nil { 253 | return nil 254 | } 255 | 256 | if !dv.traits.IsSparse() { 257 | panic("go-drill: currently only supports decimal sparse vectors, not dense") 258 | } 259 | return getFloatFromBytes(valbytes, dv.traits.NumDigits(), dv.scale, dv.traits.IsSparse()) 260 | } 261 | 262 | func (dv *NullableDecimalVector) Value(index uint) interface{} { 263 | return dv.Get(index) 264 | } 265 | 266 | func NewNullableDecimalVector(data []byte, meta *shared.SerializedField, traits DecimalTraits) *NullableDecimalVector { 267 | return &NullableDecimalVector{ 268 | nullableFixedWidthVec: newNullableFixedWidth(data, meta, traits.ByteWidth()), 269 | scale: int(meta.MajorType.GetScale()), 270 | prec: meta.MajorType.GetPrecision(), 271 | traits: traits, 272 | } 273 | } 274 | 275 | func NewValueVec(rawData []byte, meta *shared.SerializedField) DataVector { 276 | ret := NewNumericValueVec(rawData, meta) 277 | if ret != nil { 278 | return ret 279 | } 280 | 281 | if meta.GetMajorType().GetMode() == common.DataMode_OPTIONAL { 282 | switch meta.GetMajorType().GetMinorType() { 283 | case common.MinorType_BIT: 284 | return NewNullableBitVector(rawData, meta) 285 | case common.MinorType_VARCHAR: 286 | return NewNullableVarcharVector(rawData, meta) 287 | case common.MinorType_TIMESTAMP: 288 | return NewNullableTimestampVector(rawData, meta) 289 | case common.MinorType_DATE: 290 | return NewNullableDateVector(rawData, meta) 291 | case common.MinorType_TIME: 292 | return NewNullableTimeVector(rawData, meta) 293 | case common.MinorType_INTERVAL: 294 | return NewNullableIntervalVector(rawData, meta) 295 | case common.MinorType_INTERVALDAY: 296 | return NewNullableIntervalDayVector(rawData, meta) 297 | case common.MinorType_INTERVALYEAR: 298 | return NewNullableIntervalYearVector(rawData, meta) 299 | case common.MinorType_DECIMAL28SPARSE: 300 | return NewNullableDecimalVector(rawData, meta, &Decimal28SparseTraits) 301 | case common.MinorType_DECIMAL38SPARSE: 302 | return NewNullableDecimalVector(rawData, meta, &Decimal38SparseTraits) 303 | } 304 | } else { 305 | switch meta.GetMajorType().GetMinorType() { 306 | case common.MinorType_VARBINARY: 307 | return NewVarbinaryVector(rawData, meta) 308 | case common.MinorType_VARCHAR: 309 | return NewVarcharVector(rawData, meta) 310 | case common.MinorType_BIT: 311 | return NewBitVector(rawData, meta) 312 | case common.MinorType_TIMESTAMP: 313 | return NewTimestampVector(rawData, meta) 314 | case common.MinorType_DATE: 315 | return NewDateVector(rawData, meta) 316 | case common.MinorType_TIME: 317 | return NewTimeVector(rawData, meta) 318 | case common.MinorType_INTERVAL: 319 | return NewIntervalVector(rawData, meta) 320 | case common.MinorType_INTERVALDAY: 321 | return NewIntervalDayVector(rawData, meta) 322 | case common.MinorType_INTERVALYEAR: 323 | return NewIntervalYearVector(rawData, meta) 324 | case common.MinorType_DECIMAL28SPARSE: 325 | return NewDecimalVector(rawData, meta, &Decimal28SparseTraits) 326 | case common.MinorType_DECIMAL38SPARSE: 327 | return NewDecimalVector(rawData, meta, &Decimal38SparseTraits) 328 | } 329 | } 330 | 331 | return nil 332 | } 333 | -------------------------------------------------------------------------------- /driver/rows_test.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "bytes" 5 | "compress/zlib" 6 | "database/sql/driver" 7 | "encoding/hex" 8 | "io" 9 | "io/ioutil" 10 | "math" 11 | "reflect" 12 | "testing" 13 | 14 | "github.com/factset/go-drill" 15 | "github.com/factset/go-drill/internal/data" 16 | "github.com/factset/go-drill/internal/rpc/proto/common" 17 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 18 | "github.com/stretchr/testify/assert" 19 | "github.com/stretchr/testify/mock" 20 | "google.golang.org/protobuf/proto" 21 | ) 22 | 23 | func TestImplements(t *testing.T) { 24 | assert.Implements(t, (*driver.RowsColumnTypeDatabaseTypeName)(nil), new(rows)) 25 | assert.Implements(t, (*driver.RowsColumnTypeLength)(nil), new(rows)) 26 | assert.Implements(t, (*driver.RowsColumnTypeNullable)(nil), new(rows)) 27 | assert.Implements(t, (*driver.RowsColumnTypePrecisionScale)(nil), new(rows)) 28 | assert.Implements(t, (*driver.RowsColumnTypeScanType)(nil), new(rows)) 29 | } 30 | 31 | func TestRowsClose(t *testing.T) { 32 | m := new(mockResHandle) 33 | m.Test(t) 34 | defer m.AssertExpectations(t) 35 | 36 | m.On("Close").Return(assert.AnError) 37 | 38 | r := &rows{handle: m} 39 | assert.Same(t, assert.AnError, r.Close()) 40 | } 41 | 42 | func TestRowsGetCols(t *testing.T) { 43 | m := new(mockResHandle) 44 | m.Test(t) 45 | defer m.AssertExpectations(t) 46 | 47 | cols := []string{"a", "b", "c"} 48 | m.On("GetCols").Return(cols) 49 | 50 | r := &rows{handle: m} 51 | assert.Exactly(t, cols, r.Columns()) 52 | } 53 | 54 | type mockBatch struct { 55 | mock.Mock 56 | } 57 | 58 | func (mb *mockBatch) ColumnName(_ int) string { return "" } 59 | func (mb *mockBatch) NumCols() int { return 0 } 60 | func (mb *mockBatch) AffectedRows() int32 { return int32(mb.Called().Int(0)) } 61 | func (mb *mockBatch) NumRows() int32 { 62 | return int32(mb.Called().Int(0)) 63 | } 64 | func (mb *mockBatch) IsNullable(index int) bool { return false } 65 | func (mb *mockBatch) TypeName(index int) string { 66 | return mb.Called(index).String(0) 67 | } 68 | func (mb *mockBatch) PrecisionScale(index int) (precision, scale int64, ok bool) { 69 | args := mb.Called(index) 70 | return int64(args.Int(0)), int64(args.Int(1)), args.Bool(2) 71 | } 72 | func (mb *mockBatch) GetVectors() []drill.DataVector { 73 | return mb.Called().Get(0).([]drill.DataVector) 74 | } 75 | 76 | var compraw = "789c5c92cf6ed34010c6bfa4e5cf9113270ec311a90421240e705a5a132c15b78a7a29b7ad3d76962ebbeefe09f18977e24d780d5e802bcac68e93ecc19fb433dfcc6f678cfe4c7a9df67ad2eb69af8f7a7ddceb935e9f623c9bbb67009e037801e025805700de00780fe023007139cf16b9108b7956dce485f8b410dff2cb7351880b91cd6faf6fb29b2ff9d5752e3e2f44719ecdb3c55751dce6c5452e8e518ff51487e7e448a77bb177007e01f80be06a02fc9e007f26c0bf09f0760ab829404bd9349a67544ac775d4baa35a19a9a9e2d67a153c551cb80ce475a73b928d5452536dd7eca975f687f23c84cad2ba4a998682a5b06472dc442d5d0a4413fc8ceeacaec8f143641f3c496d4d97be8d571593ad93ad6593aa0c00fb68bee572d3bf95e5bd6cd893744cf2cec690acca59a34aaaad6bd8cf7aaec1c3d2d3529aa6cf3a23af349bb0ab35e40fd8073deae8948d5e776457ec52b3a0c272e3aa75ac6ba5bbf4ba6e43b3e294306247137d947a736b2bab5548beed98636363d8367988aabc4f9374d6fb5463ec3b60552b3623db4f79bff3cd52cee1168761cf06ffd9f6f99a5a2d036bde0bed1645ebd74e761f06ecb38304ef89d765f4eca994dfad1ed7bf831e27bafb895aa74c20e96c34ff030000ffff84ef0203" 77 | 78 | var sampleDef = shared.RecordBatchDef{ 79 | RecordCount: proto.Int32(9), 80 | Field: []*shared.SerializedField{ 81 | { 82 | MajorType: &common.MajorType{MinorType: common.MinorType_BIGINT.Enum(), Mode: common.DataMode_REQUIRED.Enum()}, 83 | NamePart: &shared.NamePart{Name: proto.String("N_NATIONKEY")}, 84 | ValueCount: proto.Int32(9), 85 | BufferLength: proto.Int32(72), 86 | }, 87 | { 88 | MajorType: &common.MajorType{MinorType: common.MinorType_VARBINARY.Enum(), Mode: common.DataMode_REQUIRED.Enum()}, 89 | NamePart: &shared.NamePart{Name: proto.String("N_NAME")}, 90 | Child: []*shared.SerializedField{ 91 | { 92 | MajorType: &common.MajorType{MinorType: common.MinorType_UINT4.Enum(), Mode: common.DataMode_REQUIRED.Enum()}, 93 | NamePart: &shared.NamePart{Name: proto.String("$offsets$")}, 94 | ValueCount: proto.Int32(10), 95 | BufferLength: proto.Int32(40), 96 | }, 97 | }, 98 | ValueCount: proto.Int32(9), 99 | BufferLength: proto.Int32(99), 100 | }, 101 | { 102 | MajorType: &common.MajorType{MinorType: common.MinorType_BIGINT.Enum(), Mode: common.DataMode_REQUIRED.Enum()}, 103 | NamePart: &shared.NamePart{Name: proto.String("N_REGIONKEY")}, 104 | ValueCount: proto.Int32(9), 105 | BufferLength: proto.Int32(72), 106 | }, 107 | { 108 | MajorType: &common.MajorType{MinorType: common.MinorType_VARBINARY.Enum(), Mode: common.DataMode_REQUIRED.Enum()}, 109 | NamePart: &shared.NamePart{Name: proto.String("N_COMMENT")}, 110 | Child: []*shared.SerializedField{ 111 | { 112 | MajorType: &common.MajorType{MinorType: common.MinorType_UINT4.Enum(), Mode: common.DataMode_REQUIRED.Enum()}, 113 | NamePart: &shared.NamePart{Name: proto.String("$offsets$")}, 114 | ValueCount: proto.Int32(10), 115 | BufferLength: proto.Int32(40), 116 | }, 117 | }, 118 | ValueCount: proto.Int32(9), 119 | BufferLength: proto.Int32(666), 120 | }, 121 | }, 122 | } 123 | 124 | func getSampleRecordBatch() drill.RowBatch { 125 | b, _ := hex.DecodeString(compraw) 126 | zr, _ := zlib.NewReader(bytes.NewReader(b)) 127 | defer zr.Close() 128 | 129 | rawblock, _ := ioutil.ReadAll(zr) 130 | 131 | vecs := make([]drill.DataVector, 0, 4) 132 | var offset int32 = 0 133 | for _, f := range sampleDef.GetField() { 134 | vecs = append(vecs, data.NewValueVec(rawblock[offset:offset+f.GetBufferLength()], f)) 135 | offset += f.GetBufferLength() 136 | } 137 | 138 | mb := new(mockBatch) 139 | mb.On("NumRows").Return(9) 140 | mb.On("GetVectors").Return(vecs) 141 | 142 | mb.On("TypeName", 0).Return("BIGINT") 143 | mb.On("TypeName", 1).Return("VARBINARY") 144 | mb.On("TypeName", 2).Return("BIGINT") 145 | mb.On("TypeName", 3).Return("VARBINARY") 146 | 147 | mb.On("IsNullable", mock.Anything).Return(false) 148 | 149 | mb.On("PrecisionScale", 0).Return(0, 0, false) 150 | mb.On("PrecisionScale", 1).Return(0, 0, false) 151 | return mb 152 | } 153 | 154 | func TestRowsNext(t *testing.T) { 155 | mr := new(mockResHandle) 156 | mr.Test(t) 157 | defer mr.AssertExpectations(t) 158 | 159 | mr.On("GetRecordBatch").Return(getSampleRecordBatch()) 160 | 161 | r := &rows{handle: mr, curRow: 1} 162 | dest := make([]driver.Value, 4) 163 | 164 | assert.NoError(t, r.Next(dest)) 165 | assert.Exactly(t, int64(1), dest[0]) 166 | assert.Exactly(t, []byte("ARGENTINA"), dest[1]) 167 | assert.Exactly(t, int64(1), dest[2]) 168 | assert.Exactly(t, []byte("al foxes promise slyly according to the regular accounts. bold requests alon"), dest[3]) 169 | assert.Equal(t, 2, r.curRow) 170 | } 171 | 172 | func TestRowsNextEnd(t *testing.T) { 173 | mr := new(mockResHandle) 174 | mr.Test(t) 175 | defer mr.AssertExpectations(t) 176 | 177 | mr.On("GetRecordBatch").Return((drill.RowBatch)(nil)) 178 | 179 | r := &rows{handle: mr, curRow: 1} 180 | dest := make([]driver.Value, 4) 181 | assert.Same(t, io.EOF, r.Next(dest)) 182 | } 183 | 184 | func TestRowsNextCallNext(t *testing.T) { 185 | mr := new(mockResHandle) 186 | mr.Test(t) 187 | defer mr.AssertExpectations(t) 188 | 189 | mr.On("GetRecordBatch").Return(getSampleRecordBatch()) 190 | mr.On("Next").Return(nil, getSampleRecordBatch()) 191 | 192 | r := &rows{handle: mr, curRow: 10} 193 | dest := make([]driver.Value, 4) 194 | assert.NoError(t, r.Next(dest)) 195 | 196 | assert.Exactly(t, int64(0), dest[0]) 197 | assert.Exactly(t, []byte("ALGERIA"), dest[1]) 198 | assert.Exactly(t, int64(0), dest[2]) 199 | assert.Exactly(t, []byte(" haggle. carefully final deposits detect slyly agai"), dest[3]) 200 | assert.Equal(t, 1, r.curRow) 201 | } 202 | 203 | func TestRowsNextCallNextErr(t *testing.T) { 204 | mr := new(mockResHandle) 205 | mr.Test(t) 206 | defer mr.AssertExpectations(t) 207 | 208 | mr.On("GetRecordBatch").Return(getSampleRecordBatch()) 209 | mr.On("Next").Return(assert.AnError, (drill.RowBatch)(nil)) 210 | 211 | r := &rows{handle: mr, curRow: 10} 212 | dest := make([]driver.Value, 4) 213 | assert.Same(t, assert.AnError, r.Next(dest)) 214 | } 215 | 216 | func TestRowsColumnTypeHelpers(t *testing.T) { 217 | mr := new(mockResHandle) 218 | mr.Test(t) 219 | defer mr.AssertExpectations(t) 220 | 221 | mr.On("GetRecordBatch").Return(getSampleRecordBatch()) 222 | 223 | r := &rows{handle: mr, curRow: 0} 224 | 225 | tests := []struct { 226 | name string 227 | f func() interface{} 228 | val interface{} 229 | }{ 230 | {"column type scan type", func() interface{} { return r.ColumnTypeScanType(0) }, reflect.TypeOf(int64(0))}, 231 | {"column type scan type", func() interface{} { return r.ColumnTypeScanType(1) }, reflect.TypeOf([]byte{})}, 232 | {"column type scan type", func() interface{} { return r.ColumnTypeScanType(2) }, reflect.TypeOf(int64(0))}, 233 | {"column type scan type", func() interface{} { return r.ColumnTypeScanType(3) }, reflect.TypeOf([]byte{})}, 234 | {"column database type name", func() interface{} { return r.ColumnTypeDatabaseTypeName(0) }, "BIGINT"}, 235 | {"column database type name", func() interface{} { return r.ColumnTypeDatabaseTypeName(1) }, "VARBINARY"}, 236 | {"column database type name", func() interface{} { return r.ColumnTypeDatabaseTypeName(2) }, "BIGINT"}, 237 | {"column database type name", func() interface{} { return r.ColumnTypeDatabaseTypeName(3) }, "VARBINARY"}, 238 | {"column type nullable", func() interface{} { 239 | a, b := r.ColumnTypeNullable(0) 240 | return []bool{a, b} 241 | }, []bool{false, true}}, 242 | {"column type length", func() interface{} { 243 | a, b := r.ColumnTypeLength(0) 244 | return []interface{}{a, b} 245 | }, []interface{}{int64(0), false}}, 246 | {"column type length", func() interface{} { 247 | a, b := r.ColumnTypeLength(1) 248 | return []interface{}{a, b} 249 | }, []interface{}{int64(math.MaxUint16), true}}, 250 | } 251 | 252 | for _, tt := range tests { 253 | t.Run(tt.name, func(t *testing.T) { 254 | assert.Equal(t, tt.val, tt.f()) 255 | }) 256 | } 257 | } 258 | 259 | func TestColumnPrecisionScaleNoVal(t *testing.T) { 260 | mr := new(mockResHandle) 261 | mr.Test(t) 262 | defer mr.AssertExpectations(t) 263 | 264 | mr.On("GetRecordBatch").Return(getSampleRecordBatch()) 265 | 266 | r := &rows{handle: mr, curRow: 0} 267 | p, s, ok := r.ColumnTypePrecisionScale(0) 268 | assert.False(t, ok) 269 | assert.Zero(t, p) 270 | assert.Zero(t, s) 271 | 272 | p, s, ok = r.ColumnTypePrecisionScale(1) 273 | assert.False(t, ok) 274 | assert.Zero(t, p) 275 | assert.Zero(t, s) 276 | } 277 | 278 | func TestColumnPrecisionScale(t *testing.T) { 279 | mb := new(mockBatch) 280 | mb.On("PrecisionScale", 0).Return(4, 25, true) 281 | 282 | mr := new(mockResHandle) 283 | mr.Test(t) 284 | defer mr.AssertExpectations(t) 285 | 286 | mr.On("GetRecordBatch").Return(mb) 287 | 288 | r := &rows{handle: mr, curRow: 0} 289 | p, s, ok := r.ColumnTypePrecisionScale(0) 290 | assert.True(t, ok) 291 | assert.EqualValues(t, 4, p) 292 | assert.EqualValues(t, 25, s) 293 | } 294 | -------------------------------------------------------------------------------- /internal/data/numeric_vec_typemap.gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by numeric_vec_typemap.gen_test.go.tmpl. DO NOT EDIT. 2 | 3 | package data_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/factset/go-drill/internal/data" 9 | "github.com/factset/go-drill/internal/rpc/proto/common" 10 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 11 | "github.com/stretchr/testify/assert" 12 | "google.golang.org/protobuf/proto" 13 | ) 14 | 15 | func TestNewNumericVecRequiredInt64(t *testing.T) { 16 | const N = 10 17 | b := data.Int64Traits.CastToBytes([]int64{ 18 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19 | }) 20 | 21 | meta := &shared.SerializedField{ 22 | MajorType: &common.MajorType{ 23 | MinorType: common.MinorType_BIGINT.Enum(), 24 | Mode: common.DataMode_REQUIRED.Enum(), 25 | }, 26 | } 27 | 28 | dv := data.NewValueVec(b, meta) 29 | assert.IsType(t, (*data.Int64Vector)(nil), dv) 30 | } 31 | 32 | func TestNewNumericVecOptionalInt64(t *testing.T) { 33 | const N = 10 34 | b := data.Int64Traits.CastToBytes([]int64{ 35 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 36 | }) 37 | 38 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 39 | 40 | meta := &shared.SerializedField{ 41 | ValueCount: proto.Int32(0), 42 | MajorType: &common.MajorType{ 43 | MinorType: common.MinorType_BIGINT.Enum(), 44 | Mode: common.DataMode_OPTIONAL.Enum(), 45 | }, 46 | } 47 | 48 | dv := data.NewValueVec(append(bytemap, b...), meta) 49 | assert.IsType(t, (*data.NullableInt64Vector)(nil), dv) 50 | } 51 | 52 | func TestNewNumericVecRequiredInt32(t *testing.T) { 53 | const N = 10 54 | b := data.Int32Traits.CastToBytes([]int32{ 55 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 56 | }) 57 | 58 | meta := &shared.SerializedField{ 59 | MajorType: &common.MajorType{ 60 | MinorType: common.MinorType_INT.Enum(), 61 | Mode: common.DataMode_REQUIRED.Enum(), 62 | }, 63 | } 64 | 65 | dv := data.NewValueVec(b, meta) 66 | assert.IsType(t, (*data.Int32Vector)(nil), dv) 67 | } 68 | 69 | func TestNewNumericVecOptionalInt32(t *testing.T) { 70 | const N = 10 71 | b := data.Int32Traits.CastToBytes([]int32{ 72 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 73 | }) 74 | 75 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 76 | 77 | meta := &shared.SerializedField{ 78 | ValueCount: proto.Int32(0), 79 | MajorType: &common.MajorType{ 80 | MinorType: common.MinorType_INT.Enum(), 81 | Mode: common.DataMode_OPTIONAL.Enum(), 82 | }, 83 | } 84 | 85 | dv := data.NewValueVec(append(bytemap, b...), meta) 86 | assert.IsType(t, (*data.NullableInt32Vector)(nil), dv) 87 | } 88 | 89 | func TestNewNumericVecRequiredFloat64(t *testing.T) { 90 | const N = 10 91 | b := data.Float64Traits.CastToBytes([]float64{ 92 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 93 | }) 94 | 95 | meta := &shared.SerializedField{ 96 | MajorType: &common.MajorType{ 97 | MinorType: common.MinorType_FLOAT8.Enum(), 98 | Mode: common.DataMode_REQUIRED.Enum(), 99 | }, 100 | } 101 | 102 | dv := data.NewValueVec(b, meta) 103 | assert.IsType(t, (*data.Float64Vector)(nil), dv) 104 | } 105 | 106 | func TestNewNumericVecOptionalFloat64(t *testing.T) { 107 | const N = 10 108 | b := data.Float64Traits.CastToBytes([]float64{ 109 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 110 | }) 111 | 112 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 113 | 114 | meta := &shared.SerializedField{ 115 | ValueCount: proto.Int32(0), 116 | MajorType: &common.MajorType{ 117 | MinorType: common.MinorType_FLOAT8.Enum(), 118 | Mode: common.DataMode_OPTIONAL.Enum(), 119 | }, 120 | } 121 | 122 | dv := data.NewValueVec(append(bytemap, b...), meta) 123 | assert.IsType(t, (*data.NullableFloat64Vector)(nil), dv) 124 | } 125 | 126 | func TestNewNumericVecRequiredUint64(t *testing.T) { 127 | const N = 10 128 | b := data.Uint64Traits.CastToBytes([]uint64{ 129 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 130 | }) 131 | 132 | meta := &shared.SerializedField{ 133 | MajorType: &common.MajorType{ 134 | MinorType: common.MinorType_UINT8.Enum(), 135 | Mode: common.DataMode_REQUIRED.Enum(), 136 | }, 137 | } 138 | 139 | dv := data.NewValueVec(b, meta) 140 | assert.IsType(t, (*data.Uint64Vector)(nil), dv) 141 | } 142 | 143 | func TestNewNumericVecOptionalUint64(t *testing.T) { 144 | const N = 10 145 | b := data.Uint64Traits.CastToBytes([]uint64{ 146 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 147 | }) 148 | 149 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 150 | 151 | meta := &shared.SerializedField{ 152 | ValueCount: proto.Int32(0), 153 | MajorType: &common.MajorType{ 154 | MinorType: common.MinorType_UINT8.Enum(), 155 | Mode: common.DataMode_OPTIONAL.Enum(), 156 | }, 157 | } 158 | 159 | dv := data.NewValueVec(append(bytemap, b...), meta) 160 | assert.IsType(t, (*data.NullableUint64Vector)(nil), dv) 161 | } 162 | 163 | func TestNewNumericVecRequiredUint32(t *testing.T) { 164 | const N = 10 165 | b := data.Uint32Traits.CastToBytes([]uint32{ 166 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 167 | }) 168 | 169 | meta := &shared.SerializedField{ 170 | MajorType: &common.MajorType{ 171 | MinorType: common.MinorType_UINT4.Enum(), 172 | Mode: common.DataMode_REQUIRED.Enum(), 173 | }, 174 | } 175 | 176 | dv := data.NewValueVec(b, meta) 177 | assert.IsType(t, (*data.Uint32Vector)(nil), dv) 178 | } 179 | 180 | func TestNewNumericVecOptionalUint32(t *testing.T) { 181 | const N = 10 182 | b := data.Uint32Traits.CastToBytes([]uint32{ 183 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 184 | }) 185 | 186 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 187 | 188 | meta := &shared.SerializedField{ 189 | ValueCount: proto.Int32(0), 190 | MajorType: &common.MajorType{ 191 | MinorType: common.MinorType_UINT4.Enum(), 192 | Mode: common.DataMode_OPTIONAL.Enum(), 193 | }, 194 | } 195 | 196 | dv := data.NewValueVec(append(bytemap, b...), meta) 197 | assert.IsType(t, (*data.NullableUint32Vector)(nil), dv) 198 | } 199 | 200 | func TestNewNumericVecRequiredFloat32(t *testing.T) { 201 | const N = 10 202 | b := data.Float32Traits.CastToBytes([]float32{ 203 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 204 | }) 205 | 206 | meta := &shared.SerializedField{ 207 | MajorType: &common.MajorType{ 208 | MinorType: common.MinorType_FLOAT4.Enum(), 209 | Mode: common.DataMode_REQUIRED.Enum(), 210 | }, 211 | } 212 | 213 | dv := data.NewValueVec(b, meta) 214 | assert.IsType(t, (*data.Float32Vector)(nil), dv) 215 | } 216 | 217 | func TestNewNumericVecOptionalFloat32(t *testing.T) { 218 | const N = 10 219 | b := data.Float32Traits.CastToBytes([]float32{ 220 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 221 | }) 222 | 223 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 224 | 225 | meta := &shared.SerializedField{ 226 | ValueCount: proto.Int32(0), 227 | MajorType: &common.MajorType{ 228 | MinorType: common.MinorType_FLOAT4.Enum(), 229 | Mode: common.DataMode_OPTIONAL.Enum(), 230 | }, 231 | } 232 | 233 | dv := data.NewValueVec(append(bytemap, b...), meta) 234 | assert.IsType(t, (*data.NullableFloat32Vector)(nil), dv) 235 | } 236 | 237 | func TestNewNumericVecRequiredInt16(t *testing.T) { 238 | const N = 10 239 | b := data.Int16Traits.CastToBytes([]int16{ 240 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 241 | }) 242 | 243 | meta := &shared.SerializedField{ 244 | MajorType: &common.MajorType{ 245 | MinorType: common.MinorType_SMALLINT.Enum(), 246 | Mode: common.DataMode_REQUIRED.Enum(), 247 | }, 248 | } 249 | 250 | dv := data.NewValueVec(b, meta) 251 | assert.IsType(t, (*data.Int16Vector)(nil), dv) 252 | } 253 | 254 | func TestNewNumericVecOptionalInt16(t *testing.T) { 255 | const N = 10 256 | b := data.Int16Traits.CastToBytes([]int16{ 257 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 258 | }) 259 | 260 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 261 | 262 | meta := &shared.SerializedField{ 263 | ValueCount: proto.Int32(0), 264 | MajorType: &common.MajorType{ 265 | MinorType: common.MinorType_SMALLINT.Enum(), 266 | Mode: common.DataMode_OPTIONAL.Enum(), 267 | }, 268 | } 269 | 270 | dv := data.NewValueVec(append(bytemap, b...), meta) 271 | assert.IsType(t, (*data.NullableInt16Vector)(nil), dv) 272 | } 273 | 274 | func TestNewNumericVecRequiredUint16(t *testing.T) { 275 | const N = 10 276 | b := data.Uint16Traits.CastToBytes([]uint16{ 277 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 278 | }) 279 | 280 | meta := &shared.SerializedField{ 281 | MajorType: &common.MajorType{ 282 | MinorType: common.MinorType_UINT2.Enum(), 283 | Mode: common.DataMode_REQUIRED.Enum(), 284 | }, 285 | } 286 | 287 | dv := data.NewValueVec(b, meta) 288 | assert.IsType(t, (*data.Uint16Vector)(nil), dv) 289 | } 290 | 291 | func TestNewNumericVecOptionalUint16(t *testing.T) { 292 | const N = 10 293 | b := data.Uint16Traits.CastToBytes([]uint16{ 294 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 295 | }) 296 | 297 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 298 | 299 | meta := &shared.SerializedField{ 300 | ValueCount: proto.Int32(0), 301 | MajorType: &common.MajorType{ 302 | MinorType: common.MinorType_UINT2.Enum(), 303 | Mode: common.DataMode_OPTIONAL.Enum(), 304 | }, 305 | } 306 | 307 | dv := data.NewValueVec(append(bytemap, b...), meta) 308 | assert.IsType(t, (*data.NullableUint16Vector)(nil), dv) 309 | } 310 | 311 | func TestNewNumericVecRequiredInt8(t *testing.T) { 312 | const N = 10 313 | b := data.Int8Traits.CastToBytes([]int8{ 314 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 315 | }) 316 | 317 | meta := &shared.SerializedField{ 318 | MajorType: &common.MajorType{ 319 | MinorType: common.MinorType_TINYINT.Enum(), 320 | Mode: common.DataMode_REQUIRED.Enum(), 321 | }, 322 | } 323 | 324 | dv := data.NewValueVec(b, meta) 325 | assert.IsType(t, (*data.Int8Vector)(nil), dv) 326 | } 327 | 328 | func TestNewNumericVecOptionalInt8(t *testing.T) { 329 | const N = 10 330 | b := data.Int8Traits.CastToBytes([]int8{ 331 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 332 | }) 333 | 334 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 335 | 336 | meta := &shared.SerializedField{ 337 | ValueCount: proto.Int32(0), 338 | MajorType: &common.MajorType{ 339 | MinorType: common.MinorType_TINYINT.Enum(), 340 | Mode: common.DataMode_OPTIONAL.Enum(), 341 | }, 342 | } 343 | 344 | dv := data.NewValueVec(append(bytemap, b...), meta) 345 | assert.IsType(t, (*data.NullableInt8Vector)(nil), dv) 346 | } 347 | 348 | func TestNewNumericVecRequiredUint8(t *testing.T) { 349 | const N = 10 350 | b := data.Uint8Traits.CastToBytes([]uint8{ 351 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 352 | }) 353 | 354 | meta := &shared.SerializedField{ 355 | MajorType: &common.MajorType{ 356 | MinorType: common.MinorType_UINT1.Enum(), 357 | Mode: common.DataMode_REQUIRED.Enum(), 358 | }, 359 | } 360 | 361 | dv := data.NewValueVec(b, meta) 362 | assert.IsType(t, (*data.Uint8Vector)(nil), dv) 363 | } 364 | 365 | func TestNewNumericVecOptionalUint8(t *testing.T) { 366 | const N = 10 367 | b := data.Uint8Traits.CastToBytes([]uint8{ 368 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 369 | }) 370 | 371 | bytemap := []byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1} 372 | 373 | meta := &shared.SerializedField{ 374 | ValueCount: proto.Int32(0), 375 | MajorType: &common.MajorType{ 376 | MinorType: common.MinorType_UINT1.Enum(), 377 | Mode: common.DataMode_OPTIONAL.Enum(), 378 | }, 379 | } 380 | 381 | dv := data.NewValueVec(append(bytemap, b...), meta) 382 | assert.IsType(t, (*data.NullableUint8Vector)(nil), dv) 383 | } 384 | -------------------------------------------------------------------------------- /internal/rpc/proto/exec/bit/ExecutionProtos.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.25.0 4 | // protoc v3.12.4 5 | // source: ExecutionProtos.proto 6 | 7 | // 8 | // Licensed to the Apache Software Foundation (ASF) under one 9 | // or more contributor license agreements. See the NOTICE file 10 | // distributed with this work for additional information 11 | // regarding copyright ownership. The ASF licenses this file 12 | // to you under the Apache License, Version 2.0 (the 13 | // "License"); you may not use this file except in compliance 14 | // with the License. You may obtain a copy of the License at 15 | // 16 | // http://www.apache.org/licenses/LICENSE-2.0 17 | // 18 | // Unless required by applicable law or agreed to in writing, software 19 | // distributed under the License is distributed on an "AS IS" BASIS, 20 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | // See the License for the specific language governing permissions and 22 | // limitations under the License. 23 | // 24 | 25 | package bit 26 | 27 | import ( 28 | proto "github.com/golang/protobuf/proto" 29 | _ "github.com/factset/go-drill/internal/rpc/proto/exec" 30 | shared "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 31 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 32 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 33 | reflect "reflect" 34 | sync "sync" 35 | ) 36 | 37 | const ( 38 | // Verify that this generated code is sufficiently up-to-date. 39 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 40 | // Verify that runtime/protoimpl is sufficiently up-to-date. 41 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 42 | ) 43 | 44 | // This is a compile-time assertion that a sufficiently up-to-date version 45 | // of the legacy proto package is being used. 46 | const _ = proto.ProtoPackageIsVersion4 47 | 48 | type FragmentHandle struct { 49 | state protoimpl.MessageState 50 | sizeCache protoimpl.SizeCache 51 | unknownFields protoimpl.UnknownFields 52 | 53 | QueryId *shared.QueryId `protobuf:"bytes,1,opt,name=query_id,json=queryId" json:"query_id,omitempty"` 54 | MajorFragmentId *int32 `protobuf:"varint,2,opt,name=major_fragment_id,json=majorFragmentId" json:"major_fragment_id,omitempty"` 55 | MinorFragmentId *int32 `protobuf:"varint,3,opt,name=minor_fragment_id,json=minorFragmentId" json:"minor_fragment_id,omitempty"` 56 | ParentQueryId *shared.QueryId `protobuf:"bytes,4,opt,name=parent_query_id,json=parentQueryId" json:"parent_query_id,omitempty"` 57 | } 58 | 59 | func (x *FragmentHandle) Reset() { 60 | *x = FragmentHandle{} 61 | if protoimpl.UnsafeEnabled { 62 | mi := &file_ExecutionProtos_proto_msgTypes[0] 63 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 64 | ms.StoreMessageInfo(mi) 65 | } 66 | } 67 | 68 | func (x *FragmentHandle) String() string { 69 | return protoimpl.X.MessageStringOf(x) 70 | } 71 | 72 | func (*FragmentHandle) ProtoMessage() {} 73 | 74 | func (x *FragmentHandle) ProtoReflect() protoreflect.Message { 75 | mi := &file_ExecutionProtos_proto_msgTypes[0] 76 | if protoimpl.UnsafeEnabled && x != nil { 77 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 78 | if ms.LoadMessageInfo() == nil { 79 | ms.StoreMessageInfo(mi) 80 | } 81 | return ms 82 | } 83 | return mi.MessageOf(x) 84 | } 85 | 86 | // Deprecated: Use FragmentHandle.ProtoReflect.Descriptor instead. 87 | func (*FragmentHandle) Descriptor() ([]byte, []int) { 88 | return file_ExecutionProtos_proto_rawDescGZIP(), []int{0} 89 | } 90 | 91 | func (x *FragmentHandle) GetQueryId() *shared.QueryId { 92 | if x != nil { 93 | return x.QueryId 94 | } 95 | return nil 96 | } 97 | 98 | func (x *FragmentHandle) GetMajorFragmentId() int32 { 99 | if x != nil && x.MajorFragmentId != nil { 100 | return *x.MajorFragmentId 101 | } 102 | return 0 103 | } 104 | 105 | func (x *FragmentHandle) GetMinorFragmentId() int32 { 106 | if x != nil && x.MinorFragmentId != nil { 107 | return *x.MinorFragmentId 108 | } 109 | return 0 110 | } 111 | 112 | func (x *FragmentHandle) GetParentQueryId() *shared.QueryId { 113 | if x != nil { 114 | return x.ParentQueryId 115 | } 116 | return nil 117 | } 118 | 119 | // 120 | // Prepared statement state on server side. Clients do not 121 | // need to know the contents. They just need to submit it back to 122 | // server when executing the prepared statement. 123 | type ServerPreparedStatementState struct { 124 | state protoimpl.MessageState 125 | sizeCache protoimpl.SizeCache 126 | unknownFields protoimpl.UnknownFields 127 | 128 | SqlQuery *string `protobuf:"bytes,1,opt,name=sql_query,json=sqlQuery" json:"sql_query,omitempty"` 129 | } 130 | 131 | func (x *ServerPreparedStatementState) Reset() { 132 | *x = ServerPreparedStatementState{} 133 | if protoimpl.UnsafeEnabled { 134 | mi := &file_ExecutionProtos_proto_msgTypes[1] 135 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 136 | ms.StoreMessageInfo(mi) 137 | } 138 | } 139 | 140 | func (x *ServerPreparedStatementState) String() string { 141 | return protoimpl.X.MessageStringOf(x) 142 | } 143 | 144 | func (*ServerPreparedStatementState) ProtoMessage() {} 145 | 146 | func (x *ServerPreparedStatementState) ProtoReflect() protoreflect.Message { 147 | mi := &file_ExecutionProtos_proto_msgTypes[1] 148 | if protoimpl.UnsafeEnabled && x != nil { 149 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 150 | if ms.LoadMessageInfo() == nil { 151 | ms.StoreMessageInfo(mi) 152 | } 153 | return ms 154 | } 155 | return mi.MessageOf(x) 156 | } 157 | 158 | // Deprecated: Use ServerPreparedStatementState.ProtoReflect.Descriptor instead. 159 | func (*ServerPreparedStatementState) Descriptor() ([]byte, []int) { 160 | return file_ExecutionProtos_proto_rawDescGZIP(), []int{1} 161 | } 162 | 163 | func (x *ServerPreparedStatementState) GetSqlQuery() string { 164 | if x != nil && x.SqlQuery != nil { 165 | return *x.SqlQuery 166 | } 167 | return "" 168 | } 169 | 170 | var File_ExecutionProtos_proto protoreflect.FileDescriptor 171 | 172 | var file_ExecutionProtos_proto_rawDesc = []byte{ 173 | 0x0a, 0x15, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 174 | 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x65, 0x78, 0x65, 0x63, 0x2e, 0x62, 0x69, 175 | 0x74, 0x1a, 0x12, 0x43, 0x6f, 0x6f, 0x72, 0x64, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 176 | 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x13, 0x55, 0x73, 0x65, 0x72, 0x42, 0x69, 0x74, 0x53, 0x68, 177 | 0x61, 0x72, 0x65, 0x64, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xd7, 0x01, 0x0a, 0x0e, 0x46, 178 | 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x48, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x2f, 0x0a, 179 | 0x08, 0x71, 0x75, 0x65, 0x72, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 180 | 0x14, 0x2e, 0x65, 0x78, 0x65, 0x63, 0x2e, 0x73, 0x68, 0x61, 0x72, 0x65, 0x64, 0x2e, 0x51, 0x75, 181 | 0x65, 0x72, 0x79, 0x49, 0x64, 0x52, 0x07, 0x71, 0x75, 0x65, 0x72, 0x79, 0x49, 0x64, 0x12, 0x2a, 182 | 0x0a, 0x11, 0x6d, 0x61, 0x6a, 0x6f, 0x72, 0x5f, 0x66, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 183 | 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0f, 0x6d, 0x61, 0x6a, 0x6f, 0x72, 184 | 0x46, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2a, 0x0a, 0x11, 0x6d, 0x69, 185 | 0x6e, 0x6f, 0x72, 0x5f, 0x66, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 186 | 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0f, 0x6d, 0x69, 0x6e, 0x6f, 0x72, 0x46, 0x72, 0x61, 0x67, 187 | 0x6d, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x3c, 0x0a, 0x0f, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 188 | 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 189 | 0x14, 0x2e, 0x65, 0x78, 0x65, 0x63, 0x2e, 0x73, 0x68, 0x61, 0x72, 0x65, 0x64, 0x2e, 0x51, 0x75, 190 | 0x65, 0x72, 0x79, 0x49, 0x64, 0x52, 0x0d, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x51, 0x75, 0x65, 191 | 0x72, 0x79, 0x49, 0x64, 0x22, 0x3b, 0x0a, 0x1c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x72, 192 | 0x65, 0x70, 0x61, 0x72, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 193 | 0x74, 0x61, 0x74, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x71, 0x6c, 0x5f, 0x71, 0x75, 0x65, 0x72, 194 | 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x71, 0x6c, 0x51, 0x75, 0x65, 0x72, 195 | 0x79, 0x42, 0x66, 0x0a, 0x1b, 0x6f, 0x72, 0x67, 0x2e, 0x61, 0x70, 0x61, 0x63, 0x68, 0x65, 0x2e, 196 | 0x64, 0x72, 0x69, 0x6c, 0x6c, 0x2e, 0x65, 0x78, 0x65, 0x63, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 197 | 0x42, 0x0a, 0x45, 0x78, 0x65, 0x63, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x48, 0x01, 0x5a, 0x39, 198 | 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x65, 0x72, 0x6f, 0x73, 199 | 0x68, 0x61, 0x64, 0x65, 0x2f, 0x67, 0x6f, 0x2d, 0x64, 0x72, 0x69, 0x6c, 0x6c, 0x2f, 0x69, 0x6e, 200 | 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 201 | 0x2f, 0x65, 0x78, 0x65, 0x63, 0x2f, 0x62, 0x69, 0x74, 202 | } 203 | 204 | var ( 205 | file_ExecutionProtos_proto_rawDescOnce sync.Once 206 | file_ExecutionProtos_proto_rawDescData = file_ExecutionProtos_proto_rawDesc 207 | ) 208 | 209 | func file_ExecutionProtos_proto_rawDescGZIP() []byte { 210 | file_ExecutionProtos_proto_rawDescOnce.Do(func() { 211 | file_ExecutionProtos_proto_rawDescData = protoimpl.X.CompressGZIP(file_ExecutionProtos_proto_rawDescData) 212 | }) 213 | return file_ExecutionProtos_proto_rawDescData 214 | } 215 | 216 | var file_ExecutionProtos_proto_msgTypes = make([]protoimpl.MessageInfo, 2) 217 | var file_ExecutionProtos_proto_goTypes = []interface{}{ 218 | (*FragmentHandle)(nil), // 0: exec.bit.FragmentHandle 219 | (*ServerPreparedStatementState)(nil), // 1: exec.bit.ServerPreparedStatementState 220 | (*shared.QueryId)(nil), // 2: exec.shared.QueryId 221 | } 222 | var file_ExecutionProtos_proto_depIdxs = []int32{ 223 | 2, // 0: exec.bit.FragmentHandle.query_id:type_name -> exec.shared.QueryId 224 | 2, // 1: exec.bit.FragmentHandle.parent_query_id:type_name -> exec.shared.QueryId 225 | 2, // [2:2] is the sub-list for method output_type 226 | 2, // [2:2] is the sub-list for method input_type 227 | 2, // [2:2] is the sub-list for extension type_name 228 | 2, // [2:2] is the sub-list for extension extendee 229 | 0, // [0:2] is the sub-list for field type_name 230 | } 231 | 232 | func init() { file_ExecutionProtos_proto_init() } 233 | func file_ExecutionProtos_proto_init() { 234 | if File_ExecutionProtos_proto != nil { 235 | return 236 | } 237 | if !protoimpl.UnsafeEnabled { 238 | file_ExecutionProtos_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 239 | switch v := v.(*FragmentHandle); i { 240 | case 0: 241 | return &v.state 242 | case 1: 243 | return &v.sizeCache 244 | case 2: 245 | return &v.unknownFields 246 | default: 247 | return nil 248 | } 249 | } 250 | file_ExecutionProtos_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { 251 | switch v := v.(*ServerPreparedStatementState); i { 252 | case 0: 253 | return &v.state 254 | case 1: 255 | return &v.sizeCache 256 | case 2: 257 | return &v.unknownFields 258 | default: 259 | return nil 260 | } 261 | } 262 | } 263 | type x struct{} 264 | out := protoimpl.TypeBuilder{ 265 | File: protoimpl.DescBuilder{ 266 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 267 | RawDescriptor: file_ExecutionProtos_proto_rawDesc, 268 | NumEnums: 0, 269 | NumMessages: 2, 270 | NumExtensions: 0, 271 | NumServices: 0, 272 | }, 273 | GoTypes: file_ExecutionProtos_proto_goTypes, 274 | DependencyIndexes: file_ExecutionProtos_proto_depIdxs, 275 | MessageInfos: file_ExecutionProtos_proto_msgTypes, 276 | }.Build() 277 | File_ExecutionProtos_proto = out.File 278 | file_ExecutionProtos_proto_rawDesc = nil 279 | file_ExecutionProtos_proto_goTypes = nil 280 | file_ExecutionProtos_proto_depIdxs = nil 281 | } 282 | -------------------------------------------------------------------------------- /auth_test.go: -------------------------------------------------------------------------------- 1 | package drill 2 | 3 | import ( 4 | "io" 5 | "math" 6 | "net" 7 | "testing" 8 | "time" 9 | 10 | "github.com/factset/go-drill/internal/rpc/proto/exec" 11 | "github.com/factset/go-drill/internal/rpc/proto/exec/rpc" 12 | "github.com/factset/go-drill/internal/rpc/proto/exec/shared" 13 | "github.com/factset/go-drill/internal/rpc/proto/exec/user" 14 | "github.com/factset/go-drill/sasl" 15 | "github.com/jcmturner/gokrb5/v8/gssapi" 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/mock" 18 | "google.golang.org/protobuf/proto" 19 | ) 20 | 21 | type mockConn struct { 22 | r io.Reader 23 | mock.Mock 24 | } 25 | 26 | func (m *mockConn) Close() error { 27 | return m.Called().Error(0) 28 | } 29 | 30 | func (m *mockConn) LocalAddr() net.Addr { return nil } 31 | func (m *mockConn) RemoteAddr() net.Addr { return nil } 32 | func (m *mockConn) SetDeadline(_ time.Time) error { return nil } 33 | func (m *mockConn) SetReadDeadline(_ time.Time) error { return nil } 34 | func (m *mockConn) SetWriteDeadline(_ time.Time) error { return nil } 35 | 36 | func (m *mockConn) Read(b []byte) (int, error) { 37 | m.Called() 38 | return m.r.Read(b) 39 | } 40 | 41 | func (m *mockConn) Write(b []byte) (int, error) { 42 | args := m.Called(b) 43 | return args.Int(0), args.Error(1) 44 | } 45 | 46 | type mockEncoder struct { 47 | mock.Mock 48 | } 49 | 50 | func (m *mockEncoder) WriteRaw(_ net.Conn, b []byte) (int, error) { 51 | args := m.Called(b) 52 | return args.Int(0), args.Error(1) 53 | } 54 | 55 | func (m *mockEncoder) Write(_ net.Conn, mode rpc.RpcMode, typ user.RpcType, coord int32, msg proto.Message) (int, error) { 56 | val, _ := proto.Marshal(msg) 57 | args := m.Called(mode, typ, coord, val) 58 | return args.Int(0), args.Error(1) 59 | } 60 | 61 | func (m *mockEncoder) ReadMsg(_ net.Conn, msg proto.Message) (*rpc.RpcHeader, error) { 62 | args := m.Called(msg) 63 | return args.Get(0).(*rpc.RpcHeader), args.Error(1) 64 | } 65 | 66 | func (m *mockEncoder) ReadRaw(net.Conn) (*rpc.CompleteRpcMessage, error) { 67 | args := m.Called() 68 | return args.Get(0).(*rpc.CompleteRpcMessage), args.Error(1) 69 | } 70 | 71 | var initialUserToBit = []byte{0x8, 0x2, 0x10, 0x1, 0x18, 0x5, 0x22, 0x2, 0xa, 0x0, 0x2a, 0x1a, 0xa, 0xa, 0xa, 0x6, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x12, 0x0, 0xa, 0xc, 0xa, 0x8, 0x75, 0x73, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x0, 0x30, 0x0, 0x38, 0x1, 0x42, 0x2c, 0xa, 0x1a, 0x41, 0x70, 0x61, 0x63, 0x68, 0x65, 0x20, 0x44, 0x72, 0x69, 0x6c, 0x6c, 0x20, 0x47, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x20, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x12, 0x6, 0x31, 0x2e, 0x31, 0x37, 0x2e, 0x30, 0x18, 0x1, 0x20, 0x11, 0x28, 0x0, 0x32, 0x0, 0x48, 0x2} 72 | 73 | func TestClientDoHandshake(t *testing.T) { 74 | tests := []struct { 75 | name string 76 | opts Options 77 | status *user.HandshakeStatus 78 | err bool 79 | errmsg string 80 | }{ 81 | {"successful", Options{}, user.HandshakeStatus_SUCCESS.Enum(), false, ""}, 82 | {"rpc mismatch", Options{}, user.HandshakeStatus_RPC_VERSION_MISMATCH.Enum(), true, "invalid rpc version, expected: 5, actual: 10"}, 83 | {"auth fail", Options{}, user.HandshakeStatus_AUTH_FAILED.Enum(), true, "authentication failure"}, 84 | {"unknown failure", Options{}, user.HandshakeStatus_UNKNOWN_FAILURE.Enum(), true, "unknown handshake failure"}, 85 | {"invalid security", Options{SaslEncrypt: true}, user.HandshakeStatus_SUCCESS.Enum(), true, "invalid security options"}, 86 | {"client auth, not server", Options{Auth: "booya"}, user.HandshakeStatus_SUCCESS.Enum(), true, "client wanted auth, but server didn't require it"}, 87 | {"calls handle auth", Options{Auth: "booya"}, user.HandshakeStatus_AUTH_REQUIRED.Enum(), true, "client wants encryption, server doesn't support encryption"}, 88 | } 89 | 90 | for _, tt := range tests { 91 | t.Run(tt.name, func(t *testing.T) { 92 | m := new(mockEncoder) 93 | m.Test(t) 94 | defer m.AssertExpectations(t) 95 | 96 | m.On("Write", rpc.RpcMode_REQUEST, user.RpcType_HANDSHAKE, int32(1), initialUserToBit).Return(0, nil) 97 | m.On("ReadMsg", mock.AnythingOfType("*user.BitToUserHandshake")).Return((*rpc.RpcHeader)(nil), nil).Run(func(args mock.Arguments) { 98 | info := args.Get(0).(*user.BitToUserHandshake) 99 | info.Status = tt.status 100 | info.RpcVersion = proto.Int32(10) 101 | }) 102 | 103 | cl := Client{dataEncoder: m, coordID: 1, Opts: tt.opts} 104 | err := cl.doHandshake() 105 | if !tt.err { 106 | assert.NoError(t, err) 107 | return 108 | } 109 | 110 | assert.EqualError(t, err, tt.errmsg) 111 | }) 112 | } 113 | } 114 | 115 | func TestClientHandshakeWriteFailure(t *testing.T) { 116 | m := new(mockEncoder) 117 | m.Test(t) 118 | defer m.AssertExpectations(t) 119 | 120 | m.On("Write", rpc.RpcMode_REQUEST, user.RpcType_HANDSHAKE, int32(1), initialUserToBit).Return(0, assert.AnError) 121 | cl := Client{dataEncoder: m, coordID: 1} 122 | assert.Same(t, assert.AnError, cl.doHandshake()) 123 | } 124 | 125 | func TestClientHandshakeReadFailure(t *testing.T) { 126 | m := new(mockEncoder) 127 | m.Test(t) 128 | defer m.AssertExpectations(t) 129 | 130 | m.On("Write", rpc.RpcMode_REQUEST, user.RpcType_HANDSHAKE, int32(1), initialUserToBit).Return(0, nil) 131 | m.On("ReadMsg", mock.AnythingOfType("*user.BitToUserHandshake")).Return((*rpc.RpcHeader)(nil), assert.AnError) 132 | 133 | cl := Client{dataEncoder: m, coordID: 1} 134 | assert.Same(t, assert.AnError, cl.doHandshake()) 135 | } 136 | 137 | type mockWrapper struct { 138 | mock.Mock 139 | } 140 | 141 | func (m *mockWrapper) InitAuthPayload() ([]byte, error) { 142 | args := m.Called() 143 | return args.Get(0).([]byte), args.Error(1) 144 | } 145 | 146 | func (m *mockWrapper) Step(b []byte) ([]byte, gssapi.Status) { 147 | args := m.Called(b) 148 | return args.Get(0).([]byte), args.Get(1).(gssapi.Status) 149 | } 150 | 151 | func (m *mockWrapper) GetWrappedConn(c net.Conn) net.Conn { 152 | return m.Called(c).Get(0).(net.Conn) 153 | } 154 | 155 | func TestClientHandleAuth(t *testing.T) { 156 | defer func(orig func(string, string, sasl.SecurityProps) (sasl.Wrapper, error)) { 157 | createSasl = orig 158 | }(createSasl) 159 | 160 | opts := Options{ 161 | ServiceHost: "hoster", 162 | User: "edelgard", 163 | ServiceName: "fire emblem", 164 | Auth: "kerberos", 165 | SaslEncrypt: true, 166 | } 167 | 168 | hostopts := Options{ 169 | ServiceHost: "_HOST", 170 | Auth: "kerberos", 171 | User: "kirby", 172 | ServiceName: "superstar", 173 | } 174 | 175 | serverInfo := &user.BitToUserHandshake{ 176 | MaxWrappedSize: proto.Int32(6555), 177 | Encrypted: proto.Bool(true), 178 | } 179 | 180 | tests := []struct { 181 | name string 182 | opts Options 183 | saslHost string 184 | sinfo *user.BitToUserHandshake 185 | errWhere string 186 | }{ 187 | {"successful test", opts, "fire emblem/hoster", serverInfo, ""}, 188 | {"check _HOST", hostopts, "superstar/adder.com", serverInfo, ""}, 189 | {"createSasl errors", opts, "fire emblem/hoster", serverInfo, "createSasl"}, 190 | {"InitAuthPayload error", opts, "fire emblem/hoster", serverInfo, "initauth"}, 191 | {"read start fail", opts, "fire emblem/hoster", serverInfo, "saslStart"}, 192 | {"step status error", opts, "fire emblem/hoster", serverInfo, "stepStatus"}, 193 | {"sasl read error", opts, "fire emblem/hoster", serverInfo, "saslRead"}, 194 | } 195 | 196 | for _, tt := range tests { 197 | t.Run(tt.name, func(t *testing.T) { 198 | wrapper := new(mockWrapper) 199 | wrapper.Test(t) 200 | 201 | enc := new(mockEncoder) 202 | enc.Test(t) 203 | if tt.errWhere == "" { 204 | defer wrapper.AssertExpectations(t) 205 | defer enc.AssertExpectations(t) 206 | } 207 | 208 | createSasl = func(user, service string, props sasl.SecurityProps) (sasl.Wrapper, error) { 209 | assert.Equal(t, tt.opts.User, user) 210 | assert.Equal(t, tt.saslHost, service) 211 | assert.Equal(t, sasl.SecurityProps{ 212 | MinSsf: 56, 213 | MaxSsf: math.MaxUint32, 214 | MaxBufSize: *tt.sinfo.MaxWrappedSize, 215 | UseEncryption: *tt.sinfo.Encrypted, 216 | }, props) 217 | 218 | if tt.errWhere == "createSasl" { 219 | return nil, assert.AnError 220 | } 221 | 222 | return wrapper, nil 223 | } 224 | 225 | cl := Client{ 226 | dataEncoder: enc, 227 | Opts: tt.opts, 228 | serverInfo: tt.sinfo, 229 | endpoint: &exec.DrillbitEndpoint{Address: proto.String("adder.com")}, 230 | } 231 | 232 | if tt.errWhere == "initauth" { 233 | wrapper.On("InitAuthPayload").Return([]byte{}, assert.AnError) 234 | } else { 235 | // first we'll get the initialization auth payload 236 | wrapper.On("InitAuthPayload").Return(deadbeef, nil).Once() 237 | } 238 | 239 | // then we write that same payload to the socket, wrapped with a SASL_START message 240 | enc.On("Write", rpc.RpcMode_REQUEST, user.RpcType_SASL_MESSAGE, int32(0), 241 | []byte{0xa, 0x8, 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73, 0x12, 0x4, 0xde, 0xad, 0xbe, 0xef, 0x18, 0x1}).Return(0, nil).Once() 242 | 243 | // then we receive a response which we're gonna pass the Data value to Step 244 | call := enc.On("ReadMsg", mock.AnythingOfType("*shared.SaslMessage")).Run(func(args mock.Arguments) { 245 | msg := args.Get(0).(*shared.SaslMessage) 246 | msg.Status = shared.SaslStatus_SASL_IN_PROGRESS.Enum() 247 | msg.Data = deadbeef 248 | }).Once() 249 | if tt.errWhere == "saslStart" { 250 | call.Return((*rpc.RpcHeader)(nil), assert.AnError) 251 | } else { 252 | call.Return((*rpc.RpcHeader)(nil), nil) 253 | } 254 | 255 | if tt.errWhere == "stepStatus" { 256 | wrapper.On("Step", deadbeef).Return(append(deadbeef, deadbeef...), gssapi.Status{Code: gssapi.StatusFailure}).Once() 257 | } else { 258 | // Step the wrapper to get the next token to return 259 | wrapper.On("Step", deadbeef).Return(append(deadbeef, deadbeef...), gssapi.Status{Code: gssapi.StatusContinueNeeded}).Once() 260 | } 261 | // Write that request wrapped correctly again 262 | enc.On("Write", rpc.RpcMode_REQUEST, user.RpcType_SASL_MESSAGE, int32(1), 263 | []byte{0x12, 0x8, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0x18, 0x2}).Return(10, nil).Once() 264 | 265 | // this response has a different payload to ensure we use the new payload in the next step 266 | enc.On("ReadMsg", mock.AnythingOfType("*shared.SaslMessage")).Return((*rpc.RpcHeader)(nil), nil).Run(func(args mock.Arguments) { 267 | msg := args.Get(0).(*shared.SaslMessage) 268 | msg.Status = shared.SaslStatus_SASL_IN_PROGRESS.Enum() 269 | msg.Data = append(deadbeef, deadbeef...) 270 | }).Once() 271 | 272 | // the step gets the new payload and returns that we've completed our auth 273 | wrapper.On("Step", append(deadbeef, deadbeef...)).Return([]byte{}, gssapi.Status{Code: gssapi.StatusComplete}).Once() 274 | // write the auth complete message with SASL_SUCCESS 275 | enc.On("Write", rpc.RpcMode_REQUEST, user.RpcType_SASL_MESSAGE, int32(2), []byte{0x12, 0x0, 0x018, 0x3}).Return(0, nil).Once() 276 | 277 | if tt.errWhere == "saslRead" { 278 | enc.On("ReadMsg", mock.AnythingOfType("*shared.SaslMessage")).Return((*rpc.RpcHeader)(nil), assert.AnError) 279 | } else { 280 | // read the confirmation from the service which has SASL_SUCCESS which breaks the loop 281 | enc.On("ReadMsg", mock.AnythingOfType("*shared.SaslMessage")).Return((*rpc.RpcHeader)(nil), nil).Run(func(args mock.Arguments) { 282 | msg := args.Get(0).(*shared.SaslMessage) 283 | msg.Status = shared.SaslStatus_SASL_SUCCESS.Enum() 284 | }).Once() 285 | } 286 | 287 | m := new(mockConn) 288 | // make sure we are wrapping the connection in the client 289 | wrapper.On("GetWrappedConn", nil).Return(m) 290 | 291 | err := cl.handleAuth() 292 | if tt.errWhere == "" { 293 | assert.NoError(t, err) 294 | assert.Same(t, m, cl.conn) 295 | } else { 296 | if tt.errWhere == "stepStatus" { 297 | assert.EqualError(t, err, gssapi.Status{Code: gssapi.StatusFailure}.Error()) 298 | } else { 299 | assert.Same(t, assert.AnError, err) 300 | } 301 | assert.Nil(t, cl.conn) 302 | } 303 | }) 304 | } 305 | } 306 | --------------------------------------------------------------------------------