├── .gitattributes ├── LICENSE ├── Makefile ├── README.md ├── api.go ├── conn.go ├── db_test.go ├── driver.go ├── go.mod ├── go.sum ├── rows.go ├── rows_test.go └── value.go /.gitattributes: -------------------------------------------------------------------------------- 1 | go.sum linguist-generated=true 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Segment.io, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | test: 3 | go test ./... 4 | 5 | .PHONY: fmt 6 | fmt: 7 | go fmt ./... 8 | 9 | .PHONY: vet 10 | vet: 11 | go vet ./... 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![](https://godoc.org/github.com/segmentio/go-athena?status.svg)](https://godoc.org/github.com/segmentio/go-athena) 2 | # go-athena 3 | 4 | go-athena is a simple Golang [database/sql] driver for [Amazon Athena](https://aws.amazon.com/athena/). 5 | 6 | ```go 7 | import ( 8 | "database/sql" 9 | _ "github.com/segmentio/go-athena" 10 | ) 11 | 12 | func main() { 13 | db, _ := sql.Open("athena", "db=default&output_location=s3://results") 14 | rows, _ := db.Query("SELECT url, code from cloudfront") 15 | 16 | for rows.Next() { 17 | var url string 18 | var code int 19 | rows.Scan(&url, &code) 20 | } 21 | } 22 | 23 | ``` 24 | 25 | It provides a higher-level, idiomatic wrapper over the 26 | [AWS Go SDK](https://docs.aws.amazon.com/sdk-for-go/api/service/athena/), 27 | comparable to the [Athena JDBC driver](http://docs.aws.amazon.com/athena/latest/ug/athena-jdbc-driver.html) 28 | AWS provides for Java users. 29 | 30 | For example, 31 | 32 | - Instead of manually parsing types from strings, you can use [database/sql.Rows.Scan()](https://golang.org/pkg/database/sql/#Rows.Scan) 33 | - Instead of reaching for semaphores, you can use [database/sql.DB.SetMaxOpenConns](https://golang.org/pkg/database/sql/#DB.SetMaxOpenConns) 34 | - And, so on... 35 | 36 | 37 | ## Caveats 38 | 39 | [database/sql] exposes lots of methods that aren't supported in Athena. 40 | For example, Athena doesn't support transactions so `Begin()` is irrelevant. 41 | If a method must be supplied to satisfy a standard library interface but is unsupported, 42 | the driver will **panic** indicating so. If there are new offerings in Athena and/or 43 | helpful additions, feel free to PR. 44 | 45 | 46 | ## Testing 47 | 48 | Athena doesn't have a local version and revolves around S3 so our tests are 49 | integration tests against AWS itself. Thus, our tests require AWS credentials. 50 | The simplest way to provide them is via `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` 51 | environment variables, but you can use anything supported by the 52 | [Default Credential Provider Chain]. 53 | 54 | The tests support a few environment variables: 55 | - `ATHENA_DATABASE` can be used to override the default database "go_athena_tests" 56 | - `S3_BUCKET` can be used to override the default S3 bucket of "go-athena-tests" 57 | 58 | 59 | [database/sql]: https://golang.org/pkg/database/sql/ 60 | [Default Credential Provider Chain]: http://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default 61 | -------------------------------------------------------------------------------- /api.go: -------------------------------------------------------------------------------- 1 | package athena 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/aws/aws-sdk-go-v2/service/athena" 7 | ) 8 | 9 | type athenaAPI interface { 10 | GetQueryExecution(context.Context, *athena.GetQueryExecutionInput, ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) 11 | GetQueryResults(context.Context, *athena.GetQueryResultsInput, ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) 12 | StartQueryExecution(context.Context, *athena.StartQueryExecutionInput, ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) 13 | StopQueryExecution(context.Context, *athena.StopQueryExecutionInput, ...func(*athena.Options)) (*athena.StopQueryExecutionOutput, error) 14 | } 15 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package athena 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "errors" 7 | "time" 8 | 9 | "github.com/aws/aws-sdk-go-v2/aws" 10 | "github.com/aws/aws-sdk-go-v2/service/athena" 11 | "github.com/aws/aws-sdk-go-v2/service/athena/types" 12 | ) 13 | 14 | type conn struct { 15 | athena athenaAPI 16 | db string 17 | OutputLocation string 18 | 19 | pollFrequency time.Duration 20 | } 21 | 22 | func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 23 | if len(args) > 0 { 24 | panic("The go-athena driver doesn't support prepared statements yet. Format your own arguments.") 25 | } 26 | 27 | rows, err := c.runQuery(ctx, query) 28 | return rows, err 29 | } 30 | 31 | func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 32 | if len(args) > 0 { 33 | panic("The go-athena driver doesn't support prepared statements yet. Format your own arguments.") 34 | } 35 | 36 | _, err := c.runQuery(ctx, query) 37 | return nil, err 38 | } 39 | 40 | func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) { 41 | queryID, err := c.startQuery(ctx, query) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | if err := c.waitOnQuery(ctx, queryID); err != nil { 47 | return nil, err 48 | } 49 | 50 | return newRows(ctx, rowsConfig{ 51 | Athena: c.athena, 52 | QueryID: queryID, 53 | // todo add check for ddl queries to not skip header(#10) 54 | SkipHeader: true, 55 | }) 56 | } 57 | 58 | // startQuery starts an Athena query and returns its ID. 59 | func (c *conn) startQuery(ctx context.Context, query string) (string, error) { 60 | resp, err := c.athena.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{ 61 | QueryString: aws.String(query), 62 | QueryExecutionContext: &types.QueryExecutionContext{ 63 | Database: aws.String(c.db), 64 | }, 65 | ResultConfiguration: &types.ResultConfiguration{ 66 | OutputLocation: aws.String(c.OutputLocation), 67 | }, 68 | }) 69 | if err != nil { 70 | return "", err 71 | } 72 | 73 | return *resp.QueryExecutionId, nil 74 | } 75 | 76 | // waitOnQuery blocks until a query finishes, returning an error if it failed. 77 | func (c *conn) waitOnQuery(ctx context.Context, queryID string) error { 78 | for { 79 | statusResp, err := c.athena.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{ 80 | QueryExecutionId: aws.String(queryID), 81 | }) 82 | if err != nil { 83 | return err 84 | } 85 | 86 | switch statusResp.QueryExecution.Status.State { 87 | case types.QueryExecutionStateCancelled: 88 | return context.Canceled 89 | case types.QueryExecutionStateFailed: 90 | reason := *statusResp.QueryExecution.Status.StateChangeReason 91 | return errors.New(reason) 92 | case types.QueryExecutionStateSucceeded: 93 | return nil 94 | case types.QueryExecutionStateQueued: 95 | case types.QueryExecutionStateRunning: 96 | } 97 | 98 | select { 99 | case <-ctx.Done(): 100 | c.athena.StopQueryExecution(ctx, &athena.StopQueryExecutionInput{ 101 | QueryExecutionId: aws.String(queryID), 102 | }) 103 | 104 | return ctx.Err() 105 | case <-time.After(c.pollFrequency): 106 | continue 107 | } 108 | } 109 | } 110 | 111 | func (c *conn) Prepare(query string) (driver.Stmt, error) { 112 | panic("The go-athena driver doesn't support prepared statements yet") 113 | } 114 | 115 | func (c *conn) Begin() (driver.Tx, error) { 116 | panic("Athena doesn't support transactions") 117 | } 118 | 119 | func (c *conn) Close() error { 120 | return nil 121 | } 122 | 123 | var _ driver.QueryerContext = (*conn)(nil) 124 | var _ driver.ExecerContext = (*conn)(nil) 125 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | package athena 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "encoding/json" 8 | "fmt" 9 | "os" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "github.com/aws/aws-sdk-go-v2/aws" 15 | "github.com/aws/aws-sdk-go-v2/config" 16 | "github.com/aws/aws-sdk-go-v2/service/s3" 17 | uuid "github.com/satori/go.uuid" 18 | "github.com/stretchr/testify/assert" 19 | "github.com/stretchr/testify/require" 20 | ) 21 | 22 | var ( 23 | AthenaDatabase = "go_athena_tests" 24 | S3Bucket = "go-athena-tests" 25 | ) 26 | 27 | func init() { 28 | if v := os.Getenv("ATHENA_DATABASE"); v != "" { 29 | AthenaDatabase = v 30 | } 31 | 32 | if v := os.Getenv("S3_BUCKET"); v != "" { 33 | S3Bucket = v 34 | } 35 | } 36 | 37 | func TestQuery(t *testing.T) { 38 | ctx := context.Background() 39 | harness := setup(ctx, t) 40 | // defer harness.teardown(ctx) 41 | 42 | expected := []dummyRow{ 43 | { 44 | SmallintType: 1, 45 | IntType: 2, 46 | BigintType: 3, 47 | BooleanType: true, 48 | FloatType: 3.14159, 49 | DoubleType: 1.32112345, 50 | StringType: "some string", 51 | TimestampType: athenaTimestamp(time.Date(2006, 1, 2, 3, 4, 11, 0, time.UTC)), 52 | DateType: athenaDate(time.Date(2006, 1, 2, 0, 0, 0, 0, time.UTC)), 53 | DecimalType: 1001, 54 | }, 55 | { 56 | SmallintType: 9, 57 | IntType: 8, 58 | BigintType: 0, 59 | BooleanType: false, 60 | FloatType: 3.14159, 61 | DoubleType: 1.235, 62 | StringType: "another string", 63 | TimestampType: athenaTimestamp(time.Date(2017, 12, 3, 1, 11, 12, 0, time.UTC)), 64 | DateType: athenaDate(time.Date(2017, 12, 3, 0, 0, 0, 0, time.UTC)), 65 | DecimalType: 0, 66 | }, 67 | { 68 | SmallintType: 9, 69 | IntType: 8, 70 | BigintType: 0, 71 | BooleanType: false, 72 | DoubleType: 1.235, 73 | FloatType: 3.14159, 74 | StringType: "another string", 75 | TimestampType: athenaTimestamp(time.Date(2017, 12, 3, 20, 11, 12, 0, time.UTC)), 76 | DateType: athenaDate(time.Date(2017, 12, 3, 0, 0, 0, 0, time.UTC)), 77 | DecimalType: 0.48, 78 | }, 79 | } 80 | expectedTypeNames := []string{"varchar", "smallint", "integer", "bigint", "boolean", "float", "double", "varchar", "timestamp", "date", "decimal"} 81 | harness.uploadData(ctx, expected) 82 | 83 | rows := harness.mustQuery(ctx, "select * from %s", harness.table) 84 | index := -1 85 | for rows.Next() { 86 | index++ 87 | 88 | var row dummyRow 89 | require.NoError(t, rows.Scan( 90 | &row.NullValue, 91 | 92 | &row.SmallintType, 93 | &row.IntType, 94 | &row.BigintType, 95 | &row.BooleanType, 96 | &row.FloatType, 97 | &row.DoubleType, 98 | &row.StringType, 99 | &row.TimestampType, 100 | &row.DateType, 101 | &row.DecimalType, 102 | )) 103 | 104 | assert.Equal(t, expected[index], row, fmt.Sprintf("index: %d", index)) 105 | 106 | types, err := rows.ColumnTypes() 107 | assert.NoError(t, err, fmt.Sprintf("index: %d", index)) 108 | for i, colType := range types { 109 | typeName := colType.DatabaseTypeName() 110 | assert.Equal(t, expectedTypeNames[i], typeName, fmt.Sprintf("index: %d", index)) 111 | } 112 | } 113 | 114 | require.NoError(t, rows.Err(), "rows.Err()") 115 | require.Equal(t, 3, index+1, "row count") 116 | } 117 | 118 | func TestOpen(t *testing.T) { 119 | awsConfig, err := config.LoadDefaultConfig(context.Background()) 120 | require.NoError(t, err, "LoadDefaultConfig") 121 | db, err := Open(DriverConfig{ 122 | Config: &awsConfig, 123 | Database: AthenaDatabase, 124 | OutputLocation: fmt.Sprintf("s3://%s/noop", S3Bucket), 125 | }) 126 | require.NoError(t, err, "Open") 127 | 128 | _, err = db.Query("SELECT 1") 129 | require.NoError(t, err, "Query") 130 | } 131 | 132 | type dummyRow struct { 133 | NullValue *struct{} `json:"nullValue"` 134 | SmallintType int `json:"smallintType"` 135 | IntType int `json:"intType"` 136 | BigintType int `json:"bigintType"` 137 | BooleanType bool `json:"booleanType"` 138 | FloatType float32 `json:"floatType"` 139 | DoubleType float64 `json:"doubleType"` 140 | StringType string `json:"stringType"` 141 | TimestampType athenaTimestamp `json:"timestampType"` 142 | DateType athenaDate `json:"dateType"` 143 | DecimalType float64 `json:"decimalType"` 144 | } 145 | 146 | type athenaHarness struct { 147 | t *testing.T 148 | db *sql.DB 149 | s3 *s3.Client 150 | 151 | table string 152 | } 153 | 154 | func setup(ctx context.Context, t *testing.T) *athenaHarness { 155 | awsConfig, err := config.LoadDefaultConfig(ctx) 156 | require.NoError(t, err) 157 | harness := athenaHarness{t: t, s3: s3.NewFromConfig(awsConfig)} 158 | 159 | harness.db, err = sql.Open("athena", fmt.Sprintf("db=%s&output_location=s3://%s/output", AthenaDatabase, S3Bucket)) 160 | require.NoError(t, err) 161 | 162 | harness.setupTable(ctx) 163 | 164 | return &harness 165 | } 166 | 167 | func (a *athenaHarness) setupTable(ctx context.Context) { 168 | // tables cannot start with numbers or contain dashes 169 | id := uuid.NewV4() 170 | a.table = "t_" + strings.Replace(id.String(), "-", "_", -1) 171 | a.mustExec(ctx, `CREATE EXTERNAL TABLE %[1]s ( 172 | nullValue string, 173 | smallintType smallint, 174 | intType int, 175 | bigintType bigint, 176 | booleanType boolean, 177 | floatType float, 178 | doubleType double, 179 | stringType string, 180 | timestampType timestamp, 181 | dateType date, 182 | decimalType decimal(11, 5) 183 | ) 184 | ROW FORMAT SERDE 'org.openx.data.jsonserde.JsonSerDe' 185 | WITH SERDEPROPERTIES ( 186 | 'serialization.format' = '1' 187 | ) LOCATION 's3://%[2]s/%[1]s/';`, a.table, S3Bucket) 188 | fmt.Printf("created table: %s", a.table) 189 | } 190 | 191 | func (a *athenaHarness) teardown(ctx context.Context) { 192 | a.mustExec(ctx, "drop table %s", a.table) 193 | } 194 | 195 | func (a *athenaHarness) mustExec(ctx context.Context, sql string, args ...interface{}) { 196 | query := fmt.Sprintf(sql, args...) 197 | _, err := a.db.ExecContext(ctx, query) 198 | require.NoError(a.t, err, query) 199 | } 200 | 201 | func (a *athenaHarness) mustQuery(ctx context.Context, sql string, args ...interface{}) *sql.Rows { 202 | query := fmt.Sprintf(sql, args...) 203 | rows, err := a.db.QueryContext(ctx, query) 204 | require.NoError(a.t, err, query) 205 | return rows 206 | } 207 | 208 | func (a *athenaHarness) uploadData(ctx context.Context, rows []dummyRow) { 209 | var buf bytes.Buffer 210 | enc := json.NewEncoder(&buf) 211 | for _, row := range rows { 212 | err := enc.Encode(row) 213 | require.NoError(a.t, err) 214 | } 215 | 216 | _, err := a.s3.PutObject(ctx, &s3.PutObjectInput{ 217 | Bucket: aws.String(S3Bucket), 218 | Key: aws.String(fmt.Sprintf("%s/fixture.json", a.table)), 219 | Body: bytes.NewReader(buf.Bytes()), 220 | }) 221 | require.NoError(a.t, err) 222 | } 223 | 224 | type athenaTimestamp time.Time 225 | 226 | func (t athenaTimestamp) MarshalJSON() ([]byte, error) { 227 | return json.Marshal(t.String()) 228 | } 229 | 230 | func (t athenaTimestamp) String() string { 231 | return time.Time(t).Format(TimestampLayout) 232 | } 233 | 234 | func (t athenaTimestamp) Equal(t2 athenaTimestamp) bool { 235 | return time.Time(t).Equal(time.Time(t2)) 236 | } 237 | 238 | type athenaDate time.Time 239 | 240 | func (t athenaDate) MarshalJSON() ([]byte, error) { 241 | return json.Marshal(t.String()) 242 | } 243 | 244 | func (t athenaDate) String() string { 245 | return time.Time(t).Format(DateLayout) 246 | } 247 | 248 | func (t athenaDate) Equal(t2 athenaDate) bool { 249 | return time.Time(t).Equal(time.Time(t2)) 250 | } 251 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | package athena 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "errors" 8 | "fmt" 9 | "net/url" 10 | "sync" 11 | "time" 12 | 13 | "github.com/aws/aws-sdk-go-v2/aws" 14 | "github.com/aws/aws-sdk-go-v2/config" 15 | "github.com/aws/aws-sdk-go-v2/service/athena" 16 | ) 17 | 18 | var ( 19 | openFromSessionMutex sync.Mutex 20 | openFromSessionCount int 21 | ) 22 | 23 | // Driver is a sql.Driver. It's intended for db/sql.Open(). 24 | type Driver struct { 25 | cfg *DriverConfig 26 | } 27 | 28 | // NewDriver allows you to register your own driver with `sql.Register`. 29 | // It's useful for more complex use cases. Read more in PR #3. 30 | // https://github.com/segmentio/go-athena/pull/3 31 | // 32 | // Generally, sql.Open() or athena.Open() should suffice. 33 | func NewDriver(cfg *DriverConfig) *Driver { 34 | return &Driver{cfg} 35 | } 36 | 37 | func init() { 38 | var drv driver.Driver = &Driver{} 39 | sql.Register("athena", drv) 40 | } 41 | 42 | // Open should be used via `db/sql.Open("athena", "")`. 43 | // The following parameters are supported in URI query format (k=v&k2=v2&...) 44 | // 45 | // - `db` (required) 46 | // This is the Athena database name. In the UI, this defaults to "default", 47 | // but the driver requires it regardless. 48 | // 49 | // - `output_location` (required) 50 | // This is the S3 location Athena will dump query results in the format 51 | // "s3://bucket/and/so/forth". In the AWS UI, this defaults to 52 | // "s3://aws-athena-query-results--", but the driver requires it. 53 | // 54 | // - `poll_frequency` (optional) 55 | // Athena's API requires polling to retrieve query results. This is the frequency at 56 | // which the driver will poll for results. It should be a time/Duration.String(). 57 | // A completely arbitrary default of "5s" was chosen. 58 | // 59 | // - `region` (optional) 60 | // Override AWS region. Useful if it is not set with environment variable. 61 | // 62 | // Credentials must be accessible via the SDK's Default Credential Provider Chain. 63 | // For more advanced AWS credentials/session/config management, please supply 64 | // a custom AWS session directly via `athena.Open()`. 65 | func (d *Driver) Open(connStr string) (driver.Conn, error) { 66 | cfg := d.cfg 67 | if cfg == nil { 68 | var err error 69 | // TODO: Implement DriverContext to get proper access to context 70 | cfg, err = configFromConnectionString(context.TODO(), connStr) 71 | if err != nil { 72 | return nil, err 73 | } 74 | } 75 | 76 | if cfg.PollFrequency == 0 { 77 | cfg.PollFrequency = 5 * time.Second 78 | } 79 | 80 | return &conn{ 81 | athena: athena.NewFromConfig(*cfg.Config), 82 | db: cfg.Database, 83 | OutputLocation: cfg.OutputLocation, 84 | pollFrequency: cfg.PollFrequency, 85 | }, nil 86 | } 87 | 88 | // Open is a more robust version of `db.Open`, as it accepts a raw aws.Session. 89 | // This is useful if you have a complex AWS session since the driver doesn't 90 | // currently attempt to serialize all options into a string. 91 | func Open(cfg DriverConfig) (*sql.DB, error) { 92 | if cfg.Database == "" { 93 | return nil, errors.New("db is required") 94 | } 95 | 96 | if cfg.OutputLocation == "" { 97 | return nil, errors.New("s3_staging_url is required") 98 | } 99 | 100 | if cfg.Config == nil { 101 | return nil, errors.New("AWS config is required") 102 | } 103 | 104 | // This hack was copied from jackc/pgx. Sorry :( 105 | // https://github.com/jackc/pgx/blob/70a284f4f33a9cc28fd1223f6b83fb00deecfe33/stdlib/sql.go#L130-L136 106 | openFromSessionMutex.Lock() 107 | openFromSessionCount++ 108 | name := fmt.Sprintf("athena-%d", openFromSessionCount) 109 | openFromSessionMutex.Unlock() 110 | 111 | sql.Register(name, &Driver{&cfg}) 112 | return sql.Open(name, "") 113 | } 114 | 115 | // Config is the input to Open(). 116 | type DriverConfig struct { 117 | Config *aws.Config 118 | Database string 119 | OutputLocation string 120 | 121 | PollFrequency time.Duration 122 | } 123 | 124 | func configFromConnectionString(ctx context.Context, connStr string) (*DriverConfig, error) { 125 | args, err := url.ParseQuery(connStr) 126 | if err != nil { 127 | return nil, err 128 | } 129 | 130 | var cfg DriverConfig 131 | 132 | awsConfig, err := config.LoadDefaultConfig(ctx) 133 | if err != nil { 134 | return nil, err 135 | } 136 | if region := args.Get("region"); region != "" { 137 | awsConfig.Region = region 138 | } 139 | cfg.Config = &awsConfig 140 | 141 | cfg.Database = args.Get("db") 142 | cfg.OutputLocation = args.Get("output_location") 143 | 144 | frequencyStr := args.Get("poll_frequency") 145 | if frequencyStr != "" { 146 | cfg.PollFrequency, err = time.ParseDuration(frequencyStr) 147 | if err != nil { 148 | return nil, fmt.Errorf("invalid poll_frequency parameter: %s", frequencyStr) 149 | } 150 | } 151 | 152 | return &cfg, nil 153 | } 154 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/segmentio/go-athena 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/aws/aws-sdk-go-v2 v1.30.4 7 | github.com/aws/aws-sdk-go-v2/config v1.27.30 8 | github.com/aws/aws-sdk-go-v2/service/athena v1.44.5 9 | github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1 10 | github.com/satori/go.uuid v1.2.0 11 | github.com/stretchr/testify v1.9.0 12 | ) 13 | 14 | require ( 15 | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect 16 | github.com/aws/aws-sdk-go-v2/credentials v1.17.29 // indirect 17 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 // indirect 18 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect 19 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect 20 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect 21 | github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16 // indirect 22 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect 23 | github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 // indirect 24 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect 25 | github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 // indirect 26 | github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect 27 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect 28 | github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect 29 | github.com/aws/smithy-go v1.20.4 // indirect 30 | github.com/davecgh/go-spew v1.1.1 // indirect 31 | github.com/pmezard/go-difflib v1.0.0 // indirect 32 | gopkg.in/yaml.v3 v3.0.1 // indirect 33 | ) 34 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= 2 | github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= 3 | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 h1:70PVAiL15/aBMh5LThwgXdSQorVr91L127ttckI9QQU= 4 | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4/go.mod h1:/MQxMqci8tlqDH+pjmoLu1i0tbWCUP1hhyMRuFxpQCw= 5 | github.com/aws/aws-sdk-go-v2/config v1.27.30 h1:AQF3/+rOgeJBQP3iI4vojlPib5X6eeOYoa/af7OxAYg= 6 | github.com/aws/aws-sdk-go-v2/config v1.27.30/go.mod h1:yxqvuubha9Vw8stEgNiStO+yZpP68Wm9hLmcm+R/Qk4= 7 | github.com/aws/aws-sdk-go-v2/credentials v1.17.29 h1:CwGsupsXIlAFYuDVHv1nnK0wnxO0wZ/g1L8DSK/xiIw= 8 | github.com/aws/aws-sdk-go-v2/credentials v1.17.29/go.mod h1:BPJ/yXV92ZVq6G8uYvbU0gSl8q94UB63nMT5ctNO38g= 9 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 h1:yjwoSyDZF8Jth+mUk5lSPJCkMC0lMy6FaCD51jm6ayE= 10 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12/go.mod h1:fuR57fAgMk7ot3WcNQfb6rSEn+SUffl7ri+aa8uKysI= 11 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 h1:TNyt/+X43KJ9IJJMjKfa3bNTiZbUP7DeCxfbTROESwY= 12 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16/go.mod h1:2DwJF39FlNAUiX5pAc0UNeiz16lK2t7IaFcm0LFHEgc= 13 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY0L1/KftReOGxI/4NtVSTh9O/I= 14 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16/go.mod h1:7ZfEPZxkW42Afq4uQB8H2E2e6ebh6mXTueEpYzjCzcs= 15 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= 16 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= 17 | github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16 h1:mimdLQkIX1zr8GIPY1ZtALdBQGxcASiBd2MOp8m/dMc= 18 | github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16/go.mod h1:YHk6owoSwrIsok+cAH9PENCOGoH5PU2EllX4vLtSrsY= 19 | github.com/aws/aws-sdk-go-v2/service/athena v1.44.5 h1:l6fpIrGjYc8zfeBo3QHWxQf3d8TwIxITJXCLOKEhMWw= 20 | github.com/aws/aws-sdk-go-v2/service/athena v1.44.5/go.mod h1:JKpavcrQ83Uy6ntM2pIt0vfVpHR9kvI3dkUeAKQstpc= 21 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbLPPHEmf9DgMGw51jMj77VfGPAN2Kv4cfhlfgI= 22 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc= 23 | github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 h1:GckUnpm4EJOAio1c8o25a+b3lVfwVzC9gnSBqiiNmZM= 24 | github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18/go.mod h1:Br6+bxfG33Dk3ynmkhsW2Z/t9D4+lRqdLDNCKi85w0U= 25 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHCiSH0jyd6gROjlJtNwov0eGYNz8s8nFcR0jQ= 26 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= 27 | github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 h1:jg16PhLPUiHIj8zYIW6bqzeQSuHVEiWnGA0Brz5Xv2I= 28 | github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16/go.mod h1:Uyk1zE1VVdsHSU7096h/rwnXDzOzYQVl+FNPhPw7ShY= 29 | github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1 h1:mx2ucgtv+MWzJesJY9Ig/8AFHgoE5FwLXwUVgW/FGdI= 30 | github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1/go.mod h1:BSPI0EfnYUuNHPS0uqIo5VrRwzie+Fp+YhQOUs16sKI= 31 | github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 h1:zCsFCKvbj25i7p1u94imVoO447I/sFv8qq+lGJhRN0c= 32 | github.com/aws/aws-sdk-go-v2/service/sso v1.22.5/go.mod h1:ZeDX1SnKsVlejeuz41GiajjZpRSWR7/42q/EyA/QEiM= 33 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 h1:SKvPgvdvmiTWoi0GAJ7AsJfOz3ngVkD/ERbs5pUnHNI= 34 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5/go.mod h1:20sz31hv/WsPa3HhU3hfrIet2kxM4Pe0r20eBZ20Tac= 35 | github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 h1:OMsEmCyz2i89XwRwPouAJvhj81wINh+4UK+k/0Yo/q8= 36 | github.com/aws/aws-sdk-go-v2/service/sts v1.30.5/go.mod h1:vmSqFK+BVIwVpDAGZB3CoCXHzurt4qBE8lf+I/kRTh0= 37 | github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= 38 | github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= 39 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 40 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 41 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 42 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 43 | github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= 44 | github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= 45 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 46 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 47 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 48 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 49 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 50 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 51 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | package athena 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "io" 7 | 8 | "github.com/aws/aws-sdk-go-v2/aws" 9 | "github.com/aws/aws-sdk-go-v2/service/athena" 10 | ) 11 | 12 | type rows struct { 13 | athena athenaAPI 14 | queryID string 15 | 16 | done bool 17 | skipHeaderRow bool 18 | out *athena.GetQueryResultsOutput 19 | } 20 | 21 | type rowsConfig struct { 22 | Athena athenaAPI 23 | QueryID string 24 | SkipHeader bool 25 | } 26 | 27 | func newRows(ctx context.Context, cfg rowsConfig) (*rows, error) { 28 | r := rows{ 29 | athena: cfg.Athena, 30 | queryID: cfg.QueryID, 31 | skipHeaderRow: cfg.SkipHeader, 32 | } 33 | 34 | shouldContinue, err := r.fetchNextPage(ctx, nil) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | r.done = !shouldContinue 40 | return &r, nil 41 | } 42 | 43 | func (r *rows) Columns() []string { 44 | var columns []string 45 | for _, colInfo := range r.out.ResultSet.ResultSetMetadata.ColumnInfo { 46 | columns = append(columns, *colInfo.Name) 47 | } 48 | 49 | return columns 50 | } 51 | 52 | func (r *rows) ColumnTypeDatabaseTypeName(index int) string { 53 | colInfo := r.out.ResultSet.ResultSetMetadata.ColumnInfo[index] 54 | if colInfo.Type != nil { 55 | return *colInfo.Type 56 | } 57 | return "" 58 | } 59 | 60 | func (r *rows) Next(dest []driver.Value) error { 61 | if r.done { 62 | return io.EOF 63 | } 64 | 65 | // If nothing left to iterate... 66 | if len(r.out.ResultSet.Rows) == 0 { 67 | // And if nothing more to paginate... 68 | if r.out.NextToken == nil || *r.out.NextToken == "" { 69 | return io.EOF 70 | } 71 | 72 | // A context cannot be passed into the Next function because it is defined 73 | // in the database.sql.driver.Rows interface. 74 | cont, err := r.fetchNextPage(context.Background(), r.out.NextToken) 75 | if err != nil { 76 | return err 77 | } 78 | 79 | if !cont { 80 | return io.EOF 81 | } 82 | } 83 | 84 | // Shift to next row 85 | cur := r.out.ResultSet.Rows[0] 86 | columns := r.out.ResultSet.ResultSetMetadata.ColumnInfo 87 | if err := convertRow(columns, cur.Data, dest); err != nil { 88 | return err 89 | } 90 | 91 | r.out.ResultSet.Rows = r.out.ResultSet.Rows[1:] 92 | return nil 93 | } 94 | 95 | func (r *rows) fetchNextPage(ctx context.Context, token *string) (bool, error) { 96 | var err error 97 | r.out, err = r.athena.GetQueryResults(ctx, &athena.GetQueryResultsInput{ 98 | QueryExecutionId: aws.String(r.queryID), 99 | NextToken: token, 100 | }) 101 | if err != nil { 102 | return false, err 103 | } 104 | 105 | var rowOffset = 0 106 | // First row of the first page contains header if the query is not DDL. 107 | // These are also available in *athena.Row.ResultSetMetadata. 108 | if r.skipHeaderRow { 109 | rowOffset = 1 110 | r.skipHeaderRow = false 111 | } 112 | 113 | if len(r.out.ResultSet.Rows) < rowOffset+1 { 114 | return false, nil 115 | } 116 | 117 | r.out.ResultSet.Rows = r.out.ResultSet.Rows[rowOffset:] 118 | return true, nil 119 | } 120 | 121 | func (r *rows) Close() error { 122 | r.done = true 123 | return nil 124 | } 125 | -------------------------------------------------------------------------------- /rows_test.go: -------------------------------------------------------------------------------- 1 | package athena 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "errors" 7 | "io" 8 | "math/rand" 9 | "testing" 10 | 11 | "github.com/aws/aws-sdk-go-v2/service/athena" 12 | "github.com/aws/aws-sdk-go-v2/service/athena/types" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | var dummyError = errors.New("dummy error") 17 | 18 | type genQueryResultsOutputByToken func(token string) (*athena.GetQueryResultsOutput, error) 19 | 20 | var queryToResultsGenMap = map[string]genQueryResultsOutputByToken{ 21 | "select": dummySelectQueryResponse, 22 | "show": dummyShowResponse, 23 | "iteration_fail": dummyFailedIterationResponse, 24 | } 25 | 26 | func genColumnInfo(column string) types.ColumnInfo { 27 | caseSensitive := true 28 | catalogName := "hive" 29 | nullable := types.ColumnNullableUnknown 30 | precision := int32(2147483647) 31 | scale := int32(0) 32 | schemaName := "" 33 | tableName := "" 34 | columnType := "varchar" 35 | 36 | return types.ColumnInfo{ 37 | CaseSensitive: caseSensitive, 38 | CatalogName: &catalogName, 39 | Nullable: nullable, 40 | Precision: precision, 41 | Scale: scale, 42 | SchemaName: &schemaName, 43 | TableName: &tableName, 44 | Type: &columnType, 45 | Label: &column, 46 | Name: &column, 47 | } 48 | } 49 | 50 | func randomString() string { 51 | const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 52 | s := make([]byte, 10) 53 | for i := 0; i < len(s); i++ { 54 | s[i] = alphabet[rand.Intn(len(alphabet))] 55 | } 56 | return string(s) 57 | } 58 | 59 | func genRow(isHeader bool, columns []types.ColumnInfo) types.Row { 60 | var data []types.Datum 61 | for i := 0; i < len(columns); i++ { 62 | if isHeader { 63 | data = append(data, types.Datum{ 64 | VarCharValue: columns[i].Name, 65 | }) 66 | } else { 67 | s := randomString() 68 | data = append(data, types.Datum{ 69 | VarCharValue: &s, 70 | }) 71 | } 72 | } 73 | return types.Row{ 74 | Data: data, 75 | } 76 | } 77 | 78 | func dummySelectQueryResponse(token string) (*athena.GetQueryResultsOutput, error) { 79 | switch token { 80 | case "": 81 | var nextToken = "page_1" 82 | columns := []types.ColumnInfo{ 83 | genColumnInfo("first_name"), 84 | genColumnInfo("last_name"), 85 | } 86 | return &athena.GetQueryResultsOutput{ 87 | NextToken: &nextToken, 88 | ResultSet: &types.ResultSet{ 89 | ResultSetMetadata: &types.ResultSetMetadata{ 90 | ColumnInfo: columns, 91 | }, 92 | Rows: []types.Row{ 93 | genRow(true, columns), 94 | genRow(false, columns), 95 | genRow(false, columns), 96 | genRow(false, columns), 97 | genRow(false, columns), 98 | }, 99 | }, 100 | }, nil 101 | case "page_1": 102 | columns := []types.ColumnInfo{ 103 | genColumnInfo("first_name"), 104 | genColumnInfo("last_name"), 105 | } 106 | return &athena.GetQueryResultsOutput{ 107 | ResultSet: &types.ResultSet{ 108 | ResultSetMetadata: &types.ResultSetMetadata{ 109 | ColumnInfo: columns, 110 | }, 111 | Rows: []types.Row{ 112 | genRow(false, columns), 113 | genRow(false, columns), 114 | genRow(false, columns), 115 | genRow(false, columns), 116 | genRow(false, columns), 117 | }, 118 | }, 119 | }, nil 120 | default: 121 | return nil, dummyError 122 | } 123 | } 124 | 125 | func dummyShowResponse(_ string) (*athena.GetQueryResultsOutput, error) { 126 | columns := []types.ColumnInfo{ 127 | genColumnInfo("partition"), 128 | } 129 | return &athena.GetQueryResultsOutput{ 130 | ResultSet: &types.ResultSet{ 131 | ResultSetMetadata: &types.ResultSetMetadata{ 132 | ColumnInfo: columns, 133 | }, 134 | Rows: []types.Row{ 135 | genRow(false, columns), 136 | genRow(false, columns), 137 | }, 138 | }, 139 | }, nil 140 | } 141 | 142 | func dummyFailedIterationResponse(token string) (*athena.GetQueryResultsOutput, error) { 143 | switch token { 144 | case "": 145 | var nextToken = "page_1" 146 | columns := []types.ColumnInfo{ 147 | genColumnInfo("first_name"), 148 | genColumnInfo("last_name"), 149 | } 150 | return &athena.GetQueryResultsOutput{ 151 | NextToken: &nextToken, 152 | ResultSet: &types.ResultSet{ 153 | ResultSetMetadata: &types.ResultSetMetadata{ 154 | ColumnInfo: columns, 155 | }, 156 | Rows: []types.Row{ 157 | genRow(true, columns), 158 | genRow(false, columns), 159 | genRow(false, columns), 160 | genRow(false, columns), 161 | genRow(false, columns), 162 | }, 163 | }, 164 | }, nil 165 | default: 166 | return nil, dummyError 167 | } 168 | } 169 | 170 | type mockAthenaClient struct { 171 | athenaAPI 172 | } 173 | 174 | func (m *mockAthenaClient) GetQueryResults(ctx context.Context, query *athena.GetQueryResultsInput, opts ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) { 175 | var nextToken = "" 176 | if query.NextToken != nil { 177 | nextToken = *query.NextToken 178 | } 179 | return queryToResultsGenMap[*query.QueryExecutionId](nextToken) 180 | } 181 | 182 | func castToValue(dest ...driver.Value) []driver.Value { 183 | return dest 184 | } 185 | 186 | func TestRows_Next(t *testing.T) { 187 | tests := []struct { 188 | desc string 189 | queryID string 190 | skipHeader bool 191 | expectedResultsSize int 192 | expectedError error 193 | }{ 194 | { 195 | desc: "show query, no header, 2 rows, no error", 196 | queryID: "show", 197 | skipHeader: false, 198 | expectedResultsSize: 2, 199 | expectedError: nil, 200 | }, 201 | { 202 | desc: "select query, header, multipage, 9 rows, no error", 203 | queryID: "select", 204 | skipHeader: true, 205 | expectedResultsSize: 9, 206 | expectedError: nil, 207 | }, 208 | { 209 | desc: "failed during calling next", 210 | queryID: "iteration_fail", 211 | skipHeader: true, 212 | expectedError: dummyError, 213 | }, 214 | } 215 | ctx := context.Background() 216 | for _, test := range tests { 217 | r, _ := newRows(ctx, rowsConfig{ 218 | Athena: new(mockAthenaClient), 219 | QueryID: test.queryID, 220 | SkipHeader: test.skipHeader, 221 | }) 222 | 223 | var firstName, lastName string 224 | cnt := 0 225 | for { 226 | err := r.Next(castToValue(&firstName, &lastName)) 227 | if err != nil { 228 | if err != io.EOF { 229 | assert.Equal(t, test.expectedError, err) 230 | } 231 | break 232 | } 233 | cnt++ 234 | } 235 | if test.expectedError == nil { 236 | assert.Equal(t, test.expectedResultsSize, cnt) 237 | } 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /value.go: -------------------------------------------------------------------------------- 1 | package athena 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "strconv" 7 | "time" 8 | 9 | "github.com/aws/aws-sdk-go-v2/service/athena/types" 10 | ) 11 | 12 | const ( 13 | // TimestampLayout is the Go time layout string for an Athena `timestamp`. 14 | TimestampLayout = "2006-01-02 15:04:05.999" 15 | TimestampWithTimeZoneLayout = "2006-01-02 15:04:05.999 MST" 16 | DateLayout = "2006-01-02" 17 | ) 18 | 19 | func convertRow(columns []types.ColumnInfo, in []types.Datum, ret []driver.Value) error { 20 | for i, val := range in { 21 | coerced, err := convertValue(*columns[i].Type, val.VarCharValue) 22 | if err != nil { 23 | return err 24 | } 25 | 26 | ret[i] = coerced 27 | } 28 | 29 | return nil 30 | } 31 | 32 | func convertValue(athenaType string, rawValue *string) (interface{}, error) { 33 | if rawValue == nil { 34 | return nil, nil 35 | } 36 | 37 | val := *rawValue 38 | switch athenaType { 39 | case "smallint": 40 | return strconv.ParseInt(val, 10, 16) 41 | case "integer": 42 | return strconv.ParseInt(val, 10, 32) 43 | case "bigint": 44 | return strconv.ParseInt(val, 10, 64) 45 | case "boolean": 46 | switch val { 47 | case "true": 48 | return true, nil 49 | case "false": 50 | return false, nil 51 | } 52 | return nil, fmt.Errorf("cannot parse '%s' as boolean", val) 53 | case "float": 54 | return strconv.ParseFloat(val, 32) 55 | case "double", "decimal": 56 | return strconv.ParseFloat(val, 64) 57 | case "varchar", "string": 58 | return val, nil 59 | case "timestamp": 60 | return time.Parse(TimestampLayout, val) 61 | case "timestamp with time zone": 62 | return time.Parse(TimestampWithTimeZoneLayout, val) 63 | case "date": 64 | return time.Parse(DateLayout, val) 65 | default: 66 | panic(fmt.Errorf("unknown type `%s` with value %s", athenaType, val)) 67 | } 68 | } 69 | --------------------------------------------------------------------------------