├── .github ├── CODEOWNERS └── workflows │ ├── go-checks.yml │ └── parquet-checks.yml ├── .gitignore ├── .goreleaser.yaml ├── LICENSE.txt ├── Makefile ├── README.md ├── clients ├── bigquery │ ├── bigquery.go │ ├── bigquery_dedupe_test.go │ ├── bigquery_suite_test.go │ ├── bigquery_test.go │ ├── converters │ │ ├── converters.go │ │ └── converters_test.go │ ├── dialect │ │ ├── ddl.go │ │ ├── default.go │ │ ├── dialect.go │ │ ├── dialect_test.go │ │ ├── tableid.go │ │ ├── tableid_test.go │ │ ├── typing.go │ │ └── typing_test.go │ ├── errors.go │ ├── merge.go │ ├── merge_test.go │ ├── partition.go │ ├── partition_test.go │ ├── storagewrite.go │ └── storagewrite_test.go ├── databricks │ ├── dialect │ │ ├── ddl.go │ │ ├── dialect.go │ │ ├── dialect_test.go │ │ ├── tableid.go │ │ ├── tableid_test.go │ │ ├── typing.go │ │ └── typing_test.go │ ├── file.go │ ├── file_test.go │ └── store.go ├── iceberg │ ├── dialect │ │ ├── data_types.go │ │ ├── data_types_test.go │ │ ├── dialect.go │ │ ├── dialect_test.go │ │ └── tableid.go │ ├── staging.go │ ├── staging_test.go │ ├── store.go │ └── table.go ├── mssql │ ├── dialect │ │ ├── ddl.go │ │ ├── default.go │ │ ├── dialect.go │ │ ├── dialect_test.go │ │ ├── merge_test.go │ │ ├── tableid.go │ │ ├── tableid_test.go │ │ ├── typing.go │ │ └── typing_test.go │ ├── staging.go │ ├── store.go │ ├── store_test.go │ ├── values.go │ └── values_test.go ├── redshift │ ├── cast.go │ ├── cast_test.go │ ├── dialect │ │ ├── ddl.go │ │ ├── default.go │ │ ├── default_test.go │ │ ├── dialect.go │ │ ├── dialect_test.go │ │ ├── tableid.go │ │ ├── tableid_test.go │ │ ├── typing.go │ │ └── typing_test.go │ ├── redshift.go │ ├── redshift_bench_test.go │ ├── redshift_dedupe_test.go │ ├── redshift_suite_test.go │ ├── redshift_test.go │ └── staging.go ├── s3 │ ├── s3.go │ ├── s3_test.go │ ├── tableid.go │ └── tableid_test.go ├── shared │ ├── append.go │ ├── default_value.go │ ├── default_value_test.go │ ├── merge.go │ ├── multi_step_merge.go │ ├── sweep.go │ ├── table.go │ ├── table_config.go │ ├── table_config_test.go │ ├── temp_table.go │ └── temp_table_test.go └── snowflake │ ├── ddl_test.go │ ├── dialect │ ├── ddl.go │ ├── default.go │ ├── dialect.go │ ├── dialect_test.go │ ├── tableid.go │ ├── tableid_test.go │ ├── typing.go │ └── typing_test.go │ ├── snowflake.go │ ├── snowflake_dedupe_test.go │ ├── snowflake_suite_test.go │ ├── snowflake_test.go │ ├── staging.go │ ├── staging_test.go │ ├── util.go │ ├── util_test.go │ └── writes.go ├── examples ├── README.md ├── mongodb │ ├── Dockerfile │ ├── README.md │ ├── config.yaml │ ├── connect │ │ ├── Dockerfile │ │ ├── connect-distributed.properties │ │ └── docker-entrypoint.sh │ ├── docker-compose.yaml │ └── register-mongodb-connector.json ├── mysql │ ├── Dockerfile │ ├── README.md │ ├── application.properties │ ├── config.yaml │ └── docker-compose.yaml └── postgres │ ├── Dockerfile │ ├── README.md │ ├── config.yaml │ ├── connect │ ├── Dockerfile │ ├── connect-distributed.properties │ └── docker-entrypoint.sh │ ├── docker-compose.yaml │ └── register-postgres-connector.json ├── go.mod ├── go.sum ├── goreleaser.dockerfile ├── integration_tests ├── destination_append │ └── main.go ├── destination_merge │ └── main.go ├── destination_types │ └── main.go ├── parquet │ ├── main.go │ ├── requirements.txt │ └── verify_parquet.py └── shared │ ├── baseline.go │ ├── checker.go │ ├── destination.go │ ├── destination_types.go │ ├── destination_types.snowflake.go │ ├── destination_types_mssql.go │ └── framework.go ├── lib ├── apachelivy │ ├── client.go │ ├── schema.go │ ├── schema_test.go │ ├── session.go │ ├── types.go │ └── util.go ├── array │ ├── strings.go │ └── strings_test.go ├── artie │ ├── message.go │ └── message_test.go ├── awslib │ ├── config.go │ ├── s3.go │ ├── s3tablesapi.go │ ├── sts.go │ ├── sts_test.go │ └── types.go ├── batch │ ├── batch.go │ └── batch_test.go ├── cdc │ ├── event.go │ ├── format │ │ ├── format.go │ │ └── format_test.go │ ├── mongo │ │ ├── debezium.go │ │ ├── debezium_test.go │ │ ├── event.go │ │ └── mongo_bench_test.go │ ├── relational │ │ ├── debezium.go │ │ ├── debezium_test.go │ │ └── relation_suite_test.go │ └── util │ │ ├── decimal.json │ │ ├── money.json │ │ ├── numbers.json │ │ ├── numeric.json │ │ ├── optional_schema.go │ │ ├── optional_schema_test.go │ │ ├── relational_data_test.go │ │ ├── relational_event.go │ │ ├── relational_event_decimal_test.go │ │ └── relational_event_test.go ├── config │ ├── config.go │ ├── config_test.go │ ├── config_validate_test.go │ ├── constants │ │ └── constants.go │ ├── destination_types.go │ ├── destinations.go │ ├── destinations_test.go │ ├── mssql_test.go │ ├── settings.go │ ├── settings_test.go │ └── types.go ├── cryptography │ ├── cryptography.go │ └── cryptography_test.go ├── csvwriter │ ├── gzip.go │ └── gzip_test.go ├── db │ ├── db.go │ ├── errors.go │ └── errors_test.go ├── debezium │ ├── converters │ │ ├── basic.go │ │ ├── basic_test.go │ │ ├── bytes.go │ │ ├── bytes_test.go │ │ ├── converters.go │ │ ├── date.go │ │ ├── date_test.go │ │ ├── decimal.go │ │ ├── decimal_test.go │ │ ├── geometry.go │ │ ├── geometry_test.go │ │ ├── string.go │ │ ├── string_test.go │ │ ├── time.go │ │ ├── time_test.go │ │ ├── timestamp.go │ │ └── timestamp_test.go │ ├── keys.go │ ├── keys_test.go │ ├── schema.go │ ├── schema_test.go │ ├── types.go │ ├── types_bench_test.go │ └── types_test.go ├── destination │ ├── ddl │ │ ├── ddl.go │ │ ├── ddl_alter_delete_test.go │ │ ├── ddl_bq_test.go │ │ ├── ddl_sflk_test.go │ │ ├── ddl_suite_test.go │ │ ├── ddl_temp_test.go │ │ ├── ddl_test.go │ │ ├── expiry.go │ │ └── expiry_test.go │ ├── destination.go │ ├── types │ │ ├── table_config.go │ │ ├── table_config_test.go │ │ ├── types.go │ │ └── types_test.go │ └── utils │ │ └── load.go ├── environ │ ├── environment.go │ └── environment_test.go ├── jitter │ ├── sleep.go │ └── sleep_test.go ├── jsonutil │ ├── jsonutil.go │ └── jsonutil_test.go ├── kafkalib │ ├── connection.go │ ├── connection_test.go │ ├── consumer.go │ ├── partition │ │ ├── settings.go │ │ └── settings_test.go │ ├── topic.go │ └── topic_test.go ├── logger │ └── log.go ├── maputil │ ├── map.go │ ├── map_test.go │ ├── ordered_map.go │ └── ordered_map_test.go ├── mocks │ └── generate.go ├── numbers │ ├── numbers.go │ └── numbers_test.go ├── optimization │ ├── event_bench_test.go │ ├── event_insert_test.go │ ├── table_data.go │ ├── table_data_merge_columns_test.go │ └── table_data_test.go ├── parquetutil │ ├── generate_schema.go │ ├── parse_values.go │ └── parse_values_test.go ├── retry │ ├── retry.go │ └── retry_test.go ├── size │ ├── size.go │ ├── size_bench_test.go │ └── size_test.go ├── sql │ ├── columns.go │ ├── dialect.go │ ├── rows.go │ ├── tests │ │ ├── columns_test.go │ │ └── util_test.go │ ├── util.go │ └── util_test.go ├── stringutil │ ├── strings.go │ └── strings_test.go ├── telemetry │ └── metrics │ │ ├── base │ │ └── provider.go │ │ ├── datadog │ │ ├── datadog.go │ │ ├── datadog_test.go │ │ └── tags.go │ │ ├── null_provider.go │ │ ├── stats.go │ │ └── stats_test.go └── typing │ ├── README.md │ ├── assert.go │ ├── assert_test.go │ ├── columns │ ├── columns.go │ ├── columns_test.go │ ├── diff.go │ └── diff_test.go │ ├── converters │ ├── primitives │ │ ├── converter.go │ │ └── converter_test.go │ ├── string_converter.go │ ├── string_converter_test.go │ ├── util.go │ └── util_test.go │ ├── decimal │ ├── base.go │ ├── base_test.go │ ├── decimal.go │ ├── decimal_test.go │ ├── details.go │ └── details_test.go │ ├── errors.go │ ├── errors_test.go │ ├── ext │ ├── parse.go │ ├── parse_test.go │ ├── variables.go │ └── variables_test.go │ ├── mongo │ ├── bson.go │ ├── bson_bench_test.go │ └── bson_test.go │ ├── numeric.go │ ├── numeric_test.go │ ├── parquet.go │ ├── parquet_test.go │ ├── parse.go │ ├── parse_test.go │ ├── ptr.go │ ├── ptr_test.go │ ├── typing.go │ ├── typing_bench_test.go │ ├── typing_test.go │ └── values │ ├── string.go │ └── string_test.go ├── main.go ├── models ├── event │ ├── event.go │ ├── event_save_test.go │ ├── event_test.go │ └── events_suite_test.go ├── memory.go ├── memory_flush_test.go └── memory_test.go └── processes ├── consumer ├── configs.go ├── flush.go ├── flush_suite_test.go ├── flush_test.go ├── kafka.go ├── process.go └── process_test.go └── pool └── writes.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @artie-labs/engineering 2 | -------------------------------------------------------------------------------- /.github/workflows/go-checks.yml: -------------------------------------------------------------------------------- 1 | name: Go checks 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-24.04 8 | 9 | steps: 10 | - uses: actions/checkout@v4 11 | 12 | - name: Set up Go 13 | uses: actions/setup-go@v5 14 | with: 15 | go-version-file: go.mod 16 | 17 | - name: Download dependencies 18 | run: | 19 | go mod download 20 | go mod tidy -diff 21 | 22 | - name: Generate mocks 23 | run: make generate 24 | 25 | - name: Run vet 26 | run: make vet 27 | 28 | - uses: dominikh/staticcheck-action@fe1dd0c3658873b46f8c9bb3291096a617310ca6 # v1.3.1 29 | with: 30 | version: "2025.1.1" 31 | install-go: false 32 | 33 | - name: Run tests + race condition check 34 | run: make race 35 | 36 | - name: Check Go files are properly formatted 37 | run: test -z $(gofmt -l .) 38 | -------------------------------------------------------------------------------- /.github/workflows/parquet-checks.yml: -------------------------------------------------------------------------------- 1 | name: Parquet checks 2 | 3 | on: [push] 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-24.04 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v5 17 | with: 18 | go-version-file: go.mod 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.11' 24 | 25 | - name: Download Go dependencies 26 | run: | 27 | go mod download 28 | go mod tidy -diff 29 | 30 | - name: Run Parquet integration tests 31 | run: make test-parquet 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # GoReleaser 2 | dist/* 3 | 4 | # Ignore mocks 5 | **/*.mock.go 6 | 7 | .envrc-personal 8 | 9 | # Terraform stuff 10 | config/.terraform* 11 | 12 | 13 | .personal/ 14 | scratch/ 15 | 16 | # Local .terraform directories 17 | **/.terraform.lock.hcl 18 | **/.terraform/* 19 | # .tfstate files 20 | *.tfstate 21 | *.tfstate.* 22 | 23 | ## Ignore virtual env from the integration tests 24 | integration_tests/parquet/venv/ 25 | 26 | # Ignore the output of the integration tests 27 | integration_tests/parquet/output/ 28 | -------------------------------------------------------------------------------- /.goreleaser.yaml: -------------------------------------------------------------------------------- 1 | project_name: artie-transfer 2 | 3 | version: 2 4 | 5 | before: 6 | hooks: 7 | # You may remove this if you don't use go modules. 8 | - go mod tidy 9 | builds: 10 | - binary: transfer 11 | env: 12 | - CGO_ENABLED=0 13 | goos: 14 | - linux 15 | - darwin 16 | 17 | dockers: 18 | - image_templates: 19 | - "artielabs/transfer:latest" 20 | - "artielabs/transfer:{{ .Tag }}" 21 | # You can have multiple Docker images. 22 | # GOOS of the built binaries/packages that should be used. 23 | # Default: `linux`. 24 | goos: linux 25 | 26 | # GOARCH of the built binaries/packages that should be used. 27 | # Default: `amd64`. 28 | goarch: amd64 29 | 30 | # Skips the docker push. 31 | skip_push: false 32 | 33 | # Path to the Dockerfile (from the project root). 34 | # Defaults to `Dockerfile`. 35 | dockerfile: goreleaser.dockerfile 36 | 37 | # Set the "backend" for the Docker pipe. 38 | # Valid options are: docker, buildx, podman. 39 | # Defaults to docker. 40 | use: docker 41 | build_flag_templates: 42 | - "--pull" 43 | - "--label=org.opencontainers.image.created={{.Date}}" 44 | - "--label=org.opencontainers.image.title={{.ProjectName}}" 45 | - "--label=org.opencontainers.image.revision={{.FullCommit}}" 46 | - "--label=org.opencontainers.image.version={{.Version}}" 47 | - "--platform=linux/amd64" 48 | 49 | 50 | archives: 51 | - formats: tar.gz 52 | # this name template makes the OS and Arch compatible with the results of uname. 53 | name_template: >- 54 | {{ .ProjectName }}_ 55 | {{- title .Os }}_ 56 | {{- if eq .Arch "amd64" }}x86_64 57 | {{- else if eq .Arch "386" }}i386 58 | {{- else }}{{ .Arch }}{{ end }} 59 | {{- if .Arm }}v{{ .Arm }}{{ end }} 60 | checksum: 61 | name_template: 'checksums.txt' 62 | snapshot: 63 | version_template: "{{ incpatch .Version }}-next" 64 | changelog: 65 | sort: asc 66 | filters: 67 | exclude: 68 | - '^docs:' 69 | - '^test:' 70 | -------------------------------------------------------------------------------- /clients/bigquery/bigquery_suite_test.go: -------------------------------------------------------------------------------- 1 | package bigquery 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/artie-labs/transfer/lib/config" 7 | 8 | "github.com/artie-labs/transfer/lib/db" 9 | "github.com/artie-labs/transfer/lib/mocks" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type BigQueryTestSuite struct { 15 | suite.Suite 16 | fakeStore *mocks.FakeStore 17 | store *Store 18 | } 19 | 20 | func (b *BigQueryTestSuite) SetupTest() { 21 | cfg := config.Config{ 22 | BigQuery: &config.BigQuery{ 23 | ProjectID: "artie", 24 | }, 25 | } 26 | 27 | b.fakeStore = &mocks.FakeStore{} 28 | store := db.Store(b.fakeStore) 29 | var err error 30 | b.store, err = LoadBigQuery(b.T().Context(), cfg, &store) 31 | assert.NoError(b.T(), err) 32 | } 33 | 34 | func TestBigQueryTestSuite(t *testing.T) { 35 | suite.Run(t, new(BigQueryTestSuite)) 36 | } 37 | -------------------------------------------------------------------------------- /clients/bigquery/bigquery_test.go: -------------------------------------------------------------------------------- 1 | package bigquery 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/artie-labs/transfer/clients/shared" 10 | "github.com/artie-labs/transfer/lib/config" 11 | "github.com/artie-labs/transfer/lib/kafkalib" 12 | "github.com/artie-labs/transfer/lib/optimization" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestTempTableIDWithSuffix(t *testing.T) { 17 | trimTTL := func(tableName string) string { 18 | lastUnderscore := strings.LastIndex(tableName, "_") 19 | assert.GreaterOrEqual(t, lastUnderscore, 0) 20 | epoch, err := strconv.ParseInt(tableName[lastUnderscore+1:len(tableName)-1], 10, 64) 21 | assert.NoError(t, err) 22 | assert.Greater(t, time.Unix(epoch, 0), time.Now().Add(5*time.Hour)) // default TTL is 6 hours from now 23 | return tableName[:lastUnderscore] + string(tableName[len(tableName)-1]) 24 | } 25 | 26 | store := &Store{config: config.Config{BigQuery: &config.BigQuery{ProjectID: "123454321"}}} 27 | tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table") 28 | tableID := store.IdentifierFor(tableData.TopicConfig().BuildDatabaseAndSchemaPair(), tableData.Name()) 29 | tempTableName := shared.TempTableIDWithSuffix(tableID, "sUfFiX").FullyQualifiedName() 30 | assert.Equal(t, "`123454321`.`db`.`table___artie_sUfFiX`", trimTTL(tempTableName)) 31 | } 32 | -------------------------------------------------------------------------------- /clients/bigquery/dialect/ddl.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | "github.com/artie-labs/transfer/lib/config/constants" 9 | "github.com/artie-labs/transfer/lib/sql" 10 | "github.com/artie-labs/transfer/lib/typing" 11 | ) 12 | 13 | func (BigQueryDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, temporary bool, colSQLParts []string) string { 14 | query := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ",")) 15 | 16 | if temporary { 17 | return fmt.Sprintf( 18 | `%s OPTIONS (expiration_timestamp = TIMESTAMP("%s"))`, 19 | query, 20 | BQExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL)), 21 | ) 22 | } else { 23 | return query 24 | } 25 | } 26 | 27 | func (BigQueryDialect) BuildDropTableQuery(tableID sql.TableIdentifier) string { 28 | return "DROP TABLE IF EXISTS " + tableID.FullyQualifiedName() 29 | } 30 | 31 | func (BigQueryDialect) BuildTruncateTableQuery(tableID sql.TableIdentifier) string { 32 | return "TRUNCATE TABLE " + tableID.FullyQualifiedName() 33 | } 34 | 35 | func (BigQueryDialect) BuildAddColumnQuery(tableID sql.TableIdentifier, sqlPart string) string { 36 | return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", tableID.FullyQualifiedName(), sqlPart) 37 | } 38 | 39 | func (BigQueryDialect) BuildDropColumnQuery(tableID sql.TableIdentifier, colName string) string { 40 | return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableID.FullyQualifiedName(), colName) 41 | } 42 | 43 | func (BigQueryDialect) BuildDescribeTableQuery(tableID sql.TableIdentifier) (string, []interface{}, error) { 44 | bqTableID, err := typing.AssertType[TableIdentifier](tableID) 45 | if err != nil { 46 | return "", nil, err 47 | } 48 | 49 | query := fmt.Sprintf("SELECT column_name, data_type, description FROM `%s.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` WHERE table_name = ?;", bqTableID.Dataset()) 50 | return query, []any{bqTableID.Table()}, nil 51 | } 52 | -------------------------------------------------------------------------------- /clients/bigquery/dialect/default.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import "github.com/artie-labs/transfer/lib/sql" 4 | 5 | func (BigQueryDialect) GetDefaultValueStrategy() sql.DefaultValueStrategy { 6 | return sql.Backfill 7 | } 8 | -------------------------------------------------------------------------------- /clients/bigquery/dialect/tableid.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/sql" 7 | ) 8 | 9 | var _dialect = BigQueryDialect{} 10 | 11 | type TableIdentifier struct { 12 | projectID string 13 | dataset string 14 | table string 15 | disableDropProtection bool 16 | } 17 | 18 | func NewTableIdentifier(projectID, dataset, table string) TableIdentifier { 19 | return TableIdentifier{ 20 | projectID: projectID, 21 | dataset: dataset, 22 | table: table, 23 | } 24 | } 25 | 26 | func (ti TableIdentifier) ProjectID() string { 27 | return ti.projectID 28 | } 29 | 30 | func (ti TableIdentifier) Dataset() string { 31 | return ti.dataset 32 | } 33 | 34 | func (ti TableIdentifier) EscapedTable() string { 35 | return _dialect.QuoteIdentifier(ti.table) 36 | } 37 | 38 | func (ti TableIdentifier) Table() string { 39 | return ti.table 40 | } 41 | 42 | func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { 43 | return NewTableIdentifier(ti.projectID, ti.dataset, table) 44 | } 45 | 46 | func (ti TableIdentifier) FullyQualifiedName() string { 47 | // The fully qualified name for BigQuery is: project_id.dataset.tableName. 48 | // We are escaping the project_id, dataset, and table because there could be special characters. 49 | return fmt.Sprintf("%s.%s.%s", 50 | _dialect.QuoteIdentifier(ti.projectID), 51 | _dialect.QuoteIdentifier(ti.dataset), 52 | ti.EscapedTable(), 53 | ) 54 | } 55 | 56 | func (ti TableIdentifier) WithDisableDropProtection(disableDropProtection bool) sql.TableIdentifier { 57 | ti.disableDropProtection = disableDropProtection 58 | return ti 59 | } 60 | 61 | func (ti TableIdentifier) AllowToDrop() bool { 62 | return ti.disableDropProtection 63 | } 64 | -------------------------------------------------------------------------------- /clients/bigquery/dialect/tableid_test.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTableIdentifier_WithTable(t *testing.T) { 10 | tableID := NewTableIdentifier("project", "dataset", "foo") 11 | tableID2 := tableID.WithTable("bar") 12 | typedTableID2, ok := tableID2.(TableIdentifier) 13 | assert.True(t, ok) 14 | assert.Equal(t, "project", typedTableID2.ProjectID()) 15 | assert.Equal(t, "dataset", typedTableID2.Dataset()) 16 | assert.Equal(t, "bar", tableID2.Table()) 17 | } 18 | 19 | func TestTableIdentifier_FullyQualifiedName(t *testing.T) { 20 | // Table name that is not a reserved word: 21 | assert.Equal(t, "`project`.`dataset`.`foo`", NewTableIdentifier("project", "dataset", "foo").FullyQualifiedName()) 22 | 23 | // Table name that is a reserved word: 24 | assert.Equal(t, "`project`.`dataset`.`table`", NewTableIdentifier("project", "dataset", "table").FullyQualifiedName()) 25 | } 26 | 27 | func TestTableIdentifier_EscapedTable(t *testing.T) { 28 | // Table name that is not a reserved word: 29 | assert.Equal(t, "`foo`", NewTableIdentifier("project", "dataset", "foo").EscapedTable()) 30 | 31 | // Table name that is a reserved word: 32 | assert.Equal(t, "`table`", NewTableIdentifier("project", "dataset", "table").EscapedTable()) 33 | } 34 | -------------------------------------------------------------------------------- /clients/bigquery/errors.go: -------------------------------------------------------------------------------- 1 | package bigquery 2 | 3 | import "strings" 4 | 5 | func isTableQuotaError(err error) bool { 6 | return strings.Contains(err.Error(), "Exceeded rate limits: too many table update operations for this table") 7 | } 8 | 9 | func (s *Store) IsRetryableError(err error) bool { 10 | if isTableQuotaError(err) { 11 | return true 12 | } 13 | 14 | return s.Store.IsRetryableError(err) 15 | } 16 | -------------------------------------------------------------------------------- /clients/bigquery/merge.go: -------------------------------------------------------------------------------- 1 | package bigquery 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/artie-labs/transfer/clients/shared" 9 | "github.com/artie-labs/transfer/lib/config/constants" 10 | "github.com/artie-labs/transfer/lib/destination/types" 11 | "github.com/artie-labs/transfer/lib/kafkalib/partition" 12 | "github.com/artie-labs/transfer/lib/optimization" 13 | "github.com/artie-labs/transfer/lib/sql" 14 | "github.com/artie-labs/transfer/lib/typing" 15 | "github.com/artie-labs/transfer/lib/typing/columns" 16 | ) 17 | 18 | func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) (bool, error) { 19 | var additionalEqualityStrings []string 20 | if tableData.TopicConfig().BigQueryPartitionSettings != nil { 21 | distinctDates, err := buildDistinctDates(tableData.TopicConfig().BigQueryPartitionSettings.PartitionField, tableData.Rows()) 22 | if err != nil { 23 | return false, fmt.Errorf("failed to generate distinct dates: %w", err) 24 | } 25 | 26 | mergeString, err := generateMergeString(tableData.TopicConfig().BigQueryPartitionSettings, s.Dialect(), distinctDates) 27 | if err != nil { 28 | return false, fmt.Errorf("failed to generate merge string: %w", err) 29 | } 30 | 31 | additionalEqualityStrings = []string{mergeString} 32 | } 33 | 34 | err := shared.Merge(ctx, s, tableData, types.MergeOpts{ 35 | AdditionalEqualityStrings: additionalEqualityStrings, 36 | ColumnSettings: s.config.SharedDestinationSettings.ColumnSettings, 37 | // BigQuery has DDL quotas. 38 | RetryColBackfill: true, 39 | }) 40 | if err != nil { 41 | return false, fmt.Errorf("failed to merge: %w", err) 42 | } 43 | 44 | return true, nil 45 | } 46 | 47 | func generateMergeString(bqSettings *partition.BigQuerySettings, dialect sql.Dialect, values []string) (string, error) { 48 | if err := bqSettings.Valid(); err != nil { 49 | return "", fmt.Errorf("failed to validate bigQuerySettings: %w", err) 50 | } 51 | 52 | if len(values) == 0 { 53 | return "", fmt.Errorf("values cannot be empty") 54 | } 55 | 56 | switch bqSettings.PartitionType { 57 | case "time": 58 | switch bqSettings.PartitionBy { 59 | case "daily": 60 | return fmt.Sprintf(`DATE(%s) IN (%s)`, 61 | sql.QuoteTableAliasColumn( 62 | constants.TargetAlias, 63 | columns.NewColumn(bqSettings.PartitionField, typing.Invalid), 64 | dialect, 65 | ), 66 | strings.Join(sql.QuoteLiterals(values), ",")), nil 67 | } 68 | } 69 | 70 | return "", fmt.Errorf("unexpected partitionType: %s and/or partitionBy: %s", bqSettings.PartitionType, bqSettings.PartitionBy) 71 | } 72 | -------------------------------------------------------------------------------- /clients/bigquery/partition.go: -------------------------------------------------------------------------------- 1 | package bigquery 2 | 3 | import ( 4 | "fmt" 5 | "maps" 6 | "slices" 7 | "time" 8 | 9 | "github.com/artie-labs/transfer/lib/typing/ext" 10 | ) 11 | 12 | func buildDistinctDates(colName string, rows []map[string]any) ([]string, error) { 13 | dateMap := make(map[string]bool) 14 | for _, row := range rows { 15 | val, isOk := row[colName] 16 | if !isOk { 17 | return nil, fmt.Errorf("column %q does not exist in row: %v", colName, row) 18 | } 19 | 20 | _time, err := ext.ParseDateFromAny(val) 21 | if err != nil { 22 | return nil, fmt.Errorf("column %q is not a time column, value: %v, err: %w", colName, val, err) 23 | } 24 | 25 | dateMap[_time.Format(time.DateOnly)] = true 26 | } 27 | 28 | return slices.Collect(maps.Keys(dateMap)), nil 29 | } 30 | -------------------------------------------------------------------------------- /clients/bigquery/partition_test.go: -------------------------------------------------------------------------------- 1 | package bigquery 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestDistinctDates(t *testing.T) { 12 | { 13 | // Invalid date 14 | dates, err := buildDistinctDates("ts", []map[string]any{ 15 | {"ts": time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339Nano)}, 16 | {"ts": nil}, 17 | }) 18 | assert.ErrorContains(t, err, `column "ts" is not a time column`) 19 | assert.Empty(t, dates) 20 | } 21 | { 22 | // No dates 23 | dates, err := buildDistinctDates("", nil) 24 | assert.NoError(t, err) 25 | assert.Empty(t, dates) 26 | } 27 | { 28 | // One date 29 | dates, err := buildDistinctDates("ts", []map[string]any{ 30 | {"ts": time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339Nano)}, 31 | }) 32 | assert.NoError(t, err) 33 | assert.Equal(t, []string{"2020-01-01"}, dates) 34 | } 35 | { 36 | // Two dates 37 | dates, err := buildDistinctDates("ts", []map[string]any{ 38 | {"ts": time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339Nano)}, 39 | {"ts": time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC).Format(time.RFC3339Nano)}, 40 | }) 41 | assert.NoError(t, err) 42 | equalLists(t, []string{"2020-01-01", "2020-01-02"}, dates) 43 | } 44 | { 45 | // Three days, two unique 46 | dates, err := buildDistinctDates("ts", []map[string]any{ 47 | {"ts": time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339Nano)}, 48 | {"ts": time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339Nano)}, 49 | {"ts": time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC).Format(time.RFC3339Nano)}, 50 | }) 51 | assert.NoError(t, err) 52 | equalLists(t, []string{"2020-01-01", "2020-01-02"}, dates) 53 | } 54 | } 55 | 56 | func equalLists(t *testing.T, list1 []string, list2 []string) { 57 | // Sort the two lists prior to comparison 58 | slices.Sort(list1) 59 | slices.Sort(list2) 60 | assert.Equal(t, list1, list2) 61 | } 62 | -------------------------------------------------------------------------------- /clients/databricks/dialect/ddl.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/artie-labs/transfer/lib/sql" 8 | ) 9 | 10 | func (DatabricksDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, _ bool, colSQLParts []string) string { 11 | // Databricks doesn't have a concept of temporary tables. 12 | return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", ")) 13 | } 14 | 15 | func (DatabricksDialect) BuildDropTableQuery(tableID sql.TableIdentifier) string { 16 | return "DROP TABLE IF EXISTS " + tableID.FullyQualifiedName() 17 | } 18 | 19 | func (DatabricksDialect) BuildTruncateTableQuery(tableID sql.TableIdentifier) string { 20 | return "TRUNCATE TABLE " + tableID.FullyQualifiedName() 21 | } 22 | 23 | func (DatabricksDialect) BuildAddColumnQuery(tableID sql.TableIdentifier, sqlPart string) string { 24 | return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", tableID.FullyQualifiedName(), sqlPart) 25 | } 26 | 27 | func (DatabricksDialect) BuildDropColumnQuery(tableID sql.TableIdentifier, colName string) string { 28 | return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableID.FullyQualifiedName(), colName) 29 | } 30 | 31 | func (DatabricksDialect) BuildDescribeTableQuery(tableID sql.TableIdentifier) (string, []any, error) { 32 | return fmt.Sprintf("DESCRIBE TABLE %s", tableID.FullyQualifiedName()), nil, nil 33 | } 34 | -------------------------------------------------------------------------------- /clients/databricks/dialect/tableid.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/sql" 7 | ) 8 | 9 | var _dialect = DatabricksDialect{} 10 | 11 | type TableIdentifier struct { 12 | database string 13 | schema string 14 | table string 15 | disableDropProtection bool 16 | } 17 | 18 | func NewTableIdentifier(database, schema, table string) TableIdentifier { 19 | return TableIdentifier{ 20 | database: database, 21 | schema: schema, 22 | table: table, 23 | } 24 | } 25 | 26 | func (ti TableIdentifier) Database() string { 27 | return ti.database 28 | } 29 | 30 | func (ti TableIdentifier) Schema() string { 31 | return ti.schema 32 | } 33 | 34 | func (ti TableIdentifier) EscapedTable() string { 35 | return _dialect.QuoteIdentifier(ti.table) 36 | } 37 | 38 | func (ti TableIdentifier) Table() string { 39 | return ti.table 40 | } 41 | 42 | func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { 43 | return NewTableIdentifier(ti.database, ti.schema, table) 44 | } 45 | 46 | func (ti TableIdentifier) FullyQualifiedName() string { 47 | return fmt.Sprintf("%s.%s.%s", _dialect.QuoteIdentifier(ti.database), _dialect.QuoteIdentifier(ti.schema), ti.EscapedTable()) 48 | } 49 | 50 | func (ti TableIdentifier) WithDisableDropProtection(disableDropProtection bool) sql.TableIdentifier { 51 | ti.disableDropProtection = disableDropProtection 52 | return ti 53 | } 54 | 55 | func (ti TableIdentifier) AllowToDrop() bool { 56 | return ti.disableDropProtection 57 | } 58 | -------------------------------------------------------------------------------- /clients/databricks/dialect/tableid_test.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTableIdentifier_WithTable(t *testing.T) { 10 | tableID := NewTableIdentifier("database", "schema", "foo") 11 | tableID2 := tableID.WithTable("bar") 12 | typedTableID2, ok := tableID2.(TableIdentifier) 13 | assert.True(t, ok) 14 | assert.Equal(t, "database", typedTableID2.Database()) 15 | assert.Equal(t, "schema", typedTableID2.Schema()) 16 | assert.Equal(t, "bar", tableID2.Table()) 17 | } 18 | 19 | func TestTableIdentifier_FullyQualifiedName(t *testing.T) { 20 | // Table name that is not a reserved word: 21 | assert.Equal(t, "`database`.`schema`.`foo`", NewTableIdentifier("database", "schema", "foo").FullyQualifiedName()) 22 | 23 | // Table name that is a reserved word: 24 | assert.Equal(t, "`database`.`schema`.`table`", NewTableIdentifier("database", "schema", "table").FullyQualifiedName()) 25 | } 26 | 27 | func TestTableIdentifier_EscapedTable(t *testing.T) { 28 | // Table name that is not a reserved word: 29 | assert.Equal(t, "`foo`", NewTableIdentifier("database", "schema", "foo").EscapedTable()) 30 | 31 | // Table name that is a reserved word: 32 | assert.Equal(t, "`table`", NewTableIdentifier("database", "schema", "table").EscapedTable()) 33 | } 34 | -------------------------------------------------------------------------------- /clients/databricks/file.go: -------------------------------------------------------------------------------- 1 | package databricks 2 | 3 | import ( 4 | "fmt" 5 | "path/filepath" 6 | "strings" 7 | 8 | "github.com/artie-labs/transfer/clients/databricks/dialect" 9 | "github.com/artie-labs/transfer/lib/destination/ddl" 10 | "github.com/artie-labs/transfer/lib/maputil" 11 | "github.com/artie-labs/transfer/lib/stringutil" 12 | ) 13 | 14 | type File struct { 15 | name string 16 | fp string 17 | } 18 | 19 | func NewFile(fileRow map[string]any) (File, error) { 20 | name, err := maputil.GetTypeFromMap[string](fileRow, "name") 21 | if err != nil { 22 | return File{}, err 23 | } 24 | 25 | fp, err := maputil.GetTypeFromMap[string](fileRow, "path") 26 | if err != nil { 27 | return File{}, err 28 | } 29 | 30 | return File{name: name, fp: fp}, nil 31 | } 32 | 33 | func NewFileFromTableID(tableID dialect.TableIdentifier, volume string) File { 34 | name := fmt.Sprintf("%s_%s.csv.gz", tableID.Table(), stringutil.Random(10)) 35 | return File{ 36 | name: name, 37 | fp: fmt.Sprintf("/Volumes/%s/%s/%s/%s", tableID.Database(), tableID.Schema(), volume, name), 38 | } 39 | } 40 | 41 | func (f File) Name() string { 42 | return f.name 43 | } 44 | 45 | func (f File) ShouldDelete() bool { 46 | return ddl.ShouldDeleteFromName(strings.TrimSuffix(f.name, ".csv")) 47 | } 48 | 49 | func (f File) DBFSFilePath() string { 50 | return filepath.Join("dbfs:", f.fp) 51 | } 52 | 53 | func (f File) FilePath() string { 54 | return f.fp 55 | } 56 | -------------------------------------------------------------------------------- /clients/iceberg/dialect/data_types.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/artie-labs/transfer/lib/config" 8 | "github.com/artie-labs/transfer/lib/sql" 9 | "github.com/artie-labs/transfer/lib/typing" 10 | ) 11 | 12 | // Ref: https://iceberg.apache.org/docs/latest/spark-getting-started/#iceberg-type-to-spark-type 13 | 14 | func (IcebergDialect) DataTypeForKind(kindDetails typing.KindDetails, _ bool, _ config.SharedDestinationColumnSettings) string { 15 | switch kindDetails.Kind { 16 | case typing.Boolean.Kind: 17 | return "BOOLEAN" 18 | case 19 | typing.Array.Kind, 20 | typing.Struct.Kind, 21 | typing.String.Kind, 22 | typing.Time.Kind: 23 | return "STRING" 24 | case typing.Float.Kind: 25 | return "DOUBLE" 26 | case typing.EDecimal.Kind: 27 | return kindDetails.ExtendedDecimalDetails.IcebergKind() 28 | case typing.Integer.Kind: 29 | if kindDetails.OptionalIntegerKind != nil { 30 | switch *kindDetails.OptionalIntegerKind { 31 | case typing.SmallIntegerKind, typing.IntegerKind: 32 | return "INTEGER" 33 | } 34 | } 35 | return "LONG" 36 | case typing.Date.Kind: 37 | return "DATE" 38 | case typing.TimestampNTZ.Kind: 39 | return "TIMESTAMP_NTZ" 40 | case typing.TimestampTZ.Kind: 41 | return "TIMESTAMP" 42 | default: 43 | return kindDetails.Kind 44 | } 45 | } 46 | 47 | func (IcebergDialect) KindForDataType(rawType string) (typing.KindDetails, error) { 48 | rawType = strings.ToLower(rawType) 49 | if strings.HasPrefix(rawType, "decimal") { 50 | _, parameters, err := sql.ParseDataTypeDefinition(rawType) 51 | if err != nil { 52 | return typing.Invalid, err 53 | } 54 | return typing.ParseNumeric(parameters) 55 | } 56 | 57 | switch rawType { 58 | case "boolean": 59 | return typing.Boolean, nil 60 | case "integer": 61 | return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.IntegerKind)}, nil 62 | case "long", "bigint": 63 | return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.BigIntegerKind)}, nil 64 | case "double", "float": 65 | return typing.Float, nil 66 | case "string", "binary", "uuid", "fixed": 67 | return typing.String, nil 68 | case "date": 69 | return typing.Date, nil 70 | case "timestamp": 71 | return typing.TimestampTZ, nil 72 | case "timestamp_ntz": 73 | return typing.TimestampNTZ, nil 74 | default: 75 | return typing.Invalid, fmt.Errorf("unsupported data type: %q", rawType) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /clients/iceberg/dialect/tableid.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/artie-labs/transfer/lib/sql" 8 | ) 9 | 10 | var _dialect = IcebergDialect{} 11 | 12 | type TableIdentifier struct { 13 | catalog string 14 | namespace string 15 | table string 16 | disableDropProtection bool 17 | } 18 | 19 | func NewTableIdentifier(catalog, namespace, table string) TableIdentifier { 20 | return TableIdentifier{catalog: catalog, namespace: namespace, table: table} 21 | } 22 | 23 | func (ti TableIdentifier) Namespace() string { 24 | return strings.ToLower(ti.namespace) 25 | } 26 | 27 | func (ti TableIdentifier) EscapedTable() string { 28 | return _dialect.QuoteIdentifier(ti.table) 29 | } 30 | 31 | func (ti TableIdentifier) Table() string { 32 | return strings.ToLower(ti.table) 33 | } 34 | 35 | func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { 36 | return NewTableIdentifier(ti.catalog, ti.namespace, table) 37 | } 38 | 39 | func (ti TableIdentifier) FullyQualifiedName() string { 40 | return fmt.Sprintf("%s.%s.%s", _dialect.QuoteIdentifier(ti.catalog), _dialect.QuoteIdentifier(ti.namespace), ti.EscapedTable()) 41 | } 42 | 43 | func (ti TableIdentifier) WithDisableDropProtection(disableDropProtection bool) sql.TableIdentifier { 44 | ti.disableDropProtection = disableDropProtection 45 | return ti 46 | } 47 | 48 | func (ti TableIdentifier) AllowToDrop() bool { 49 | return ti.disableDropProtection 50 | } 51 | -------------------------------------------------------------------------------- /clients/mssql/dialect/ddl.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | mssql "github.com/microsoft/go-mssqldb" 8 | 9 | "github.com/artie-labs/transfer/lib/sql" 10 | "github.com/artie-labs/transfer/lib/typing" 11 | ) 12 | 13 | const describeTableQuery = ` 14 | SELECT 15 | COLUMN_NAME, 16 | CASE 17 | WHEN DATA_TYPE IN ('numeric', 'decimal') THEN 18 | DATA_TYPE + '(' + CAST(NUMERIC_PRECISION AS VARCHAR) + ',' + CAST(NUMERIC_SCALE AS VARCHAR) + ')' 19 | WHEN DATA_TYPE IN ('varchar', 'nvarchar', 'char', 'nchar', 'ntext', 'text') THEN 20 | DATA_TYPE + '(' + CAST(CHARACTER_MAXIMUM_LENGTH AS VARCHAR) + ')' 21 | WHEN DATA_TYPE IN ('datetime2', 'time') THEN 22 | DATA_TYPE + '(' + CAST(DATETIME_PRECISION AS VARCHAR) + ')' 23 | ELSE 24 | DATA_TYPE 25 | END AS DATA_TYPE, 26 | COLUMN_DEFAULT AS DEFAULT_VALUE 27 | FROM 28 | INFORMATION_SCHEMA.COLUMNS 29 | WHERE 30 | LOWER(TABLE_SCHEMA) = LOWER(?) AND LOWER(TABLE_NAME) = LOWER(?); 31 | ` 32 | 33 | func (MSSQLDialect) BuildDescribeTableQuery(tableID sql.TableIdentifier) (string, []any, error) { 34 | mssqlTableID, err := typing.AssertType[TableIdentifier](tableID) 35 | if err != nil { 36 | return "", nil, err 37 | } 38 | 39 | return describeTableQuery, []any{mssql.VarChar(mssqlTableID.Schema()), mssql.VarChar(mssqlTableID.Table())}, nil 40 | } 41 | 42 | func (MSSQLDialect) BuildAddColumnQuery(tableID sql.TableIdentifier, sqlPart string) string { 43 | return fmt.Sprintf("ALTER TABLE %s ADD %s", tableID.FullyQualifiedName(), sqlPart) 44 | } 45 | 46 | func (MSSQLDialect) BuildDropColumnQuery(tableID sql.TableIdentifier, colName string) string { 47 | return fmt.Sprintf("ALTER TABLE %s DROP %s", tableID.FullyQualifiedName(), colName) 48 | } 49 | 50 | func (MSSQLDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, _ bool, colSQLParts []string) string { 51 | // Microsoft SQL Server uses the same syntax for temporary and permanent tables. 52 | // Microsoft SQL Server doesn't support IF NOT EXISTS 53 | return fmt.Sprintf("CREATE TABLE %s (%s);", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ",")) 54 | } 55 | 56 | func (MSSQLDialect) BuildDropTableQuery(tableID sql.TableIdentifier) string { 57 | return "DROP TABLE IF EXISTS " + tableID.FullyQualifiedName() 58 | } 59 | 60 | func (MSSQLDialect) BuildTruncateTableQuery(tableID sql.TableIdentifier) string { 61 | return "TRUNCATE TABLE " + tableID.FullyQualifiedName() 62 | } 63 | -------------------------------------------------------------------------------- /clients/mssql/dialect/default.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import "github.com/artie-labs/transfer/lib/sql" 4 | 5 | func (MSSQLDialect) GetDefaultValueStrategy() sql.DefaultValueStrategy { 6 | return sql.Native 7 | } 8 | -------------------------------------------------------------------------------- /clients/mssql/dialect/tableid.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/sql" 7 | ) 8 | 9 | var _dialect = MSSQLDialect{} 10 | 11 | type TableIdentifier struct { 12 | schema string 13 | table string 14 | disableDropProtection bool 15 | } 16 | 17 | func NewTableIdentifier(schema, table string) TableIdentifier { 18 | return TableIdentifier{schema: schema, table: table} 19 | } 20 | 21 | func (ti TableIdentifier) Schema() string { 22 | return ti.schema 23 | } 24 | 25 | func (ti TableIdentifier) EscapedTable() string { 26 | return _dialect.QuoteIdentifier(ti.table) 27 | } 28 | 29 | func (ti TableIdentifier) Table() string { 30 | return ti.table 31 | } 32 | 33 | func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { 34 | return NewTableIdentifier(ti.schema, table) 35 | } 36 | 37 | func (ti TableIdentifier) FullyQualifiedName() string { 38 | return fmt.Sprintf("%s.%s", _dialect.QuoteIdentifier(ti.schema), ti.EscapedTable()) 39 | } 40 | 41 | func (ti TableIdentifier) WithDisableDropProtection(disableDropProtection bool) sql.TableIdentifier { 42 | ti.disableDropProtection = disableDropProtection 43 | return ti 44 | } 45 | 46 | func (ti TableIdentifier) AllowToDrop() bool { 47 | return ti.disableDropProtection 48 | } 49 | -------------------------------------------------------------------------------- /clients/mssql/dialect/tableid_test.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTableIdentifier_WithTable(t *testing.T) { 10 | tableID := NewTableIdentifier("schema", "foo") 11 | tableID2 := tableID.WithTable("bar") 12 | typedTableID2, ok := tableID2.(TableIdentifier) 13 | assert.True(t, ok) 14 | assert.Equal(t, "schema", typedTableID2.Schema()) 15 | assert.Equal(t, "bar", tableID2.Table()) 16 | } 17 | 18 | func TestTableIdentifier_FullyQualifiedName(t *testing.T) { 19 | // Table name that is not a reserved word: 20 | assert.Equal(t, `"schema"."foo"`, NewTableIdentifier("schema", "foo").FullyQualifiedName()) 21 | 22 | // Table name that is a reserved word: 23 | assert.Equal(t, `"schema"."table"`, NewTableIdentifier("schema", "table").FullyQualifiedName()) 24 | } 25 | 26 | func TestTableIdentifier_EscapedTable(t *testing.T) { 27 | // Table name that is not a reserved word: 28 | assert.Equal(t, `"foo"`, NewTableIdentifier("schema", "foo").EscapedTable()) 29 | 30 | // Table name that is a reserved word: 31 | assert.Equal(t, `"table"`, NewTableIdentifier("schema", "table").EscapedTable()) 32 | } 33 | -------------------------------------------------------------------------------- /clients/mssql/staging.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | mssql "github.com/microsoft/go-mssqldb" 8 | 9 | "github.com/artie-labs/transfer/clients/shared" 10 | "github.com/artie-labs/transfer/lib/destination/types" 11 | "github.com/artie-labs/transfer/lib/optimization" 12 | "github.com/artie-labs/transfer/lib/sql" 13 | "github.com/artie-labs/transfer/lib/typing/columns" 14 | ) 15 | 16 | func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, dwh *types.DestinationTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, opts types.AdditionalSettings, createTempTable bool) error { 17 | if createTempTable { 18 | if err := shared.CreateTempTable(ctx, s, tableData, dwh, opts.ColumnSettings, tempTableID); err != nil { 19 | return err 20 | } 21 | } 22 | 23 | tx, err := s.Begin() 24 | if err != nil { 25 | return fmt.Errorf("failed to begin transaction: %w", err) 26 | } 27 | 28 | var txCommitted bool 29 | defer func() { 30 | if !txCommitted { 31 | tx.Rollback() 32 | } 33 | }() 34 | 35 | cols := tableData.ReadOnlyInMemoryCols().ValidColumns() 36 | stmt, err := tx.Prepare(mssql.CopyIn(tempTableID.FullyQualifiedName(), mssql.BulkOptions{}, columns.ColumnNames(cols)...)) 37 | if err != nil { 38 | return fmt.Errorf("failed to prepare bulk insert: %w", err) 39 | } 40 | 41 | defer stmt.Close() 42 | 43 | for _, value := range tableData.Rows() { 44 | var row []any 45 | for _, col := range cols { 46 | castedValue, castErr := parseValue(value[col.Name()], col) 47 | if castErr != nil { 48 | return castErr 49 | } 50 | 51 | row = append(row, castedValue) 52 | } 53 | 54 | if _, err = stmt.ExecContext(ctx, row...); err != nil { 55 | return fmt.Errorf("failed to copy row: %w", err) 56 | } 57 | } 58 | 59 | results, err := stmt.ExecContext(ctx) 60 | if err != nil { 61 | return fmt.Errorf("failed to finalize bulk insert: %w", err) 62 | } 63 | 64 | rowsLoaded, err := results.RowsAffected() 65 | if err != nil { 66 | return fmt.Errorf("failed to get rows affected: %w", err) 67 | } 68 | 69 | if expectedRows := int64(tableData.NumberOfRows()); rowsLoaded != expectedRows { 70 | return fmt.Errorf("expected %d rows to be loaded, but got %d", expectedRows, rowsLoaded) 71 | } 72 | 73 | if err = tx.Commit(); err != nil { 74 | return fmt.Errorf("failed to commit transaction: %w", err) 75 | } 76 | 77 | txCommitted = true 78 | return nil 79 | } 80 | -------------------------------------------------------------------------------- /clients/mssql/store_test.go: -------------------------------------------------------------------------------- 1 | package mssql 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/artie-labs/transfer/clients/shared" 10 | "github.com/artie-labs/transfer/lib/config" 11 | "github.com/artie-labs/transfer/lib/kafkalib" 12 | "github.com/artie-labs/transfer/lib/optimization" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestTempTableIDWithSuffix(t *testing.T) { 17 | trimTTL := func(tableName string) string { 18 | lastUnderscore := strings.LastIndex(tableName, "_") 19 | assert.GreaterOrEqual(t, lastUnderscore, 0) 20 | epoch, err := strconv.ParseInt(tableName[lastUnderscore+1:len(tableName)-1], 10, 64) 21 | assert.NoError(t, err) 22 | assert.Greater(t, time.Unix(epoch, 0), time.Now().Add(5*time.Hour)) // default TTL is 6 hours from now 23 | return tableName[:lastUnderscore] + string(tableName[len(tableName)-1]) 24 | } 25 | 26 | store := Store{} 27 | { 28 | // Schema is "schema": 29 | tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table") 30 | tableID := store.IdentifierFor(tableData.TopicConfig().BuildDatabaseAndSchemaPair(), tableData.Name()) 31 | tempTableName := shared.TempTableIDWithSuffix(tableID, "sUfFiX").FullyQualifiedName() 32 | assert.Equal(t, `"schema"."table___artie_sUfFiX"`, trimTTL(tempTableName)) 33 | } 34 | { 35 | // Schema is "public" -> "dbo": 36 | tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "public"}, "table") 37 | tableID := store.IdentifierFor(tableData.TopicConfig().BuildDatabaseAndSchemaPair(), tableData.Name()) 38 | tempTableName := shared.TempTableIDWithSuffix(tableID, "sUfFiX").FullyQualifiedName() 39 | assert.Equal(t, `"dbo"."table___artie_sUfFiX"`, trimTTL(tempTableName)) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /clients/redshift/dialect/ddl.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/artie-labs/transfer/lib/sql" 8 | "github.com/artie-labs/transfer/lib/typing" 9 | ) 10 | 11 | const describeTableQuery = `SELECT 12 | c.column_name, 13 | CASE 14 | WHEN c.data_type IN ('numeric') THEN 15 | c.data_type + '(' + CAST(c.numeric_precision AS VARCHAR) + ',' + CAST(c.numeric_scale AS VARCHAR) + ')' 16 | WHEN c.data_type IN ('character varying') THEN 17 | c.data_type + '(' + CAST(c.character_maximum_length AS VARCHAR) + ')' 18 | ELSE 19 | c.data_type 20 | END AS data_type, 21 | d.description 22 | FROM 23 | INFORMATION_SCHEMA.COLUMNS c 24 | LEFT JOIN 25 | PG_CLASS c1 ON c.table_name = c1.relname 26 | LEFT JOIN 27 | PG_CATALOG.PG_NAMESPACE n ON c.table_schema = n.nspname AND c1.relnamespace = n.oid 28 | LEFT JOIN 29 | PG_CATALOG.PG_DESCRIPTION d ON d.objsubid = c.ordinal_position AND d.objoid = c1.oid 30 | WHERE 31 | LOWER(c.table_schema) = LOWER($1) AND LOWER(c.table_name) = LOWER($2);` 32 | 33 | func (RedshiftDialect) BuildDescribeTableQuery(tableID sql.TableIdentifier) (string, []any, error) { 34 | redshiftTableID, err := typing.AssertType[TableIdentifier](tableID) 35 | if err != nil { 36 | return "", nil, err 37 | } 38 | 39 | // This query is a modified fork from: https://gist.github.com/alexanderlz/7302623 40 | return describeTableQuery, []any{redshiftTableID.Schema(), redshiftTableID.Table()}, nil 41 | } 42 | 43 | func (RedshiftDialect) BuildAddColumnQuery(tableID sql.TableIdentifier, sqlPart string) string { 44 | return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", tableID.FullyQualifiedName(), sqlPart) 45 | } 46 | 47 | func (RedshiftDialect) BuildDropColumnQuery(tableID sql.TableIdentifier, colName string) string { 48 | return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableID.FullyQualifiedName(), colName) 49 | } 50 | 51 | func (RedshiftDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, _ bool, colSQLParts []string) string { 52 | // Redshift uses the same syntax for temporary and permanent tables. 53 | return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s);", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ",")) 54 | } 55 | 56 | func (RedshiftDialect) BuildDropTableQuery(tableID sql.TableIdentifier) string { 57 | return "DROP TABLE IF EXISTS " + tableID.FullyQualifiedName() 58 | } 59 | 60 | func (RedshiftDialect) BuildTruncateTableQuery(tableID sql.TableIdentifier) string { 61 | return "TRUNCATE TABLE " + tableID.FullyQualifiedName() 62 | } 63 | -------------------------------------------------------------------------------- /clients/redshift/dialect/default.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/sql" 7 | ) 8 | 9 | func (RedshiftDialect) GetDefaultValueStrategy() sql.DefaultValueStrategy { 10 | return sql.Backfill 11 | } 12 | 13 | func (RedshiftDialect) BuildBackfillQuery(tableID sql.TableIdentifier, escapedColumn string, defaultValue any) string { 14 | return fmt.Sprintf(`UPDATE %s SET %s = %v WHERE %s IS NULL;`, tableID.FullyQualifiedName(), escapedColumn, defaultValue, escapedColumn) 15 | } 16 | -------------------------------------------------------------------------------- /clients/redshift/dialect/default_test.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/artie-labs/transfer/lib/typing" 7 | "github.com/artie-labs/transfer/lib/typing/columns" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestRedshiftDialect_BuildBackfillQuery(t *testing.T) { 12 | _dialect := RedshiftDialect{} 13 | 14 | tableID := NewTableIdentifier("{SCHEMA}", "{TABLE}") 15 | col := columns.NewColumn("{COLUMN}", typing.String) 16 | 17 | assert.Equal(t, `UPDATE {SCHEMA}."{table}" SET "{column}" = {DEFAULT_VALUE} WHERE "{column}" IS NULL;`, _dialect.BuildBackfillQuery(tableID, _dialect.QuoteIdentifier(col.Name()), "{DEFAULT_VALUE}")) 18 | } 19 | -------------------------------------------------------------------------------- /clients/redshift/dialect/tableid.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/sql" 7 | ) 8 | 9 | var _dialect = RedshiftDialect{} 10 | 11 | type TableIdentifier struct { 12 | schema string 13 | table string 14 | disableDropProtection bool 15 | } 16 | 17 | func NewTableIdentifier(schema, table string) TableIdentifier { 18 | return TableIdentifier{schema: schema, table: table} 19 | } 20 | 21 | func (ti TableIdentifier) Schema() string { 22 | return ti.schema 23 | } 24 | 25 | func (ti TableIdentifier) EscapedTable() string { 26 | return _dialect.QuoteIdentifier(ti.table) 27 | } 28 | 29 | func (ti TableIdentifier) Table() string { 30 | return ti.table 31 | } 32 | 33 | func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { 34 | return NewTableIdentifier(ti.schema, table) 35 | } 36 | 37 | func (ti TableIdentifier) FullyQualifiedName() string { 38 | // Redshift is Postgres compatible, so when establishing a connection, we'll specify a database. 39 | // Thus, we only need to specify schema and table name here. 40 | return fmt.Sprintf("%s.%s", ti.schema, ti.EscapedTable()) 41 | } 42 | 43 | func (ti TableIdentifier) WithDisableDropProtection(disableDropProtection bool) sql.TableIdentifier { 44 | ti.disableDropProtection = disableDropProtection 45 | return ti 46 | } 47 | 48 | func (ti TableIdentifier) AllowToDrop() bool { 49 | return ti.disableDropProtection 50 | } 51 | -------------------------------------------------------------------------------- /clients/redshift/dialect/tableid_test.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTableIdentifier_WithTable(t *testing.T) { 10 | tableID := NewTableIdentifier("schema", "foo") 11 | tableID2 := tableID.WithTable("bar") 12 | typedTableID2, ok := tableID2.(TableIdentifier) 13 | assert.True(t, ok) 14 | assert.Equal(t, "schema", typedTableID2.Schema()) 15 | assert.Equal(t, "bar", tableID2.Table()) 16 | } 17 | 18 | func TestTableIdentifier_FullyQualifiedName(t *testing.T) { 19 | // Table name that is not a reserved word: 20 | assert.Equal(t, `schema."foo"`, NewTableIdentifier("schema", "foo").FullyQualifiedName()) 21 | 22 | // Table name that is a reserved word: 23 | assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table").FullyQualifiedName()) 24 | } 25 | 26 | func TestTableIdentifier_EscapedTable(t *testing.T) { 27 | // Table name that is not a reserved word: 28 | assert.Equal(t, `"foo"`, NewTableIdentifier("schema", "foo").EscapedTable()) 29 | 30 | // Table name that is a reserved word: 31 | assert.Equal(t, `"table"`, NewTableIdentifier("schema", "table").EscapedTable()) 32 | } 33 | -------------------------------------------------------------------------------- /clients/redshift/redshift_bench_test.go: -------------------------------------------------------------------------------- 1 | package redshift 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | 8 | "github.com/artie-labs/transfer/lib/stringutil" 9 | 10 | "github.com/artie-labs/transfer/lib/config/constants" 11 | "github.com/artie-labs/transfer/lib/typing" 12 | "github.com/artie-labs/transfer/lib/typing/columns" 13 | ) 14 | 15 | func BenchmarkMethods(b *testing.B) { 16 | // Random string of length [500, 100,000) 17 | colVal := stringutil.Random(rand.Intn(100000) + 500) // use the same value for both benchmarks 18 | colKind := columns.NewColumn("foo", typing.String) // use the same column kind for both benchmarks 19 | 20 | b.Run("OldMethod", func(b *testing.B) { 21 | for i := 0; i < b.N; i++ { 22 | replaceExceededValuesOld(colVal, colKind) 23 | } 24 | }) 25 | 26 | b.Run("NewMethod", func(b *testing.B) { 27 | for i := 0; i < b.N; i++ { 28 | replaceExceededValuesNew(colVal, colKind) 29 | } 30 | }) 31 | } 32 | 33 | func replaceExceededValuesOld(colVal any, colKind columns.Column) any { 34 | colValString := fmt.Sprint(colVal) 35 | switch colKind.KindDetails.Kind { 36 | case typing.Struct.Kind: 37 | if int32(len(colValString)) > maxStringLength { 38 | return map[string]any{ 39 | "key": constants.ExceededValueMarker, 40 | } 41 | } 42 | case typing.String.Kind: 43 | if int32(len(colValString)) > maxStringLength { 44 | return constants.ExceededValueMarker 45 | } 46 | } 47 | 48 | return colVal 49 | } 50 | 51 | func replaceExceededValuesNew(colVal any, colKind columns.Column) any { 52 | colValString := fmt.Sprint(colVal) 53 | colValBytes := int32(len(colValString)) 54 | switch colKind.KindDetails.Kind { 55 | case typing.Struct.Kind: 56 | if colValBytes > maxStringLength { 57 | return map[string]any{ 58 | "key": constants.ExceededValueMarker, 59 | } 60 | } 61 | case typing.String.Kind: 62 | if colValBytes > maxStringLength { 63 | return constants.ExceededValueMarker 64 | } 65 | } 66 | 67 | return colVal 68 | } 69 | -------------------------------------------------------------------------------- /clients/redshift/redshift_suite_test.go: -------------------------------------------------------------------------------- 1 | package redshift 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/artie-labs/transfer/lib/config" 7 | "github.com/artie-labs/transfer/lib/db" 8 | "github.com/artie-labs/transfer/lib/mocks" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/suite" 11 | ) 12 | 13 | type RedshiftTestSuite struct { 14 | suite.Suite 15 | fakeStore *mocks.FakeStore 16 | store *Store 17 | } 18 | 19 | func (r *RedshiftTestSuite) SetupTest() { 20 | cfg := config.Config{ 21 | Redshift: &config.Redshift{}, 22 | } 23 | 24 | r.fakeStore = &mocks.FakeStore{} 25 | store := db.Store(r.fakeStore) 26 | var err error 27 | r.store, err = LoadRedshift(r.T().Context(), cfg, &store) 28 | assert.NoError(r.T(), err) 29 | } 30 | 31 | func TestRedshiftTestSuite(t *testing.T) { 32 | suite.Run(t, new(RedshiftTestSuite)) 33 | } 34 | -------------------------------------------------------------------------------- /clients/redshift/redshift_test.go: -------------------------------------------------------------------------------- 1 | package redshift 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/artie-labs/transfer/clients/shared" 10 | "github.com/artie-labs/transfer/lib/config" 11 | "github.com/artie-labs/transfer/lib/kafkalib" 12 | "github.com/artie-labs/transfer/lib/optimization" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestTempTableIDWithSuffix(t *testing.T) { 17 | trimTTL := func(tableName string) string { 18 | lastUnderscore := strings.LastIndex(tableName, "_") 19 | assert.GreaterOrEqual(t, lastUnderscore, 0) 20 | epoch, err := strconv.ParseInt(tableName[lastUnderscore+1:len(tableName)-1], 10, 64) 21 | assert.NoError(t, err) 22 | assert.Greater(t, time.Unix(epoch, 0), time.Now().Add(5*time.Hour)) // default TTL is 6 hours from now 23 | return tableName[:lastUnderscore] + string(tableName[len(tableName)-1]) 24 | } 25 | 26 | tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table") 27 | tableID := (&Store{}).IdentifierFor(tableData.TopicConfig().BuildDatabaseAndSchemaPair(), tableData.Name()) 28 | tempTableName := shared.TempTableIDWithSuffix(tableID, "sUfFiX").FullyQualifiedName() 29 | assert.Equal(t, `schema."table___artie_suffix"`, trimTTL(tempTableName)) 30 | } 31 | -------------------------------------------------------------------------------- /clients/s3/s3_test.go: -------------------------------------------------------------------------------- 1 | package s3 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | 11 | "github.com/artie-labs/transfer/lib/config" 12 | "github.com/artie-labs/transfer/lib/config/constants" 13 | "github.com/artie-labs/transfer/lib/kafkalib" 14 | "github.com/artie-labs/transfer/lib/optimization" 15 | ) 16 | 17 | func TestBuildTemporaryFilePath(t *testing.T) { 18 | ts := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) 19 | fp := buildTemporaryFilePath(&optimization.TableData{LatestCDCTs: ts}) 20 | assert.True(t, strings.HasPrefix(fp, "/tmp/1577836800000_"), fp) 21 | assert.True(t, strings.HasSuffix(fp, ".parquet"), fp) 22 | } 23 | 24 | func TestObjectPrefix(t *testing.T) { 25 | td := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", TableName: "table", Schema: "public"}, "table") 26 | { 27 | // Valid - No Folder 28 | store, err := LoadStore(t.Context(), config.Config{S3: &config.S3Settings{ 29 | Bucket: "bucket", 30 | AwsSecretAccessKey: "foo", 31 | AwsAccessKeyID: "bar", 32 | AwsRegion: "us-east-1", 33 | OutputFormat: constants.ParquetFormat, 34 | }}) 35 | 36 | assert.NoError(t, err) 37 | assert.Equal(t, fmt.Sprintf("db.public.table/date=%s", time.Now().Format(time.DateOnly)), store.ObjectPrefix(td)) 38 | } 39 | { 40 | // Valid - With Folder 41 | store, err := LoadStore(t.Context(), config.Config{S3: &config.S3Settings{ 42 | Bucket: "bucket", 43 | AwsSecretAccessKey: "foo", 44 | AwsAccessKeyID: "bar", 45 | AwsRegion: "us-east-1", 46 | FolderName: "foo", 47 | OutputFormat: constants.ParquetFormat, 48 | }}) 49 | 50 | assert.NoError(t, err) 51 | assert.Equal(t, fmt.Sprintf("foo/db.public.table/date=%s", time.Now().Format(time.DateOnly)), store.ObjectPrefix(td)) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /clients/s3/tableid.go: -------------------------------------------------------------------------------- 1 | package s3 2 | 3 | import ( 4 | "cmp" 5 | "strings" 6 | 7 | "github.com/artie-labs/transfer/lib/sql" 8 | ) 9 | 10 | type TableIdentifier struct { 11 | database string 12 | schema string 13 | table string 14 | nameSeparator string 15 | disableDropProtection bool 16 | } 17 | 18 | func NewTableIdentifier(database, schema, table string, nameSeparator string) TableIdentifier { 19 | return TableIdentifier{database: database, schema: schema, table: table, nameSeparator: cmp.Or(nameSeparator, ".")} 20 | } 21 | 22 | func (ti TableIdentifier) Database() string { 23 | return ti.database 24 | } 25 | 26 | func (ti TableIdentifier) Schema() string { 27 | return ti.schema 28 | } 29 | 30 | func (ti TableIdentifier) EscapedTable() string { 31 | // S3 doesn't require escaping 32 | return ti.table 33 | } 34 | 35 | func (ti TableIdentifier) Table() string { 36 | return ti.table 37 | } 38 | 39 | func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { 40 | return NewTableIdentifier(ti.database, ti.schema, table, ti.nameSeparator) 41 | } 42 | 43 | func (ti TableIdentifier) FullyQualifiedName() string { 44 | return strings.Join([]string{ti.database, ti.schema, ti.EscapedTable()}, ti.nameSeparator) 45 | } 46 | 47 | func (ti TableIdentifier) WithDisableDropProtection(disableDropProtection bool) sql.TableIdentifier { 48 | ti.disableDropProtection = disableDropProtection 49 | return ti 50 | } 51 | 52 | func (ti TableIdentifier) AllowToDrop() bool { 53 | return ti.disableDropProtection 54 | } 55 | -------------------------------------------------------------------------------- /clients/s3/tableid_test.go: -------------------------------------------------------------------------------- 1 | package s3 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTableIdentifier_WithTable(t *testing.T) { 10 | tableID := NewTableIdentifier("database", "schema", "foo", "") 11 | tableID2 := tableID.WithTable("bar") 12 | typedTableID2, ok := tableID2.(TableIdentifier) 13 | assert.True(t, ok) 14 | assert.Equal(t, "database", typedTableID2.Database()) 15 | assert.Equal(t, "schema", typedTableID2.Schema()) 16 | assert.Equal(t, "bar", tableID2.Table()) 17 | } 18 | 19 | func TestTableIdentifier_FullyQualifiedName(t *testing.T) { 20 | { 21 | // S3 doesn't escape the table name. 22 | tableID := NewTableIdentifier("database", "schema", "table", "") 23 | assert.Equal(t, "database.schema.table", tableID.FullyQualifiedName()) 24 | } 25 | { 26 | // Separator via `/` 27 | tableID := NewTableIdentifier("database", "schema", "table", "/") 28 | assert.Equal(t, "database/schema/table", tableID.FullyQualifiedName()) 29 | } 30 | } 31 | 32 | func TestTableIdentifier_EscapedTable(t *testing.T) { 33 | // S3 doesn't escape the table name. 34 | tableID := NewTableIdentifier("database", "schema", "table", "") 35 | assert.Equal(t, "table", tableID.EscapedTable()) 36 | } 37 | -------------------------------------------------------------------------------- /clients/shared/append.go: -------------------------------------------------------------------------------- 1 | package shared 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/artie-labs/transfer/lib/destination" 8 | "github.com/artie-labs/transfer/lib/destination/types" 9 | "github.com/artie-labs/transfer/lib/optimization" 10 | "github.com/artie-labs/transfer/lib/typing/columns" 11 | ) 12 | 13 | func Append(ctx context.Context, dest destination.Destination, tableData *optimization.TableData, opts types.AdditionalSettings) error { 14 | if tableData.ShouldSkipUpdate() { 15 | return nil 16 | } 17 | 18 | tableID := dest.IdentifierFor(tableData.TopicConfig().BuildDatabaseAndSchemaPair(), tableData.Name()) 19 | tableConfig, err := dest.GetTableConfig(tableID, tableData.TopicConfig().DropDeletedColumns) 20 | if err != nil { 21 | return fmt.Errorf("failed to get table config: %w", err) 22 | } 23 | 24 | // We don't care about srcKeysMissing because we don't drop columns when we append. 25 | _, targetKeysMissing := columns.DiffAndFilter( 26 | tableData.ReadOnlyInMemoryCols().GetColumns(), 27 | tableConfig.GetColumns(), 28 | tableData.BuildColumnsToKeep(), 29 | ) 30 | 31 | if tableConfig.CreateTable() { 32 | if err = CreateTable(ctx, dest, tableData.Mode(), tableConfig, opts.ColumnSettings, tableID, false, targetKeysMissing); err != nil { 33 | return fmt.Errorf("failed to create table: %w", err) 34 | } 35 | } else { 36 | if err = AlterTableAddColumns(ctx, dest, tableConfig, opts.ColumnSettings, tableID, targetKeysMissing); err != nil { 37 | return fmt.Errorf("failed to alter table: %w", err) 38 | } 39 | } 40 | 41 | if err = tableData.MergeColumnsFromDestination(tableConfig.GetColumns()...); err != nil { 42 | return fmt.Errorf("failed to merge columns from destination: %w", err) 43 | } 44 | 45 | tempTableID := tableID 46 | if opts.UseTempTable { 47 | // Override tableID with tempTableID if we're using a temporary table 48 | tempTableID = opts.TempTableID 49 | } 50 | 51 | return dest.PrepareTemporaryTable( 52 | ctx, 53 | tableData, 54 | tableConfig, 55 | tempTableID, 56 | tableID, 57 | opts, 58 | opts.UseTempTable, 59 | ) 60 | } 61 | -------------------------------------------------------------------------------- /clients/shared/sweep.go: -------------------------------------------------------------------------------- 1 | package shared 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | 8 | "github.com/artie-labs/transfer/lib/destination" 9 | "github.com/artie-labs/transfer/lib/destination/ddl" 10 | "github.com/artie-labs/transfer/lib/kafkalib" 11 | ) 12 | 13 | type GetQueryFunc func(dbName string, schemaName string) (string, []any) 14 | 15 | func Sweep(ctx context.Context, dest destination.Destination, topicConfigs []*kafkalib.TopicConfig, getQueryFunc GetQueryFunc) error { 16 | slog.Info("Looking to see if there are any dangling artie temporary tables to delete...") 17 | for _, dbAndSchemaPair := range kafkalib.GetUniqueDatabaseAndSchemaPairs(topicConfigs) { 18 | query, args := getQueryFunc(dbAndSchemaPair.Database, dbAndSchemaPair.Schema) 19 | rows, err := dest.QueryContext(ctx, query, args...) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | for rows.Next() { 25 | var tableSchema, tableName string 26 | if err = rows.Scan(&tableSchema, &tableName); err != nil { 27 | return err 28 | } 29 | 30 | if ddl.ShouldDeleteFromName(tableName) { 31 | if err = ddl.DropTemporaryTable(ctx, dest, dest.IdentifierFor(dbAndSchemaPair, tableName), true); err != nil { 32 | return err 33 | } 34 | } 35 | } 36 | 37 | if err = rows.Err(); err != nil { 38 | return fmt.Errorf("failed to iterate over rows: %w", err) 39 | } 40 | } 41 | 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /clients/shared/table_config_test.go: -------------------------------------------------------------------------------- 1 | package shared 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/artie-labs/transfer/lib/destination/types" 8 | "github.com/artie-labs/transfer/lib/mocks" 9 | "github.com/artie-labs/transfer/lib/typing" 10 | "github.com/artie-labs/transfer/lib/typing/columns" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestGetTableConfig(t *testing.T) { 15 | // Return early because table is found in configMap. 16 | var cols []columns.Column 17 | for i := range 100 { 18 | cols = append(cols, columns.NewColumn(fmt.Sprintf("col-%v", i), typing.Invalid)) 19 | } 20 | 21 | cm := &types.DestinationTableConfigMap{} 22 | fakeTableID := &mocks.FakeTableIdentifier{} 23 | fakeTableID.FullyQualifiedNameReturns("dusty_the_mini_aussie") 24 | 25 | tableCfg := types.NewDestinationTableConfig(cols, false) 26 | cm.AddTable(fakeTableID, tableCfg) 27 | 28 | actualTableCfg, err := GetTableCfgArgs{ 29 | Destination: &mocks.FakeDestination{}, 30 | TableID: fakeTableID, 31 | ConfigMap: cm, 32 | }.GetTableConfig() 33 | 34 | assert.NoError(t, err) 35 | assert.Equal(t, tableCfg, actualTableCfg) 36 | } 37 | -------------------------------------------------------------------------------- /clients/snowflake/dialect/default.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import "github.com/artie-labs/transfer/lib/sql" 4 | 5 | func (SnowflakeDialect) GetDefaultValueStrategy() sql.DefaultValueStrategy { 6 | return sql.Backfill 7 | } 8 | -------------------------------------------------------------------------------- /clients/snowflake/dialect/tableid.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/sql" 7 | ) 8 | 9 | var _dialect = SnowflakeDialect{} 10 | 11 | type TableIdentifier struct { 12 | database string 13 | schema string 14 | table string 15 | disableDropProtection bool 16 | } 17 | 18 | func NewTableIdentifier(database, schema, table string) TableIdentifier { 19 | return TableIdentifier{ 20 | database: database, 21 | schema: schema, 22 | table: table, 23 | } 24 | } 25 | 26 | func (ti TableIdentifier) Database() string { 27 | return ti.database 28 | } 29 | 30 | func (ti TableIdentifier) Schema() string { 31 | return ti.schema 32 | } 33 | 34 | func (ti TableIdentifier) EscapedTable() string { 35 | return _dialect.QuoteIdentifier(ti.table) 36 | } 37 | 38 | func (ti TableIdentifier) Table() string { 39 | return ti.table 40 | } 41 | 42 | func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { 43 | return NewTableIdentifier(ti.database, ti.schema, table) 44 | } 45 | 46 | func (ti TableIdentifier) FullyQualifiedName() string { 47 | return fmt.Sprintf("%s.%s.%s", _dialect.QuoteIdentifier(ti.database), _dialect.QuoteIdentifier(ti.schema), ti.EscapedTable()) 48 | } 49 | 50 | func (ti TableIdentifier) WithDisableDropProtection(disableDropProtection bool) sql.TableIdentifier { 51 | ti.disableDropProtection = disableDropProtection 52 | return ti 53 | } 54 | 55 | func (ti TableIdentifier) AllowToDrop() bool { 56 | return ti.disableDropProtection 57 | } 58 | -------------------------------------------------------------------------------- /clients/snowflake/dialect/tableid_test.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTableIdentifier_WithTable(t *testing.T) { 10 | tableID := NewTableIdentifier("database", "schema", "foo") 11 | tableID2 := tableID.WithTable("bar") 12 | typedTableID2, ok := tableID2.(TableIdentifier) 13 | assert.True(t, ok) 14 | assert.Equal(t, "database", typedTableID2.Database()) 15 | assert.Equal(t, "schema", typedTableID2.Schema()) 16 | assert.Equal(t, "bar", tableID2.Table()) 17 | } 18 | 19 | func TestTableIdentifier_FullyQualifiedName(t *testing.T) { 20 | // Table name that is not a reserved word: 21 | assert.Equal(t, `"DATABASE"."SCHEMA"."FOO"`, NewTableIdentifier("database", "schema", "foo").FullyQualifiedName()) 22 | 23 | // Table name that is a reserved word: 24 | assert.Equal(t, `"DATABASE"."SCHEMA"."TABLE"`, NewTableIdentifier("database", "schema", "table").FullyQualifiedName()) 25 | } 26 | 27 | func TestTableIdentifier_EscapedTable(t *testing.T) { 28 | // Table name that is not a reserved word: 29 | assert.Equal(t, `"FOO"`, NewTableIdentifier("database", "schema", "foo").EscapedTable()) 30 | 31 | // Table name that is a reserved word: 32 | assert.Equal(t, `"TABLE"`, NewTableIdentifier("database", "schema", "table").EscapedTable()) 33 | } 34 | -------------------------------------------------------------------------------- /clients/snowflake/snowflake_suite_test.go: -------------------------------------------------------------------------------- 1 | package snowflake 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/DATA-DOG/go-sqlmock" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/suite" 9 | 10 | "github.com/artie-labs/transfer/lib/config" 11 | "github.com/artie-labs/transfer/lib/db" 12 | "github.com/artie-labs/transfer/lib/typing" 13 | ) 14 | 15 | type SnowflakeTestSuite struct { 16 | suite.Suite 17 | mockDB sqlmock.Sqlmock 18 | stageStore *Store 19 | } 20 | 21 | func (s *SnowflakeTestSuite) SetupTest() { 22 | s.ResetStore() 23 | } 24 | 25 | func (s *SnowflakeTestSuite) ResetStore() { 26 | _db, mock, err := sqlmock.New() 27 | assert.NoError(s.T(), err) 28 | 29 | s.mockDB = mock 30 | s.stageStore, err = LoadSnowflake(s.T().Context(), config.Config{Snowflake: &config.Snowflake{}}, typing.ToPtr(db.NewStoreWrapperForTest(_db))) 31 | assert.NoError(s.T(), err) 32 | } 33 | 34 | func TestSnowflakeTestSuite(t *testing.T) { 35 | suite.Run(t, new(SnowflakeTestSuite)) 36 | } 37 | -------------------------------------------------------------------------------- /clients/snowflake/util.go: -------------------------------------------------------------------------------- 1 | package snowflake 2 | 3 | import ( 4 | "github.com/artie-labs/transfer/lib/sql" 5 | ) 6 | 7 | // addPrefixToTableName will take a [sql.TableIdentifier] and add a prefix in front of the table. 8 | // This is necessary for `PUT` commands. 9 | func addPrefixToTableName(tableID sql.TableIdentifier, prefix string) string { 10 | return tableID.WithTable(prefix + tableID.Table()).FullyQualifiedName() 11 | } 12 | -------------------------------------------------------------------------------- /clients/snowflake/util_test.go: -------------------------------------------------------------------------------- 1 | package snowflake 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/artie-labs/transfer/clients/snowflake/dialect" 9 | ) 10 | 11 | func TestAddPrefixToTableName(t *testing.T) { 12 | const prefix = "%" 13 | { 14 | // Database, schema and table name 15 | assert.Equal(t, `"DATABASE"."SCHEMA"."%TABLENAME"`, addPrefixToTableName(dialect.NewTableIdentifier("database", "schema", "tableName"), prefix)) 16 | } 17 | { 18 | // Table name 19 | assert.Equal(t, `"".""."%ORDERS"`, addPrefixToTableName(dialect.NewTableIdentifier("", "", "orders"), prefix)) 20 | } 21 | { 22 | // Schema and table name 23 | assert.Equal(t, `""."PUBLIC"."%ORDERS"`, addPrefixToTableName(dialect.NewTableIdentifier("", "public", "orders"), prefix)) 24 | } 25 | { 26 | // Database and table name 27 | assert.Equal(t, `"DB".""."%TABLENAME"`, addPrefixToTableName(dialect.NewTableIdentifier("db", "", "tableName"), prefix)) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /clients/snowflake/writes.go: -------------------------------------------------------------------------------- 1 | package snowflake 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/artie-labs/transfer/clients/shared" 8 | "github.com/artie-labs/transfer/lib/config/constants" 9 | "github.com/artie-labs/transfer/lib/destination/types" 10 | "github.com/artie-labs/transfer/lib/optimization" 11 | "github.com/artie-labs/transfer/lib/sql" 12 | "github.com/artie-labs/transfer/lib/typing" 13 | "github.com/artie-labs/transfer/lib/typing/columns" 14 | ) 15 | 16 | func (s *Store) Append(ctx context.Context, tableData *optimization.TableData, _ bool) error { 17 | // TODO: For history mode - in the future, we could also have a separate stage name for history mode so we can enable parallel processing. 18 | return shared.Append(ctx, s, tableData, types.AdditionalSettings{ 19 | AdditionalCopyClause: fmt.Sprintf(`FILE_FORMAT = (TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='%s' EMPTY_FIELD_AS_NULL=FALSE) PURGE = TRUE`, constants.NullValuePlaceholder), 20 | }) 21 | } 22 | 23 | func (s *Store) additionalEqualityStrings(tableData *optimization.TableData) []string { 24 | cols := make([]columns.Column, len(tableData.TopicConfig().AdditionalMergePredicates)) 25 | for i, additionalMergePredicate := range tableData.TopicConfig().AdditionalMergePredicates { 26 | cols[i] = columns.NewColumn(additionalMergePredicate.PartitionField, typing.Invalid) 27 | } 28 | return sql.BuildColumnComparisons(cols, constants.TargetAlias, constants.StagingAlias, sql.Equal, s.Dialect()) 29 | } 30 | 31 | func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) (bool, error) { 32 | mergeOpts := types.MergeOpts{ 33 | AdditionalEqualityStrings: s.additionalEqualityStrings(tableData), 34 | } 35 | 36 | if tableData.MultiStepMergeSettings().Enabled { 37 | return shared.MultiStepMerge(ctx, s, tableData, mergeOpts) 38 | } 39 | 40 | if err := shared.Merge(ctx, s, tableData, mergeOpts); err != nil { 41 | return false, fmt.Errorf("failed to merge: %w", err) 42 | } 43 | 44 | return true, nil 45 | } 46 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | There are currently 4 examples that work end-end. 4 | * [MongoDB](https://github.com/artie-labs/transfer/tree/master/examples/mongodb) 5 | * [Postgres](https://github.com/artie-labs/transfer/tree/master/examples/postgres) 6 | * [MySQL](https://github.com/artie-labs/transfer/tree/master/examples/mysql) 7 | 8 | Go into each folder to see the exact instructions on how to run it. 9 | -------------------------------------------------------------------------------- /examples/mongodb/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM artielabs/transfer:latest 2 | 3 | COPY . . 4 | 5 | CMD ["/transfer", "--config", "config.yaml"] 6 | -------------------------------------------------------------------------------- /examples/mongodb/README.md: -------------------------------------------------------------------------------- 1 | # MongoDB Example 2 | 3 | ## Running 4 | 5 | To run this, you'll need to install Docker. We will be running 5 images. 6 | 7 | 1. zookeeper 8 | 2. kafka 9 | 3. MongoDB 10 | 4. Debezium (pulling the data from Mongo and publishing to Kafka) 11 | 5. Transfer (pulling Kafka and writing against a test DB) 12 | 13 | _Note: Snowflake does not have a development Docker image, so the Mock DB will just output the function calls_ 14 | 15 | ### Initial set up 16 | ```sh 17 | docker-compose build 18 | 19 | docker-compose up 20 | 21 | ``` 22 | 23 | ### Registering the connector 24 | ```sh 25 | curl -i -X POST -H "Accept:application/json" -H "Content-Type:application/json" http://localhost:8083/connectors/ -d @register-mongodb-connector.json 26 | 27 | # Now initiate the replica set and insert some dummy data. 28 | docker-compose -f docker-compose.yaml exec mongodb bash -c '/usr/local/bin/init-inventory.sh' 29 | 30 | # Now, if you want to connect to the Mongo shell and insert more data, go right ahead 31 | docker-compose -f docker-compose.yaml exec mongodb bash -c 'mongo -u $MONGODB_USER -p $MONGODB_PASSWORD --authenticationDatabase admin inventory' 32 | db.customers.insert([ 33 | { _id : NumberLong("1020"), first_name : 'Robin', 34 | last_name : 'Tang', email : 'robin@example.com', unique_id : UUID(), 35 | test_bool_false: false, test_bool_true: true, new_id: ObjectId(), 36 | test_decimal: NumberDecimal("13.37"), test_int: NumberInt("1337"), 37 | test_decimal_2: 13.37, test_list: [1, 2, 3, 4, "hello"], test_null: null, test_ts: Timestamp(42, 1), test_nested_object: {a: { b: { c: "hello"}}}} 38 | ]); 39 | ``` 40 | -------------------------------------------------------------------------------- /examples/mongodb/config.yaml: -------------------------------------------------------------------------------- 1 | outputSource: test 2 | 3 | kafka: 4 | bootstrapServer: kafka:9092 5 | groupID: abc1234 6 | topicConfigs: 7 | - db: customers_mongo 8 | tableName: customers 9 | schema: public 10 | topic: "dbserver1.inventory.customers" 11 | cdcFormat: debezium.mongodb 12 | # Turn this on if you plan to use JSON converter (see connect-distributed.properties) 13 | cdcKeyFormat: org.apache.kafka.connect.json.JsonConverter 14 | 15 | snowflake: 16 | account: 123 17 | username: foo 18 | password: bar 19 | warehouse: dwh 20 | region: us-east-2.aws 21 | 22 | telemetry: 23 | metrics: 24 | provider: datadog 25 | settings: 26 | tags: 27 | - env:production 28 | namespace: "transfer." 29 | addr: "127.0.0.1:8125" 30 | -------------------------------------------------------------------------------- /examples/mongodb/connect/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM quay.io/debezium/connect:2.0 2 | 3 | COPY docker-entrypoint.sh /docker-entrypoint.sh 4 | COPY connect-distributed.properties /kafka/config/ 5 | 6 | ENV KAFKA_LOG4J_OPTS=-Dlog4j.configuration=file:/kafka/config/log4j.properties 7 | ENV CONFIG_STORAGE_TOPIC=my_connect_configs 8 | ENV OFFSET_STORAGE_TOPIC=my_connect_offsets 9 | ENV STATUS_STORAGE_TOPIC=my_connect_statuses 10 | 11 | ENTRYPOINT ["/docker-entrypoint.sh", "start"] 12 | -------------------------------------------------------------------------------- /examples/mongodb/connect/connect-distributed.properties: -------------------------------------------------------------------------------- 1 | # A list of host/port pairs to use for establishing the initial connection to the Kafka cluster. 2 | bootstrap.servers=kafka:9092 3 | 4 | # unique name for the cluster, used in forming the Connect cluster group. Note that this must not conflict with consumer group IDs 5 | group.id=connect-cluster 6 | 7 | # The converters specify the format of data in Kafka and how to translate it into Connect data. Every Connect user will 8 | # need to configure these based on the format they want their data in when loaded from or stored into Kafka 9 | # key.converter=org.apache.kafka.connect.storage.StringConverter 10 | key.converter=org.apache.kafka.connect.json.JsonConverter 11 | value.converter=org.apache.kafka.connect.json.JsonConverter 12 | 13 | key.converter.schemas.enable=false 14 | value.converter.schemas.enable=true 15 | 16 | offset.storage.topic=connect-offsets 17 | offset.storage.replication.factor=1 18 | 19 | config.storage.topic=connect-configs 20 | config.storage.replication.factor=1 21 | 22 | status.storage.topic=connect-status 23 | status.storage.replication.factor=1 24 | 25 | # Flush much faster than normal, which is useful for testing/debugging 26 | offset.flush.interval.ms=10000 27 | 28 | # Load the connnectors 29 | plugin.path=/kafka/connect 30 | -------------------------------------------------------------------------------- /examples/mongodb/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | services: 3 | zookeeper: 4 | image: quay.io/debezium/zookeeper:2.0 5 | ports: 6 | - 2181:2181 7 | - 2888:2888 8 | - 3888:3888 9 | kafka: 10 | image: quay.io/debezium/kafka:2.0 11 | ports: 12 | - 9092:9092 13 | links: 14 | - zookeeper 15 | environment: 16 | - ZOOKEEPER_CONNECT=zookeeper:2181 17 | mongodb: 18 | image: quay.io/debezium/example-mongodb:2.0 19 | hostname: mongodb 20 | ports: 21 | - 27017:27017 22 | environment: 23 | - MONGODB_USER=debezium 24 | - MONGODB_PASSWORD=dbz 25 | connect: 26 | build: 27 | context: ./connect 28 | dockerfile: Dockerfile 29 | ports: 30 | - 8083:8083 31 | links: 32 | - kafka 33 | - mongodb 34 | environment: 35 | - BOOTSTRAP_SERVERS=kafka:9092 36 | - GROUP_ID=1 37 | - CONFIG_STORAGE_TOPIC=my_connect_configs 38 | - OFFSET_STORAGE_TOPIC=my_connect_offsets 39 | - STATUS_STORAGE_TOPIC=my_connect_statuses 40 | transfer: 41 | build: 42 | context: . 43 | dockerfile: Dockerfile 44 | links: 45 | - kafka 46 | 47 | -------------------------------------------------------------------------------- /examples/mongodb/register-mongodb-connector.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mongodb-connector", 3 | "config": { 4 | "connector.class" : "io.debezium.connector.mongodb.MongoDbConnector", 5 | "tasks.max" : "1", 6 | "mongodb.hosts" : "rs0/mongodb:27017", 7 | "topic.prefix" : "dbserver1", 8 | "mongodb.user" : "debezium", 9 | "mongodb.password" : "dbz" 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /examples/mysql/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM artielabs/transfer:latest 2 | 3 | COPY . . 4 | 5 | CMD ["/transfer", "--config", "config.yaml"] 6 | -------------------------------------------------------------------------------- /examples/mysql/README.md: -------------------------------------------------------------------------------- 1 | # MySQL example 2 | 3 | This example does the following: 4 | 1. Run Debezium server 5 | 2. Runs MySQL image 6 | 3. Runs Transfer which outputs the merge commands to stdout 7 | 8 | # Running 9 | ``` 10 | docker-compose build 11 | docker-compose up 12 | ``` 13 | 14 | # Connecting to MySQL 15 | Once you have docker-compose running, the MySQL instance can be reached by the following command. 16 | 17 | ```bash 18 | mysql -h 0.0.0.0 -u mysqluser -p 19 | # Password is mysqlpw 20 | ``` 21 | 22 | ```sql 23 | UPDATE inventory.customers SET first_name = 'Artie' where id = 1001; 24 | -- Do any DML and DDL you would like. 25 | -- And it'll show up in the console :) 26 | ``` 27 | -------------------------------------------------------------------------------- /examples/mysql/application.properties: -------------------------------------------------------------------------------- 1 | # Offset storage 2 | debezium.source.offset.storage.file.filename=/tmp/foo 3 | debezium.source.offset.flush.interval.ms=0 4 | 5 | # Kafka setup. 6 | debezium.sink.type=kafka 7 | debezium.sink.kafka.producer.bootstrap.servers=kafka:9092 8 | debezium.sink.kafka.producer.group.id=connect-cluster 9 | debezium.sink.kafka.producer.key.converter=org.apache.kafka.connect.storage.StringConverter 10 | debezium.sink.kafka.producer.value.converter=org.apache.kafka.connect.json.JsonConverter 11 | debezium.sink.kafka.producer.key.converter.schemas.enable=false 12 | debezium.sink.kafka.producer.value.converter.schemas.enable=true 13 | debezium.sink.kafka.producer.key.serializer=org.apache.kafka.common.serialization.StringSerializer 14 | debezium.sink.kafka.producer.key.serializer.schemas.enable=false 15 | debezium.sink.kafka.producer.value.serializer=org.apache.kafka.common.serialization.StringSerializer 16 | 17 | # MySQL (Local) 18 | debezium.source.connector.class=io.debezium.connector.mysql.MySqlConnector 19 | debezium.source.database.hostname=mysql 20 | debezium.source.database.port=3306 21 | debezium.source.database.user=debezium 22 | debezium.source.database.password=dbz 23 | debezium.source.database.server.id=1234 24 | debezium.source.topic.prefix=mysql1 25 | debezium.source.decimal.handling.mode=double 26 | 27 | debezium.source.schema.history.internal.kafka.bootstrap.servers=kafka:9092 28 | debezium.source.schema.history.internal.kafka.topic=schema-changes.inventory 29 | -------------------------------------------------------------------------------- /examples/mysql/config.yaml: -------------------------------------------------------------------------------- 1 | outputSource: test 2 | 3 | kafka: 4 | bootstrapServer: kafka:9092 5 | groupID: abc1234 6 | topicConfigs: 7 | - db: customers 8 | tableName: customers 9 | schema: public 10 | topic: "mysql1.inventory.customers" 11 | cdcFormat: debezium.mysql 12 | # cdcKeyFormat: org.apache.kafka.connect.json.JsonConverter 13 | # If you turn this on, make sure to check connect-distributed.properties for key.converter 14 | # cdcKeyFormat: org.apache.kafka.connect.storage.StringConverter 15 | cdcKeyFormat: org.apache.kafka.connect.json.JsonConverter 16 | dropDeletedColumns: true 17 | softDelete: false 18 | -------------------------------------------------------------------------------- /examples/mysql/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3.9' 2 | services: 3 | zookeeper: 4 | image: quay.io/debezium/zookeeper:2.0 5 | ports: 6 | - 2181:2181 7 | - 2888:2888 8 | - 3888:3888 9 | kafka: 10 | image: quay.io/debezium/kafka:2.0 11 | ports: 12 | - 9092:9092 13 | - 29092:29092 14 | links: 15 | - zookeeper 16 | environment: 17 | KAFKA_LISTENERS: EXTERNAL_SAME_HOST://:29092,INTERNAL://:9092 18 | KAFKA_ADVERTISED_LISTENERS: INTERNAL://kafka:9092,EXTERNAL_SAME_HOST://localhost:29092 19 | KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: INTERNAL:PLAINTEXT,EXTERNAL_SAME_HOST:PLAINTEXT 20 | KAFKA_INTER_BROKER_LISTENER_NAME: INTERNAL 21 | ZOOKEEPER_CONNECT: zookeeper:2181 22 | mysql: 23 | image: quay.io/debezium/example-mysql:2.0 24 | ports: 25 | - 3306:3306 26 | environment: 27 | - MYSQL_ROOT_PASSWORD=debezium 28 | - MYSQL_USER=mysqluser 29 | - MYSQL_PASSWORD=mysqlpw 30 | debezium-server: 31 | image: quay.io/debezium/server:2.0 32 | container_name: debezium-server 33 | command: sh -c "sleep 15 && /debezium/run.sh" 34 | links: 35 | - kafka 36 | - mysql 37 | ports: 38 | - 8080:8080 39 | volumes: 40 | - ./application.properties:/debezium/conf/application.properties 41 | depends_on: 42 | - kafka 43 | - mysql 44 | transfer: 45 | build: 46 | context: . 47 | dockerfile: Dockerfile 48 | links: 49 | - kafka 50 | - mysql 51 | -------------------------------------------------------------------------------- /examples/postgres/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM artielabs/transfer:latest 2 | 3 | COPY . . 4 | 5 | CMD ["/transfer", "--config", "config.yaml"] 6 | -------------------------------------------------------------------------------- /examples/postgres/README.md: -------------------------------------------------------------------------------- 1 | # Postgres Example 2 | 3 | ## Running 4 | 5 | To run this, you'll need to install Docker. We will be running 5 images. 6 | 7 | 1. zookeeper 8 | 2. kafka 9 | 3. Postgres 10 | 4. Debezium (pulling the data from Postgres and publishing to Kafka) 11 | 5. Transfer (pulling Kafka and writing against a test DB) 12 | 13 | _Note: Snowflake does not have a development Docker image, so the Mock DB will just output the function calls_ 14 | 15 | ### Initial set up 16 | ```sh 17 | docker-compose build 18 | 19 | docker-compose up 20 | 21 | ``` 22 | 23 | ### Registering the connector 24 | ```sh 25 | curl -i -X POST -H "Accept:application/json" -H "Content-Type:application/json" http://localhost:8083/connectors/ -d @register-postgres-connector.json 26 | 27 | ## Play around within the Postgres server (insert, update, delete) will now all work. 28 | docker-compose -f docker-compose.yaml exec postgres env PGOPTIONS="--search_path=inventory" bash -c 'psql -U $POSTGRES_USER postgres' 29 | 30 | ``` 31 | -------------------------------------------------------------------------------- /examples/postgres/config.yaml: -------------------------------------------------------------------------------- 1 | outputSource: test 2 | 3 | kafka: 4 | bootstrapServer: kafka:9092 5 | groupID: abc1234 6 | topicConfigs: 7 | - db: customers 8 | tableName: customers 9 | schema: public 10 | topic: "dbserver1.inventory.customers" 11 | cdcFormat: debezium.postgres.wal2json 12 | # cdcKeyFormat: org.apache.kafka.connect.json.JsonConverter 13 | # If you turn this on, make sure to check connect-distributed.properties for key.converter 14 | cdcKeyFormat: org.apache.kafka.connect.storage.StringConverter 15 | 16 | 17 | snowflake: 18 | account: 123 19 | username: foo 20 | password: bar 21 | warehouse: dwh 22 | region: us-east-2.aws 23 | 24 | telemetry: 25 | metrics: 26 | provider: datadog 27 | settings: 28 | tags: 29 | - env:production 30 | namespace: "transfer." 31 | addr: "127.0.0.1:8125" 32 | -------------------------------------------------------------------------------- /examples/postgres/connect/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM quay.io/debezium/connect:2.0 2 | 3 | COPY docker-entrypoint.sh /docker-entrypoint.sh 4 | COPY connect-distributed.properties /kafka/config/ 5 | 6 | ENV KAFKA_LOG4J_OPTS=-Dlog4j.configuration=file:/kafka/config/log4j.properties 7 | ENV CONFIG_STORAGE_TOPIC=my_connect_configs 8 | ENV OFFSET_STORAGE_TOPIC=my_connect_offsets 9 | ENV STATUS_STORAGE_TOPIC=my_connect_statuses 10 | 11 | ENTRYPOINT ["/docker-entrypoint.sh", "start"] 12 | -------------------------------------------------------------------------------- /examples/postgres/connect/connect-distributed.properties: -------------------------------------------------------------------------------- 1 | # A list of host/port pairs to use for establishing the initial connection to the Kafka cluster. 2 | bootstrap.servers=kafka:9092 3 | 4 | # unique name for the cluster, used in forming the Connect cluster group. Note that this must not conflict with consumer group IDs 5 | group.id=connect-cluster 6 | 7 | # The converters specify the format of data in Kafka and how to translate it into Connect data. Every Connect user will 8 | # need to configure these based on the format they want their data in when loaded from or stored into Kafka 9 | # key.converter=org.apache.kafka.connect.json.JsonConverter 10 | key.converter=org.apache.kafka.connect.storage.StringConverter 11 | value.converter=org.apache.kafka.connect.json.JsonConverter 12 | 13 | key.converter.schemas.enable=false 14 | value.converter.schemas.enable=true 15 | 16 | offset.storage.topic=connect-offsets 17 | offset.storage.replication.factor=1 18 | 19 | config.storage.topic=connect-configs 20 | config.storage.replication.factor=1 21 | 22 | status.storage.topic=connect-status 23 | status.storage.replication.factor=1 24 | 25 | offset.flush.interval.ms=10000 26 | 27 | # Load the connnectors 28 | plugin.path=/kafka/connect 29 | -------------------------------------------------------------------------------- /examples/postgres/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | services: 3 | zookeeper: 4 | image: quay.io/debezium/zookeeper:2.0 5 | ports: 6 | - 2181:2181 7 | - 2888:2888 8 | - 3888:3888 9 | kafka: 10 | image: quay.io/debezium/kafka:2.0 11 | ports: 12 | - 9092:9092 13 | links: 14 | - zookeeper 15 | environment: 16 | - ZOOKEEPER_CONNECT=zookeeper:2181 17 | postgres: 18 | image: quay.io/debezium/example-postgres:2.0 19 | ports: 20 | - 5432:5432 21 | environment: 22 | - POSTGRES_USER=postgres 23 | - POSTGRES_PASSWORD=postgres 24 | connect: 25 | build: 26 | context: ./connect 27 | dockerfile: Dockerfile 28 | ports: 29 | - 8083:8083 30 | links: 31 | - kafka 32 | - postgres 33 | environment: 34 | - BOOTSTRAP_SERVERS=kafka:9092 35 | - GROUP_ID=1 36 | - CONFIG_STORAGE_TOPIC=my_connect_configs 37 | - OFFSET_STORAGE_TOPIC=my_connect_offsets 38 | - STATUS_STORAGE_TOPIC=my_connect_statuses 39 | transfer: 40 | build: 41 | context: . 42 | dockerfile: Dockerfile 43 | links: 44 | - kafka 45 | 46 | -------------------------------------------------------------------------------- /examples/postgres/register-postgres-connector.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "postgres-connector", 3 | "config": { 4 | "connector.class": "io.debezium.connector.postgresql.PostgresConnector", 5 | "tasks.max": "1", 6 | "database.hostname": "postgres", 7 | "database.port": "5432", 8 | "database.user": "postgres", 9 | "database.password": "postgres", 10 | "database.dbname" : "postgres", 11 | "topic.prefix": "dbserver1", 12 | "schema.include.list": "inventory", 13 | "tombstones.on.delete": "false", 14 | "decimal.handling.mode": "double", 15 | "after.state.only": "false" 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /goreleaser.dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=linux/amd64 alpine:3.21 2 | RUN apk add --no-cache tzdata 3 | COPY transfer /transfer 4 | ENTRYPOINT ["/transfer"] 5 | -------------------------------------------------------------------------------- /integration_tests/parquet/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=2.0.0 2 | pyarrow>=14.0.0 3 | -------------------------------------------------------------------------------- /integration_tests/shared/destination.go: -------------------------------------------------------------------------------- 1 | package shared 2 | 3 | import "fmt" 4 | 5 | func (tf *TestFramework) verifyRowCountDestination(expected int) error { 6 | rows, err := tf.dest.Query(fmt.Sprintf("SELECT COUNT(*) FROM %s", tf.tableID.FullyQualifiedName())) 7 | if err != nil { 8 | return fmt.Errorf("failed to query table: %w", err) 9 | } 10 | 11 | var count int 12 | if rows.Next() { 13 | if err := rows.Scan(&count); err != nil { 14 | return fmt.Errorf("failed to scan count: %w", err) 15 | } 16 | } 17 | 18 | if err := rows.Err(); err != nil { 19 | return fmt.Errorf("failed to get rows: %w", err) 20 | } 21 | 22 | if count != expected { 23 | return fmt.Errorf("unexpected row count: expected %d, got %d", expected, count) 24 | } 25 | 26 | return nil 27 | } 28 | 29 | func (tf *TestFramework) verifyDataContentDestination(rowCount int) error { 30 | baseQuery := fmt.Sprintf("SELECT id, name, value, json_data, json_array, json_string, json_boolean, json_number FROM %s ORDER BY id", tf.tableID.FullyQualifiedName()) 31 | 32 | if tf.BigQuery() { 33 | // BigQuery does not support booleans, numbers and strings in a JSON column. 34 | baseQuery = fmt.Sprintf("SELECT id, name, value, TO_JSON_STRING(json_data), TO_JSON_STRING(json_array) FROM %s ORDER BY id", tf.tableID.FullyQualifiedName()) 35 | } 36 | 37 | rows, err := tf.dest.Query(baseQuery) 38 | if err != nil { 39 | return fmt.Errorf("failed to query table data: %w", err) 40 | } 41 | 42 | for i := 0; i < rowCount; i++ { 43 | if !rows.Next() { 44 | return fmt.Errorf("expected more rows: expected %d, got %d", rowCount, i) 45 | } 46 | 47 | if err := tf.scanAndCheckRow(rows, i); err != nil { 48 | return fmt.Errorf("failed to check row %d: %w", i, err) 49 | } 50 | } 51 | 52 | if rows.Next() { 53 | return fmt.Errorf("unexpected extra rows found") 54 | } 55 | 56 | if err := rows.Err(); err != nil { 57 | return fmt.Errorf("failed to get rows: %w", err) 58 | } 59 | 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /lib/apachelivy/schema.go: -------------------------------------------------------------------------------- 1 | package apachelivy 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/typing" 7 | ) 8 | 9 | // SparkSQL does not support primary keys. 10 | type Column struct { 11 | Name string 12 | DataType string 13 | Comment string 14 | } 15 | 16 | func (g GetSchemaResponse) BuildColumns() ([]Column, error) { 17 | colNameIndex := -1 18 | colTypeIndex := -1 19 | colCommentIndex := -1 20 | 21 | for i, field := range g.Schema.Fields { 22 | switch field.Name { 23 | case "col_name": 24 | colNameIndex = i 25 | case "data_type": 26 | colTypeIndex = i 27 | case "comment": 28 | colCommentIndex = i 29 | } 30 | } 31 | 32 | if colNameIndex == -1 || colTypeIndex == -1 || colCommentIndex == -1 { 33 | return nil, fmt.Errorf("col_name, data_type, or comment not found") 34 | } 35 | 36 | var cols []Column 37 | for _, row := range g.Data { 38 | name, err := typing.AssertTypeOptional[string](row[colNameIndex]) 39 | if err != nil { 40 | return nil, fmt.Errorf("col_name is not a string, type: %T", row[colNameIndex]) 41 | } 42 | 43 | dataType, err := typing.AssertTypeOptional[string](row[colTypeIndex]) 44 | if err != nil { 45 | return nil, fmt.Errorf("data_type is not a string, type: %T", row[colTypeIndex]) 46 | } 47 | 48 | comment, err := typing.AssertTypeOptional[string](row[colCommentIndex]) 49 | if err != nil { 50 | return nil, fmt.Errorf("comment is not a string, type: %T", row[colCommentIndex]) 51 | } 52 | 53 | cols = append(cols, Column{ 54 | Name: name, 55 | DataType: dataType, 56 | Comment: comment, 57 | }) 58 | } 59 | 60 | return cols, nil 61 | } 62 | -------------------------------------------------------------------------------- /lib/apachelivy/schema_test.go: -------------------------------------------------------------------------------- 1 | package apachelivy 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGetSchemaResponse_BuildColumns(t *testing.T) { 10 | { 11 | // Test case #1 - No columns 12 | resp := GetSchemaResponse{ 13 | Schema: GetSchemaStructResponse{ 14 | Fields: []GetSchemaFieldResponse{}, 15 | }, 16 | Data: [][]any{}, 17 | } 18 | 19 | _, err := resp.BuildColumns() 20 | assert.ErrorContains(t, err, "col_name, data_type, or comment not found") 21 | } 22 | { 23 | // Test case #2 - With columns 24 | resp := GetSchemaResponse{ 25 | Schema: GetSchemaStructResponse{ 26 | Fields: []GetSchemaFieldResponse{ 27 | { 28 | Name: "col_name", 29 | Type: "STRING", 30 | Nullable: false, 31 | }, 32 | { 33 | Name: "data_type", 34 | Type: "STRING", 35 | Nullable: false, 36 | }, 37 | { 38 | Name: "comment", 39 | Type: "STRING", 40 | Nullable: true, 41 | }, 42 | }, 43 | }, 44 | Data: [][]any{ 45 | { 46 | "id", 47 | "bigint", 48 | "", 49 | }, 50 | { 51 | "first_name", 52 | "string", 53 | "", 54 | }, 55 | { 56 | "last_name", 57 | "string", 58 | "", 59 | }, 60 | { 61 | "email", 62 | "string", 63 | "", 64 | }, 65 | }, 66 | } 67 | 68 | cols, err := resp.BuildColumns() 69 | assert.NoError(t, err) 70 | assert.Equal(t, 4, len(cols)) 71 | 72 | assert.Equal(t, "id", cols[0].Name) 73 | assert.Equal(t, "bigint", cols[0].DataType) 74 | assert.Equal(t, "", cols[0].Comment) 75 | 76 | assert.Equal(t, "first_name", cols[1].Name) 77 | assert.Equal(t, "string", cols[1].DataType) 78 | assert.Equal(t, "", cols[1].Comment) 79 | 80 | assert.Equal(t, "last_name", cols[2].Name) 81 | assert.Equal(t, "string", cols[2].DataType) 82 | assert.Equal(t, "", cols[2].Comment) 83 | 84 | assert.Equal(t, "email", cols[3].Name) 85 | assert.Equal(t, "string", cols[3].DataType) 86 | assert.Equal(t, "", cols[3].Comment) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /lib/apachelivy/util.go: -------------------------------------------------------------------------------- 1 | package apachelivy 2 | 3 | import ( 4 | "github.com/artie-labs/transfer/clients/iceberg/dialect" 5 | ) 6 | 7 | func shouldRetry(err error) bool { 8 | if err == nil { 9 | return false 10 | } 11 | 12 | _dialect := dialect.IcebergDialect{} 13 | if _dialect.IsTableDoesNotExistErr(err) { 14 | return false 15 | } else if _dialect.IsColumnAlreadyExistsErr(err) { 16 | return false 17 | } 18 | 19 | return true 20 | } 21 | -------------------------------------------------------------------------------- /lib/array/strings.go: -------------------------------------------------------------------------------- 1 | package array 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "reflect" 7 | 8 | "github.com/artie-labs/transfer/lib/stringutil" 9 | ) 10 | 11 | func InterfaceToArrayString(val any, recastAsArray bool) ([]string, error) { 12 | if val == nil { 13 | return nil, nil 14 | } 15 | 16 | list := reflect.ValueOf(val) 17 | if list.Kind() != reflect.Slice { 18 | if recastAsArray { 19 | // Since it's not a slice, let's cast it as a slice and re-enter this function. 20 | return InterfaceToArrayString([]any{val}, recastAsArray) 21 | } else { 22 | return nil, fmt.Errorf("wrong data type, kind: %v", list.Kind()) 23 | } 24 | 25 | } 26 | 27 | var vals []string 28 | for i := 0; i < list.Len(); i++ { 29 | kind := list.Index(i).Kind() 30 | value := list.Index(i).Interface() 31 | if stringValue, ok := value.(string); ok { 32 | vals = append(vals, stringValue) 33 | continue 34 | } 35 | 36 | var shouldParse bool 37 | if kind == reflect.Interface { 38 | valMap, isOk := value.(map[string]any) 39 | if isOk { 40 | value = valMap 41 | } 42 | 43 | shouldParse = true 44 | } 45 | 46 | if kind == reflect.Map || kind == reflect.Struct || shouldParse { 47 | bytes, err := json.Marshal(value) 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | vals = append(vals, string(bytes)) 53 | } else { 54 | // TODO: Do we need to escape backslashes? 55 | vals = append(vals, stringutil.EscapeBackslashes(fmt.Sprint(value))) 56 | } 57 | } 58 | 59 | return vals, nil 60 | } 61 | -------------------------------------------------------------------------------- /lib/array/strings_test.go: -------------------------------------------------------------------------------- 1 | package array 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestToArrayString(t *testing.T) { 10 | { 11 | // Test nil input 12 | value, err := InterfaceToArrayString(nil, false) 13 | assert.NoError(t, err) 14 | var expected []string 15 | assert.Equal(t, expected, value) 16 | } 17 | { 18 | // Test wrong data type 19 | _, err := InterfaceToArrayString(true, false) 20 | assert.ErrorContains(t, err, "wrong data type, kind: bool") 21 | } 22 | { 23 | // Test list of numbers 24 | value, err := InterfaceToArrayString([]int{1, 2, 3, 4, 5}, false) 25 | assert.NoError(t, err) 26 | assert.Equal(t, []string{"1", "2", "3", "4", "5"}, value) 27 | } 28 | { 29 | // Test list of strings 30 | value, err := InterfaceToArrayString([]string{"abc", "def", "ghi"}, false) 31 | assert.NoError(t, err) 32 | assert.Equal(t, []string{"abc", "def", "ghi"}, value) 33 | } 34 | { 35 | // Test list of booleans 36 | value, err := InterfaceToArrayString([]bool{true, false, true}, false) 37 | assert.NoError(t, err) 38 | assert.Equal(t, []string{"true", "false", "true"}, value) 39 | } 40 | { 41 | // Test array of nested objects 42 | value, err := InterfaceToArrayString([]map[string]any{{"foo": "bar"}, {"hello": "world"}}, false) 43 | assert.NoError(t, err) 44 | assert.Equal(t, []string{`{"foo":"bar"}`, `{"hello":"world"}`}, value) 45 | } 46 | { 47 | // Test array of nested lists 48 | value, err := InterfaceToArrayString([][]string{ 49 | { 50 | "foo", "bar", 51 | }, 52 | { 53 | "abc", "def", 54 | }, 55 | }, false) 56 | assert.NoError(t, err) 57 | assert.Equal(t, []string{"[foo bar]", "[abc def]"}, value) 58 | } 59 | { 60 | value, err := InterfaceToArrayString([]any{`{"foo":"bar"}`}, true) 61 | assert.NoError(t, err) 62 | assert.Equal(t, []string{`{"foo":"bar"}`}, value) 63 | } 64 | { 65 | // Test boolean recast as array 66 | value, err := InterfaceToArrayString(true, true) 67 | assert.NoError(t, err) 68 | assert.Equal(t, []string{"true"}, value) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /lib/artie/message_test.go: -------------------------------------------------------------------------------- 1 | package artie 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/segmentio/kafka-go" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | const keyString = "Struct{id=12}" 11 | 12 | func TestNewMessage(t *testing.T) { 13 | kafkaMsg := &kafka.Message{ 14 | Topic: "test_topic", 15 | Partition: 5, 16 | Key: []byte(keyString), 17 | Value: []byte("kafka_value"), 18 | } 19 | 20 | msg := NewMessage(kafkaMsg, "") 21 | assert.Equal(t, "test_topic", msg.Topic()) 22 | assert.Equal(t, "5", msg.Partition()) 23 | assert.Equal(t, keyString, string(msg.Key())) 24 | assert.Equal(t, "kafka_value", string(msg.Value())) 25 | } 26 | -------------------------------------------------------------------------------- /lib/awslib/config.go: -------------------------------------------------------------------------------- 1 | package awslib 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/aws/aws-sdk-go-v2/aws" 7 | "github.com/aws/aws-sdk-go-v2/config" 8 | "github.com/aws/aws-sdk-go-v2/credentials" 9 | ) 10 | 11 | func NewConfigWithCredentialsAndRegion(credentials credentials.StaticCredentialsProvider, region string) aws.Config { 12 | return aws.Config{ 13 | Region: region, 14 | Credentials: credentials, 15 | } 16 | } 17 | 18 | func NewDefaultConfig(ctx context.Context, region string) (aws.Config, error) { 19 | return config.LoadDefaultConfig(ctx, config.WithRegion(region)) 20 | } 21 | -------------------------------------------------------------------------------- /lib/awslib/sts_test.go: -------------------------------------------------------------------------------- 1 | package awslib 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestCredentials_IsExpired(t *testing.T) { 11 | creds := Credentials{ 12 | awsAccessKeyID: "test", 13 | awsSecretAccessKey: "test", 14 | awsSessionToken: "test", 15 | } 16 | 17 | { 18 | // Expiration = true, because there's no expiration set 19 | assert.True(t, creds.isExpired()) 20 | } 21 | { 22 | // Expiration = true, because is in the past, so it has expired 23 | creds.expiresAt = time.Now().Add(-1 * time.Minute) 24 | assert.True(t, creds.isExpired()) 25 | } 26 | { 27 | // Expiration = true, because is in the future (but less than the buffer) 28 | creds.expiresAt = time.Now().Add(1 * time.Minute) 29 | assert.True(t, creds.isExpired()) 30 | } 31 | { 32 | // Expiration = false, because is in the future (but more than the buffer) 33 | creds.expiresAt = time.Now().Add(100 * time.Minute) 34 | assert.False(t, creds.isExpired()) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /lib/awslib/types.go: -------------------------------------------------------------------------------- 1 | package awslib 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | // Ref: https://iceberg.apache.org/spec/#table-metadata-and-snapshots 10 | type S3TableSchema struct { 11 | FormatVersion int `json:"format-version"` 12 | TableUUID uuid.UUID `json:"table-uuid"` 13 | Location string `json:"location"` 14 | LastSequenceNumber int `json:"last-sequence-number"` 15 | LastUpdatedMS int `json:"last-updated-ms"` 16 | CurrentSchemaID int `json:"current-schema-id"` 17 | Schemas []InnerSchemaObject `json:"schemas"` 18 | } 19 | 20 | func (s S3TableSchema) RetrieveCurrentSchema() (InnerSchemaObject, error) { 21 | for _, schema := range s.Schemas { 22 | if schema.SchemaID == s.CurrentSchemaID { 23 | return schema, nil 24 | } 25 | } 26 | 27 | return InnerSchemaObject{}, fmt.Errorf("current schema not found") 28 | } 29 | 30 | type InnerSchemaObject struct { 31 | Type string `json:"struct"` 32 | SchemaID int `json:"schema-id"` 33 | Fields []InnerSchemaField `json:"fields"` 34 | } 35 | 36 | type InnerSchemaField struct { 37 | ID int `json:"id"` 38 | Name string `json:"name"` 39 | Type string `json:"type"` 40 | Required bool `json:"required"` 41 | } 42 | -------------------------------------------------------------------------------- /lib/batch/batch.go: -------------------------------------------------------------------------------- 1 | package batch 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | ) 7 | 8 | type KeyFunction interface { 9 | Key() string 10 | } 11 | 12 | func hasKeyFunction[T any](item T) (KeyFunction, bool) { 13 | if castedItem, isOk := any(item).(KeyFunction); isOk { 14 | return castedItem, isOk 15 | } 16 | 17 | return nil, false 18 | } 19 | 20 | // BySize takes a series of elements [in], encodes them using [encode], groups them into batches of bytes that sum to at 21 | // most [maxSizeBytes], and then passes each batch to the [yield] function. 22 | func BySize[T any](in []T, maxSizeBytes int, failIfRowExceedsMaxSizeBytes bool, encode func(T) ([]byte, error), yield func([][]byte, []T) error) error { 23 | var buffer [][]byte 24 | var rows []T 25 | var currentSizeBytes int 26 | 27 | for i, item := range in { 28 | bytes, err := encode(item) 29 | if err != nil { 30 | return fmt.Errorf("failed to encode item %d: %w", i, err) 31 | } 32 | 33 | if len(bytes) > maxSizeBytes { 34 | if failIfRowExceedsMaxSizeBytes { 35 | return fmt.Errorf("item %d is larger (%d bytes) than maxSizeBytes (%d bytes)", i, len(bytes), maxSizeBytes) 36 | } else { 37 | logFields := []any{slog.Int("index", i), slog.Int("bytes", len(bytes))} 38 | if stringItem, isOk := hasKeyFunction[T](item); isOk { 39 | logFields = append(logFields, slog.String("key", stringItem.Key())) 40 | } 41 | 42 | slog.Warn("Skipping item as the row is larger than maxSizeBytes", logFields...) 43 | continue 44 | } 45 | } 46 | 47 | currentSizeBytes += len(bytes) 48 | if currentSizeBytes < maxSizeBytes { 49 | buffer = append(buffer, bytes) 50 | rows = append(rows, item) 51 | } else if currentSizeBytes == maxSizeBytes { 52 | buffer = append(buffer, bytes) 53 | rows = append(rows, item) 54 | if err = yield(buffer, rows); err != nil { 55 | return err 56 | } 57 | buffer = [][]byte{} 58 | rows = []T{} 59 | currentSizeBytes = 0 60 | } else { 61 | if err = yield(buffer, rows); err != nil { 62 | return err 63 | } 64 | buffer = [][]byte{bytes} 65 | rows = []T{item} 66 | currentSizeBytes = len(bytes) 67 | } 68 | } 69 | 70 | if len(buffer) > 0 { 71 | if err := yield(buffer, rows); err != nil { 72 | return err 73 | } 74 | } 75 | 76 | return nil 77 | } 78 | -------------------------------------------------------------------------------- /lib/cdc/event.go: -------------------------------------------------------------------------------- 1 | package cdc 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/artie-labs/transfer/lib/kafkalib" 7 | "github.com/artie-labs/transfer/lib/typing" 8 | "github.com/artie-labs/transfer/lib/typing/columns" 9 | ) 10 | 11 | type Format interface { 12 | Labels() []string // Labels() to return a list of strings to maintain backward compatibility. 13 | GetPrimaryKey(key []byte, tc kafkalib.TopicConfig) (map[string]any, error) 14 | GetEventFromBytes(bytes []byte) (Event, error) 15 | } 16 | 17 | type Event interface { 18 | GetExecutionTime() time.Time 19 | Operation() string 20 | DeletePayload() bool 21 | GetTableName() string 22 | GetData(tc kafkalib.TopicConfig) (map[string]any, error) 23 | GetOptionalSchema() (map[string]typing.KindDetails, error) 24 | // GetColumns will inspect the envelope's payload right now and return. 25 | GetColumns() (*columns.Columns, error) 26 | } 27 | -------------------------------------------------------------------------------- /lib/cdc/format/format.go: -------------------------------------------------------------------------------- 1 | package format 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/artie-labs/transfer/lib/cdc" 7 | "github.com/artie-labs/transfer/lib/cdc/mongo" 8 | "github.com/artie-labs/transfer/lib/cdc/relational" 9 | "github.com/artie-labs/transfer/lib/logger" 10 | ) 11 | 12 | func GetFormatParser(label, topic string) cdc.Format { 13 | for _, validFormat := range []cdc.Format{relational.Debezium{}, mongo.Debezium{}} { 14 | for _, fmtLabel := range validFormat.Labels() { 15 | if fmtLabel == label { 16 | slog.Info("Loaded CDC Format parser...", 17 | slog.String("label", label), 18 | slog.String("topic", topic), 19 | ) 20 | return validFormat 21 | } 22 | } 23 | } 24 | 25 | logger.Panic("Failed to fetch CDC format parser", slog.String("label", label)) 26 | return nil 27 | } 28 | -------------------------------------------------------------------------------- /lib/cdc/format/format_test.go: -------------------------------------------------------------------------------- 1 | package format 2 | 3 | import ( 4 | "os" 5 | "os/exec" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/artie-labs/transfer/lib/cdc/mongo" 11 | "github.com/artie-labs/transfer/lib/cdc/relational" 12 | "github.com/artie-labs/transfer/lib/config/constants" 13 | "github.com/artie-labs/transfer/lib/typing" 14 | ) 15 | 16 | func TestGetFormatParser(t *testing.T) { 17 | { 18 | // Relational 19 | for _, format := range []string{constants.DBZPostgresAltFormat, constants.DBZPostgresFormat} { 20 | formatParser := GetFormatParser(format, "topicA") 21 | assert.NotNil(t, formatParser) 22 | 23 | _, err := typing.AssertType[relational.Debezium](formatParser) 24 | assert.NoError(t, err) 25 | } 26 | } 27 | { 28 | // Mongo 29 | formatParser := GetFormatParser(constants.DBZMongoFormat, "topicA") 30 | assert.NotNil(t, formatParser) 31 | 32 | _, err := typing.AssertType[mongo.Debezium](formatParser) 33 | assert.NoError(t, err) 34 | } 35 | } 36 | 37 | func testOsExit(t *testing.T, testFunc func(*testing.T)) { 38 | if os.Getenv(t.Name()) == "1" { 39 | testFunc(t) 40 | return 41 | } 42 | 43 | cmd := exec.Command(os.Args[0], "-test.run="+t.Name()) 44 | cmd.Env = append(os.Environ(), t.Name()+"=1") 45 | err := cmd.Run() 46 | if e, ok := err.(*exec.ExitError); ok && !e.Success() { 47 | return 48 | } 49 | 50 | t.Fatal("subprocess ran successfully, want non-zero exit status") 51 | } 52 | 53 | func TestGetFormatParserFatal(t *testing.T) { 54 | // This test cannot be iterated because it forks a separate process to do `go test -test.run=...` 55 | testOsExit(t, func(t *testing.T) { 56 | GetFormatParser("foo", "topicB") 57 | }) 58 | } 59 | -------------------------------------------------------------------------------- /lib/cdc/mongo/event.go: -------------------------------------------------------------------------------- 1 | package mongo 2 | 3 | import ( 4 | "github.com/artie-labs/transfer/lib/debezium" 5 | ) 6 | 7 | // SchemaEventPayload is our struct for an event with schema enabled. For reference, this is an example payload https://gist.github.com/Tang8330/d0998d8d1ebcbeaa4ecb8e098445cc3a 8 | type SchemaEventPayload struct { 9 | Schema debezium.Schema `json:"schema"` 10 | Payload Payload `json:"payload"` 11 | } 12 | 13 | type Payload struct { 14 | Before *string `json:"before"` 15 | After *string `json:"after"` 16 | 17 | Source Source `json:"source"` 18 | Operation string `json:"op"` 19 | 20 | // These maps are used to store the before and after JSONE as a map, since `before` and `after` come in as a JSONE string. 21 | beforeMap map[string]any 22 | afterMap map[string]any 23 | } 24 | 25 | type Source struct { 26 | Connector string `json:"connector"` 27 | TsMs int64 `json:"ts_ms"` 28 | Database string `json:"db"` 29 | Collection string `json:"collection"` 30 | } 31 | -------------------------------------------------------------------------------- /lib/cdc/mongo/mongo_bench_test.go: -------------------------------------------------------------------------------- 1 | package mongo 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "go.mongodb.org/mongo-driver/bson/primitive" 8 | 9 | "github.com/artie-labs/transfer/lib/kafkalib" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func BenchmarkGetPrimaryKey(b *testing.B) { 14 | var dbz Debezium 15 | 16 | for i := 0; i < b.N; i++ { 17 | newObjectID := primitive.NewObjectID().Hex() 18 | 19 | pkMap, err := dbz.GetPrimaryKey( 20 | []byte(fmt.Sprintf(`{"schema":{"type":"struct","fields":[{"type":"string","optional":false,"field":"id"}],"optional":false,"name":"1a75f632-29d2-419b-9ffe-d18fa12d74d5.38d5d2db-870a-4a38-a76c-9891b0e5122d.myFirstDatabase.stock.Key"},"payload":{"id":"{\"$oid\": \"%s\"}"}}`, newObjectID)), 21 | kafkalib.TopicConfig{ 22 | CDCKeyFormat: kafkalib.JSONKeyFmt, 23 | }, 24 | ) 25 | assert.NoError(b, err) 26 | 27 | pkVal, isOk := pkMap["_id"] 28 | assert.True(b, isOk) 29 | assert.Equal(b, pkVal, newObjectID) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /lib/cdc/relational/debezium.go: -------------------------------------------------------------------------------- 1 | package relational 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/artie-labs/transfer/lib/cdc" 8 | "github.com/artie-labs/transfer/lib/cdc/util" 9 | "github.com/artie-labs/transfer/lib/config/constants" 10 | "github.com/artie-labs/transfer/lib/debezium" 11 | "github.com/artie-labs/transfer/lib/kafkalib" 12 | ) 13 | 14 | type Debezium struct{} 15 | 16 | func (Debezium) GetEventFromBytes(bytes []byte) (cdc.Event, error) { 17 | if len(bytes) == 0 { 18 | return nil, fmt.Errorf("empty message") 19 | } 20 | 21 | var event util.SchemaEventPayload 22 | if err := json.Unmarshal(bytes, &event); err != nil { 23 | return nil, err 24 | } 25 | 26 | return &event, nil 27 | } 28 | 29 | func (Debezium) Labels() []string { 30 | return []string{ 31 | constants.DBZPostgresFormat, 32 | constants.DBZPostgresAltFormat, 33 | constants.DBZMySQLFormat, 34 | constants.DBZRelationalFormat, 35 | } 36 | } 37 | 38 | func (Debezium) GetPrimaryKey(key []byte, tc kafkalib.TopicConfig) (map[string]any, error) { 39 | return debezium.ParsePartitionKey(key, tc.CDCKeyFormat) 40 | } 41 | -------------------------------------------------------------------------------- /lib/cdc/relational/relation_suite_test.go: -------------------------------------------------------------------------------- 1 | package relational 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/suite" 7 | ) 8 | 9 | type RelationTestSuite struct { 10 | suite.Suite 11 | *Debezium 12 | } 13 | 14 | func (r *RelationTestSuite) SetupTest() { 15 | var debezium Debezium 16 | r.Debezium = &debezium 17 | } 18 | 19 | func TestRelationTestSuite(t *testing.T) { 20 | suite.Run(t, new(RelationTestSuite)) 21 | } 22 | -------------------------------------------------------------------------------- /lib/cdc/util/optional_schema.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/artie-labs/transfer/lib/debezium" 5 | "github.com/artie-labs/transfer/lib/typing" 6 | ) 7 | 8 | func (s *SchemaEventPayload) GetOptionalSchema() (map[string]typing.KindDetails, error) { 9 | fieldsObject := s.Schema.GetSchemaFromLabel(debezium.After) 10 | if fieldsObject == nil { 11 | return nil, nil 12 | } 13 | 14 | schema := make(map[string]typing.KindDetails) 15 | for _, field := range fieldsObject.Fields { 16 | kd, err := field.ToKindDetails() 17 | if err != nil { 18 | return nil, err 19 | } 20 | 21 | schema[field.FieldName] = kd 22 | } 23 | 24 | return schema, nil 25 | } 26 | -------------------------------------------------------------------------------- /lib/config/destinations_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDatabricks_DSN(t *testing.T) { 10 | d := Databricks{ 11 | Host: "foo", 12 | HttpPath: "/api/def", 13 | Port: 443, 14 | Catalog: "catalogName", 15 | PersonalAccessToken: "pat", 16 | } 17 | 18 | assert.Equal(t, "token:pat@foo:443/api/def?catalog=catalogName", d.DSN()) 19 | } 20 | -------------------------------------------------------------------------------- /lib/config/mssql_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/artie-labs/transfer/lib/config/constants" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestValidateMSSQL(t *testing.T) { 12 | var cfg Config 13 | assert.ErrorContains(t, cfg.ValidateMSSQL(), "output is not mssql") 14 | cfg.Output = constants.MSSQL 15 | assert.ErrorContains(t, cfg.ValidateMSSQL(), "mssql config is nil") 16 | cfg.MSSQL = &MSSQL{} 17 | assert.ErrorContains(t, cfg.ValidateMSSQL(), "one of mssql settings is empty (host, username, password, database)") 18 | 19 | cfg.MSSQL = &MSSQL{ 20 | Host: "localhost", 21 | Port: 1433, 22 | Username: "sa", 23 | Password: "password", 24 | Database: "test", 25 | } 26 | assert.NoError(t, cfg.ValidateMSSQL()) 27 | } 28 | -------------------------------------------------------------------------------- /lib/config/settings.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/jessevdk/go-flags" 7 | ) 8 | 9 | type Settings struct { 10 | Config Config 11 | VerboseLogging bool 12 | } 13 | 14 | // LoadSettings will take the flags and then parse, loadConfig is optional for testing purposes. 15 | func LoadSettings(args []string, loadConfig bool) (*Settings, error) { 16 | var opts struct { 17 | ConfigFilePath string `short:"c" long:"config" description:"path to the config file"` 18 | Verbose bool `short:"v" long:"verbose" description:"debug logging" optional:"true"` 19 | } 20 | 21 | _, err := flags.ParseArgs(&opts, args) 22 | if err != nil { 23 | return nil, fmt.Errorf("failed to parse args: %w", err) 24 | } 25 | 26 | settings := &Settings{ 27 | VerboseLogging: opts.Verbose, 28 | } 29 | 30 | if loadConfig { 31 | config, err := readFileToConfig(opts.ConfigFilePath) 32 | if err != nil { 33 | return nil, fmt.Errorf("failed to parse config file: %w", err) 34 | } 35 | 36 | tcs, err := config.TopicConfigs() 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | for _, tc := range tcs { 42 | tc.Load() 43 | } 44 | 45 | if err = config.Validate(); err != nil { 46 | return nil, fmt.Errorf("failed to validate config: %w", err) 47 | } 48 | 49 | settings.Config = *config 50 | } 51 | 52 | return settings, nil 53 | } 54 | -------------------------------------------------------------------------------- /lib/config/settings_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestParseArgs(t *testing.T) { 10 | settings, err := LoadSettings([]string{}, false) 11 | assert.NoError(t, err) 12 | assert.Equal(t, settings.VerboseLogging, false) 13 | 14 | settings, err = LoadSettings([]string{"-v"}, false) 15 | assert.NoError(t, err) 16 | assert.Equal(t, settings.VerboseLogging, true) 17 | } 18 | -------------------------------------------------------------------------------- /lib/cryptography/cryptography.go: -------------------------------------------------------------------------------- 1 | package cryptography 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/sha256" 7 | "crypto/x509" 8 | "encoding/hex" 9 | "encoding/pem" 10 | "fmt" 11 | "math/big" 12 | "os" 13 | 14 | "github.com/artie-labs/transfer/lib/typing" 15 | ) 16 | 17 | // HashValue - Hashes a value using SHA256 18 | func HashValue(value any) any { 19 | if value == nil { 20 | return nil 21 | } 22 | 23 | hash := sha256.New() 24 | hash.Write([]byte(fmt.Sprint(value))) 25 | return hex.EncodeToString(hash.Sum(nil)) 26 | } 27 | 28 | func LoadRSAKey(filePath string) (*rsa.PrivateKey, error) { 29 | keyBytes, err := os.ReadFile(filePath) 30 | if err != nil { 31 | return nil, fmt.Errorf("failed to read file: %w", err) 32 | } 33 | 34 | return ParseRSAPrivateKey(keyBytes) 35 | } 36 | 37 | func ParseRSAPrivateKey(keyBytes []byte) (*rsa.PrivateKey, error) { 38 | block, _ := pem.Decode(keyBytes) 39 | if block == nil { 40 | return nil, fmt.Errorf("failed to decode PEM block containing private key") 41 | } 42 | 43 | key, err := x509.ParsePKCS8PrivateKey(block.Bytes) 44 | if err != nil { 45 | return nil, fmt.Errorf("failed to parse private key: %v", err) 46 | } 47 | 48 | rsaKey, err := typing.AssertType[*rsa.PrivateKey](key) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | return rsaKey, nil 54 | } 55 | 56 | func RandomInt64n(n int64) (int64, error) { 57 | randN, err := rand.Int(rand.Reader, big.NewInt(n)) 58 | if err != nil { 59 | return 0, fmt.Errorf("failed to generate random number: %w", err) 60 | } 61 | 62 | return randN.Int64(), nil 63 | } 64 | -------------------------------------------------------------------------------- /lib/cryptography/cryptography_test.go: -------------------------------------------------------------------------------- 1 | package cryptography 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestHashValue(t *testing.T) { 10 | { 11 | // If we pass nil in, we should get nil out. 12 | assert.Equal(t, nil, HashValue(nil)) 13 | } 14 | { 15 | // Pass in an empty string 16 | assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", HashValue("")) 17 | } 18 | { 19 | // Pass in a string 20 | assert.Equal(t, "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", HashValue("hello world")) 21 | } 22 | { 23 | // Value should be deterministic. 24 | for range 50 { 25 | assert.Equal(t, "b9a40320d82075681b2500e38160538e5e912bd8f49c03e87367fe82c1fa35d2", HashValue("dusty the mini aussie")) 26 | } 27 | } 28 | } 29 | 30 | func BenchmarkHashValue(b *testing.B) { 31 | for i := 0; i < b.N; i++ { 32 | assert.Equal(b, "b9a40320d82075681b2500e38160538e5e912bd8f49c03e87367fe82c1fa35d2", HashValue("dusty the mini aussie")) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /lib/csvwriter/gzip.go: -------------------------------------------------------------------------------- 1 | package csvwriter 2 | 3 | import ( 4 | "compress/gzip" 5 | "encoding/csv" 6 | "os" 7 | "path/filepath" 8 | ) 9 | 10 | type GzipWriter struct { 11 | file *os.File 12 | gzip *gzip.Writer 13 | writer *csv.Writer 14 | } 15 | 16 | func NewGzipWriter(fp string) (*GzipWriter, error) { 17 | file, err := os.Create(fp) 18 | if err != nil { 19 | return nil, err 20 | } 21 | 22 | gzipWriter := gzip.NewWriter(file) 23 | csvWriter := csv.NewWriter(gzipWriter) 24 | csvWriter.Comma = '\t' 25 | return &GzipWriter{ 26 | file: file, 27 | gzip: gzipWriter, 28 | writer: csvWriter, 29 | }, nil 30 | } 31 | 32 | func (g *GzipWriter) FileName() string { 33 | return filepath.Base(g.file.Name()) 34 | } 35 | 36 | func (g *GzipWriter) Write(row []string) error { 37 | return g.writer.Write(row) 38 | } 39 | 40 | func (g *GzipWriter) Flush() error { 41 | g.writer.Flush() 42 | return g.writer.Error() 43 | } 44 | 45 | func (g *GzipWriter) Close() error { 46 | if err := g.gzip.Close(); err != nil { 47 | // If closing the gzip writer fails, we should still try to close the file. 48 | _ = g.file.Close() 49 | return err 50 | } 51 | return g.file.Close() 52 | } 53 | -------------------------------------------------------------------------------- /lib/db/errors.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "log/slog" 7 | "net" 8 | "syscall" 9 | ) 10 | 11 | var retryableErrs = []error{ 12 | syscall.ECONNRESET, 13 | syscall.ECONNREFUSED, 14 | io.EOF, 15 | syscall.ETIMEDOUT, 16 | } 17 | 18 | func isRetryableError(err error) bool { 19 | if err == nil { 20 | return false 21 | } 22 | 23 | for _, retryableErr := range retryableErrs { 24 | if errors.Is(err, retryableErr) { 25 | return true 26 | } 27 | } 28 | 29 | if netErr, ok := err.(net.Error); ok { 30 | if netErr.Timeout() { 31 | slog.Warn("caught a net.Error in isRetryableError", slog.Any("err", err)) 32 | return true 33 | } 34 | } 35 | 36 | return false 37 | } 38 | -------------------------------------------------------------------------------- /lib/db/errors_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | "syscall" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestIsRetryable_Errors(t *testing.T) { 12 | { 13 | // Test nil error case 14 | var err error 15 | assert.False(t, isRetryableError(err), "nil error should not be retryable") 16 | } 17 | { 18 | // Test irrelevant error case 19 | assert.False(t, isRetryableError(fmt.Errorf("random error")), "irrelevant error should not be retryable") 20 | } 21 | { 22 | // Test direct connection refused error 23 | assert.True(t, isRetryableError(syscall.ECONNREFUSED), "direct connection refused error should be retryable") 24 | } 25 | { 26 | // Test direct connection reset error 27 | assert.True(t, isRetryableError(syscall.ECONNRESET), "direct connection reset error should be retryable") 28 | } 29 | { 30 | // Test wrapped connection refused error 31 | assert.True(t, isRetryableError(fmt.Errorf("foo: %w", syscall.ECONNREFUSED)), "wrapped connection refused error should be retryable") 32 | } 33 | { 34 | // Test wrapped connection reset error 35 | assert.True(t, isRetryableError(fmt.Errorf("foo: %w", syscall.ECONNRESET)), "wrapped connection reset error should be retryable") 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /lib/debezium/converters/bytes.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | ) 7 | 8 | type Bytes struct{} 9 | 10 | // Convert attempts to convert a value (type []byte, or string) to a slice of bytes. 11 | // - If value is already a slice of bytes it will be directly returned. 12 | // - If value is a string we will attempt to base64 decode it. 13 | func (Bytes) Convert(value any) (any, error) { 14 | if bytes, isOk := value.([]byte); isOk { 15 | return bytes, nil 16 | } 17 | 18 | if stringValue, isOk := value.(string); isOk { 19 | data, err := base64.StdEncoding.DecodeString(stringValue) 20 | if err != nil { 21 | return nil, fmt.Errorf("failed to base64 decode: %w", err) 22 | } 23 | return data, nil 24 | } 25 | 26 | return nil, fmt.Errorf("expected []byte or string, got %T", value) 27 | } 28 | -------------------------------------------------------------------------------- /lib/debezium/converters/bytes_test.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestBytes_Convert(t *testing.T) { 10 | { 11 | // []byte 12 | actual, err := Bytes{}.Convert([]byte{40, 39, 38}) 13 | assert.NoError(t, err) 14 | assert.Equal(t, []byte{40, 39, 38}, actual) 15 | } 16 | { 17 | // base64 encoded string 18 | actual, err := Bytes{}.Convert("aGVsbG8gd29ybGQK") 19 | assert.NoError(t, err) 20 | assert.Equal(t, []byte{0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0xa}, actual) 21 | } 22 | { 23 | // malformed string 24 | _, err := Bytes{}.Convert("asdf$$$") 25 | assert.ErrorContains(t, err, "failed to base64 decode") 26 | } 27 | { 28 | // type that is not string or []byte 29 | _, err := Bytes{}.Convert(map[string]any{}) 30 | assert.ErrorContains(t, err, "expected []byte or string, got map[string]interface {}") 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /lib/debezium/converters/converters.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "github.com/artie-labs/transfer/lib/typing" 5 | ) 6 | 7 | type ValueConverter interface { 8 | ToKindDetails() typing.KindDetails 9 | Convert(value any) (any, error) 10 | } 11 | -------------------------------------------------------------------------------- /lib/debezium/converters/date.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/artie-labs/transfer/lib/typing" 8 | ) 9 | 10 | type Date struct{} 11 | 12 | func (d Date) ToKindDetails() typing.KindDetails { 13 | return typing.Date 14 | } 15 | 16 | func (d Date) Convert(value any) (any, error) { 17 | valueInt64, isOk := value.(int64) 18 | if !isOk { 19 | return nil, fmt.Errorf("expected int64 got '%v' with type %T", value, value) 20 | } 21 | 22 | // Represents the number of days since the epoch. 23 | return time.UnixMilli(0).In(time.UTC).AddDate(0, 0, int(valueInt64)).Format(time.DateOnly), nil 24 | } 25 | -------------------------------------------------------------------------------- /lib/debezium/converters/date_test.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDate_Convert(t *testing.T) { 10 | { 11 | // Invalid data type 12 | _, err := Date{}.Convert("invalid") 13 | assert.ErrorContains(t, err, "expected int64 got 'invalid' with type string") 14 | } 15 | { 16 | val, err := Date{}.Convert(int64(19401)) 17 | assert.NoError(t, err) 18 | assert.Equal(t, "2023-02-13", val.(string)) 19 | } 20 | { 21 | val, err := Date{}.Convert(int64(19429)) 22 | assert.NoError(t, err) 23 | assert.Equal(t, "2023-03-13", val.(string)) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /lib/debezium/converters/geometry_test.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestParseGeometryPoint(t *testing.T) { 10 | { 11 | geoJSON, err := GeometryPoint{}.Convert(map[string]any{ 12 | "x": 2.2945, 13 | "y": 48.8584, 14 | "wkb": "AQEAAABCYOXQIlsCQHZxGw3gbUhA", 15 | "srid": nil, 16 | }) 17 | 18 | geoJSONString, isOk := geoJSON.(string) 19 | assert.True(t, isOk) 20 | assert.NoError(t, err) 21 | assert.Equal(t, `{"type":"Feature","geometry":{"type":"Point","coordinates":[2.2945,48.8584]}}`, geoJSONString) 22 | } 23 | } 24 | 25 | func TestGeometryWkb(t *testing.T) { 26 | { 27 | geoJSONString, err := Geometry{}.Convert(map[string]any{ 28 | "wkb": "AQEAAAAAAAAAAADwPwAAAAAAAPA/", 29 | "srid": nil, 30 | }) 31 | 32 | assert.NoError(t, err) 33 | assert.Equal(t, `{"type":"Feature","geometry":{"type":"Point","coordinates":[1,1]},"properties":null}`, geoJSONString) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /lib/debezium/converters/string.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import "github.com/artie-labs/transfer/lib/typing" 4 | 5 | type StringPassthrough struct{} 6 | 7 | func (StringPassthrough) Convert(value any) (any, error) { 8 | castedValue, err := typing.AssertType[string](value) 9 | if err != nil { 10 | return nil, err 11 | } 12 | 13 | return castedValue, nil 14 | } 15 | 16 | func (StringPassthrough) ToKindDetails() typing.KindDetails { 17 | return typing.String 18 | } 19 | -------------------------------------------------------------------------------- /lib/debezium/converters/string_test.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestStringPassthrough_Convert(t *testing.T) { 10 | { 11 | // Non string 12 | _, err := StringPassthrough{}.Convert(1) 13 | assert.ErrorContains(t, err, "expected type string, got int") 14 | } 15 | { 16 | // String 17 | value, err := StringPassthrough{}.Convert("test") 18 | assert.Nil(t, err) 19 | assert.Equal(t, "test", value) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /lib/debezium/converters/timestamp.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/artie-labs/transfer/lib/typing" 7 | ) 8 | 9 | type Timestamp struct{} 10 | 11 | func (t Timestamp) ToKindDetails() typing.KindDetails { 12 | return typing.TimestampNTZ 13 | } 14 | 15 | func (t Timestamp) Convert(value any) (any, error) { 16 | castedValue, err := typing.AssertType[int64](value) 17 | if err != nil { 18 | return nil, err 19 | } 20 | 21 | // Represents the number of milliseconds since the epoch, and does not include timezone information. 22 | return time.UnixMilli(castedValue).In(time.UTC), nil 23 | } 24 | 25 | type MicroTimestamp struct{} 26 | 27 | func (mt MicroTimestamp) ToKindDetails() typing.KindDetails { 28 | return typing.TimestampNTZ 29 | } 30 | 31 | func (mt MicroTimestamp) Convert(value any) (any, error) { 32 | castedValue, err := typing.AssertType[int64](value) 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | // Represents the number of microseconds since the epoch, and does not include timezone information. 38 | return time.UnixMicro(castedValue).In(time.UTC), nil 39 | } 40 | 41 | type NanoTimestamp struct{} 42 | 43 | func (nt NanoTimestamp) ToKindDetails() typing.KindDetails { 44 | return typing.TimestampNTZ 45 | } 46 | 47 | func (nt NanoTimestamp) Convert(value any) (any, error) { 48 | castedValue, err := typing.AssertType[int64](value) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | // Represents the number of nanoseconds since the epoch, and does not include timezone information. 54 | return time.UnixMicro(castedValue / 1_000).In(time.UTC), nil 55 | } 56 | -------------------------------------------------------------------------------- /lib/debezium/converters/timestamp_test.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/artie-labs/transfer/lib/typing" 8 | "github.com/artie-labs/transfer/lib/typing/ext" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestTimestamp_Converter(t *testing.T) { 13 | assert.Equal(t, typing.TimestampNTZ, Timestamp{}.ToKindDetails()) 14 | { 15 | // Invalid conversion 16 | _, err := Timestamp{}.Convert("invalid") 17 | assert.ErrorContains(t, err, "expected type int64, got string") 18 | } 19 | { 20 | // Valid conversion 21 | converted, err := Timestamp{}.Convert(int64(1_725_058_799_089)) 22 | assert.NoError(t, err) 23 | assert.Equal(t, "2024-08-30T22:59:59.089", converted.(time.Time).Format(ext.RFC3339NoTZ)) 24 | } 25 | } 26 | 27 | func TestMicroTimestamp_Converter(t *testing.T) { 28 | assert.Equal(t, typing.TimestampNTZ, MicroTimestamp{}.ToKindDetails()) 29 | { 30 | // Invalid conversion 31 | _, err := MicroTimestamp{}.Convert("invalid") 32 | assert.ErrorContains(t, err, "expected type int64, got string") 33 | } 34 | { 35 | // Valid conversion 36 | converted, err := MicroTimestamp{}.Convert(int64(1_712_609_795_827_923)) 37 | assert.NoError(t, err) 38 | assert.Equal(t, "2024-04-08T20:56:35.827923", converted.(time.Time).Format(ext.RFC3339NoTZ)) 39 | } 40 | } 41 | 42 | func TestNanoTimestamp_Converter(t *testing.T) { 43 | assert.Equal(t, typing.TimestampNTZ, NanoTimestamp{}.ToKindDetails()) 44 | { 45 | // Invalid conversion 46 | _, err := NanoTimestamp{}.Convert("invalid") 47 | assert.ErrorContains(t, err, "expected type int64, got string") 48 | } 49 | { 50 | // Valid conversion 51 | converted, err := NanoTimestamp{}.Convert(int64(1_712_609_795_827_001_000)) 52 | assert.NoError(t, err) 53 | assert.Equal(t, "2024-04-08T20:56:35.827001", converted.(time.Time).Format(ext.RFC3339NoTZ)) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /lib/debezium/types_bench_test.go: -------------------------------------------------------------------------------- 1 | package debezium 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | 9 | dbzConverters "github.com/artie-labs/transfer/lib/debezium/converters" 10 | "github.com/artie-labs/transfer/lib/typing/decimal" 11 | ) 12 | 13 | func BenchmarkDecodeDecimal_P64_S10(b *testing.B) { 14 | parameters := map[string]any{ 15 | "scale": 10, 16 | KafkaDecimalPrecisionKey: 64, 17 | } 18 | field := Field{Parameters: parameters} 19 | for i := 0; i < b.N; i++ { 20 | bytes, err := dbzConverters.Bytes{}.Convert("AwBGAw8m9GLXrCGifrnVP/8jPHrNEtd1r4rS") 21 | assert.NoError(b, err) 22 | 23 | converter, err := field.ToValueConverter() 24 | assert.NoError(b, err) 25 | 26 | dec, err := converter.Convert(bytes.([]byte)) 27 | assert.NoError(b, err) 28 | assert.Equal(b, "123456789012345678901234567890123456789012345678901234.1234567890", dec.(*decimal.Decimal).String()) 29 | require.NoError(b, err) 30 | } 31 | } 32 | 33 | func BenchmarkDecodeDecimal_P38_S2(b *testing.B) { 34 | parameters := map[string]any{ 35 | "scale": 2, 36 | KafkaDecimalPrecisionKey: 38, 37 | } 38 | field := Field{Parameters: parameters} 39 | for i := 0; i < b.N; i++ { 40 | bytes, err := dbzConverters.Bytes{}.Convert(`AMCXznvJBxWzS58P/////w==`) 41 | assert.NoError(b, err) 42 | 43 | converter, err := field.ToValueConverter() 44 | assert.NoError(b, err) 45 | 46 | dec, err := converter.Convert(bytes.([]byte)) 47 | assert.NoError(b, err) 48 | assert.Equal(b, "9999999999999999999999999999999999.99", dec.(*decimal.Decimal).String()) 49 | } 50 | } 51 | 52 | func BenchmarkDecodeDecimal_P5_S2(b *testing.B) { 53 | parameters := map[string]any{ 54 | "scale": 2, 55 | KafkaDecimalPrecisionKey: 5, 56 | } 57 | 58 | field := Field{Parameters: parameters} 59 | for i := 0; i < b.N; i++ { 60 | bytes, err := dbzConverters.Bytes{}.Convert(`AOHJ`) 61 | assert.NoError(b, err) 62 | 63 | converter, err := field.ToValueConverter() 64 | assert.NoError(b, err) 65 | 66 | dec, err := converter.Convert(bytes.([]byte)) 67 | assert.NoError(b, err) 68 | assert.Equal(b, "578.01", dec.(*decimal.Decimal).String()) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /lib/destination/ddl/ddl_suite_test.go: -------------------------------------------------------------------------------- 1 | package ddl_test // to avoid go import cycles. 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/artie-labs/transfer/clients/redshift" 7 | 8 | "github.com/artie-labs/transfer/lib/config" 9 | 10 | "github.com/artie-labs/transfer/clients/bigquery" 11 | "github.com/artie-labs/transfer/clients/snowflake" 12 | "github.com/artie-labs/transfer/lib/db" 13 | "github.com/artie-labs/transfer/lib/mocks" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/suite" 16 | ) 17 | 18 | type DDLTestSuite struct { 19 | suite.Suite 20 | fakeBigQueryStore *mocks.FakeStore 21 | bigQueryStore *bigquery.Store 22 | bigQueryCfg config.Config 23 | 24 | fakeSnowflakeStagesStore *mocks.FakeStore 25 | snowflakeStagesStore *snowflake.Store 26 | 27 | fakeRedshiftStore *mocks.FakeStore 28 | redshiftStore *redshift.Store 29 | } 30 | 31 | func (d *DDLTestSuite) SetupTest() { 32 | d.bigQueryCfg = config.Config{ 33 | BigQuery: &config.BigQuery{ 34 | ProjectID: "artie-project", 35 | }, 36 | } 37 | 38 | d.fakeBigQueryStore = &mocks.FakeStore{} 39 | bqStore := db.Store(d.fakeBigQueryStore) 40 | 41 | var err error 42 | d.bigQueryStore, err = bigquery.LoadBigQuery(d.T().Context(), d.bigQueryCfg, &bqStore) 43 | assert.NoError(d.T(), err) 44 | 45 | d.fakeSnowflakeStagesStore = &mocks.FakeStore{} 46 | snowflakeStagesStore := db.Store(d.fakeSnowflakeStagesStore) 47 | snowflakeCfg := config.Config{ 48 | Snowflake: &config.Snowflake{}, 49 | } 50 | d.snowflakeStagesStore, err = snowflake.LoadSnowflake(d.T().Context(), snowflakeCfg, &snowflakeStagesStore) 51 | assert.NoError(d.T(), err) 52 | 53 | d.fakeRedshiftStore = &mocks.FakeStore{} 54 | redshiftStore := db.Store(d.fakeRedshiftStore) 55 | redshiftCfg := config.Config{Redshift: &config.Redshift{}} 56 | d.redshiftStore, err = redshift.LoadRedshift(d.T().Context(), redshiftCfg, &redshiftStore) 57 | assert.NoError(d.T(), err) 58 | } 59 | 60 | func TestDDLTestSuite(t *testing.T) { 61 | suite.Run(t, new(DDLTestSuite)) 62 | } 63 | -------------------------------------------------------------------------------- /lib/destination/ddl/expiry.go: -------------------------------------------------------------------------------- 1 | package ddl 2 | 3 | import ( 4 | "log/slog" 5 | "strconv" 6 | "strings" 7 | "time" 8 | 9 | "github.com/artie-labs/transfer/lib/config/constants" 10 | ) 11 | 12 | func ShouldDeleteFromName(name string) bool { 13 | parts := strings.Split(strings.ToLower(name), constants.ArtiePrefix) 14 | if len(parts) != 2 { 15 | return false 16 | } 17 | 18 | suffixParts := strings.Split(parts[1], "_") 19 | if len(suffixParts) != 3 { 20 | return false 21 | } 22 | 23 | part := suffixParts[2] 24 | if strings.EqualFold(part, "msm") { 25 | return false 26 | } 27 | 28 | unix, err := strconv.Atoi(part) 29 | if err != nil { 30 | slog.Error("Failed to parse unix string", slog.Any("err", err), slog.String("tableName", name), slog.String("part", part)) 31 | return false 32 | } 33 | 34 | return time.Now().UTC().After(time.Unix(int64(unix), 0)) 35 | } 36 | -------------------------------------------------------------------------------- /lib/destination/ddl/expiry_test.go: -------------------------------------------------------------------------------- 1 | package ddl 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/artie-labs/transfer/lib/config/constants" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestShouldDeleteFromName(t *testing.T) { 15 | { 16 | // Tables to not drop 17 | tablesToNotDrop := []string{ 18 | "foo", 19 | "transactions", 20 | fmt.Sprintf("future_tbl___artie_suffix_%d", time.Now().Add(constants.TemporaryTableTTL).Unix()), 21 | fmt.Sprintf("future_tbl___notartie_%d", time.Now().Add(-1*time.Hour).Unix()), 22 | fmt.Sprintf("%s_foo_msm", constants.ArtiePrefix), 23 | fmt.Sprintf("%s_foo_MSM", constants.ArtiePrefix), 24 | fmt.Sprintf("%s_foo_foo_MSM", constants.ArtiePrefix), 25 | } 26 | 27 | for _, tblToNotDelete := range tablesToNotDrop { 28 | assert.False(t, ShouldDeleteFromName(strings.ToLower(tblToNotDelete)), tblToNotDelete) 29 | assert.False(t, ShouldDeleteFromName(strings.ToUpper(tblToNotDelete)), tblToNotDelete) 30 | assert.False(t, ShouldDeleteFromName(tblToNotDelete), tblToNotDelete) 31 | } 32 | } 33 | { 34 | // Tables that are eligible to be dropped 35 | tablesToDrop := []string{ 36 | "transactions___ARTIE_48GJC_1723663043", 37 | fmt.Sprintf("expired_tbl_%s_suffix_%d", constants.ArtiePrefix, time.Now().Add(-1*constants.TemporaryTableTTL).Unix()), 38 | fmt.Sprintf("artie_%s_suffix_%d", constants.ArtiePrefix, time.Now().Add(-1*constants.TemporaryTableTTL).Unix()), 39 | } 40 | 41 | for _, tblToDelete := range tablesToDrop { 42 | assert.True(t, ShouldDeleteFromName(strings.ToLower(tblToDelete)), tblToDelete) 43 | assert.True(t, ShouldDeleteFromName(strings.ToUpper(tblToDelete)), tblToDelete) 44 | assert.True(t, ShouldDeleteFromName(tblToDelete), tblToDelete) 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /lib/destination/types/types.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/artie-labs/transfer/lib/config" 7 | "github.com/artie-labs/transfer/lib/sql" 8 | ) 9 | 10 | type DestinationTableConfigMap struct { 11 | fqNameToConfigMap map[string]*DestinationTableConfig 12 | sync.RWMutex 13 | } 14 | 15 | func (d *DestinationTableConfigMap) GetTableConfig(tableID sql.TableIdentifier) *DestinationTableConfig { 16 | d.RLock() 17 | defer d.RUnlock() 18 | 19 | tableConfig, isOk := d.fqNameToConfigMap[tableID.FullyQualifiedName()] 20 | if !isOk { 21 | return nil 22 | } 23 | 24 | return tableConfig 25 | } 26 | 27 | func (d *DestinationTableConfigMap) RemoveTable(tableID sql.TableIdentifier) { 28 | d.Lock() 29 | defer d.Unlock() 30 | 31 | delete(d.fqNameToConfigMap, tableID.FullyQualifiedName()) 32 | } 33 | 34 | func (d *DestinationTableConfigMap) AddTable(tableID sql.TableIdentifier, config *DestinationTableConfig) { 35 | d.Lock() 36 | defer d.Unlock() 37 | 38 | if d.fqNameToConfigMap == nil { 39 | d.fqNameToConfigMap = make(map[string]*DestinationTableConfig) 40 | } 41 | 42 | d.fqNameToConfigMap[tableID.FullyQualifiedName()] = config 43 | } 44 | 45 | type MergeOpts struct { 46 | AdditionalEqualityStrings []string 47 | ColumnSettings config.SharedDestinationColumnSettings 48 | RetryColBackfill bool 49 | SubQueryDedupe bool 50 | 51 | // Multi-step merge settings 52 | PrepareTemporaryTable bool 53 | UseBuildMergeQueryIntoStagingTable bool 54 | } 55 | 56 | type AdditionalSettings struct { 57 | AdditionalCopyClause string 58 | ColumnSettings config.SharedDestinationColumnSettings 59 | 60 | // These settings are used for the `Append` method. 61 | UseTempTable bool 62 | TempTableID sql.TableIdentifier 63 | } 64 | -------------------------------------------------------------------------------- /lib/destination/types/types_test.go: -------------------------------------------------------------------------------- 1 | package types_test 2 | 3 | // We are using a different pkg name because we are importing `mocks.TableIdentifier`, doing so will avoid a cyclical dependency. 4 | 5 | import ( 6 | "math/rand" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/artie-labs/transfer/lib/destination/types" 12 | "github.com/artie-labs/transfer/lib/mocks" 13 | "github.com/artie-labs/transfer/lib/typing" 14 | "github.com/artie-labs/transfer/lib/typing/columns" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func generateDestinationTableConfig() *types.DestinationTableConfig { 19 | var cols []columns.Column 20 | for _, col := range []string{"a", "b", "c", "d"} { 21 | cols = append(cols, columns.NewColumn(col, typing.String)) 22 | } 23 | 24 | tableCfg := types.NewDestinationTableConfig(cols, false) 25 | colsToDelete := make(map[string]time.Time) 26 | for _, col := range []string{"foo", "bar", "abc", "xyz"} { 27 | colsToDelete[col] = time.Now() 28 | } 29 | 30 | tableCfg.SetColumnsToDeleteForTest(colsToDelete) 31 | return tableCfg 32 | } 33 | 34 | func TestDwhToTablesConfigMap_TableConfigBasic(t *testing.T) { 35 | dwh := &types.DestinationTableConfigMap{} 36 | dwhTableConfig := generateDestinationTableConfig() 37 | fakeTableID := &mocks.FakeTableIdentifier{} 38 | dwh.AddTable(fakeTableID, dwhTableConfig) 39 | assert.Equal(t, dwhTableConfig, dwh.GetTableConfig(fakeTableID)) 40 | } 41 | 42 | // TestDwhToTablesConfigMap_Concurrency - has a bunch of concurrent go-routines that are rapidly adding and reading from the tableConfig. 43 | func TestDwhToTablesConfigMap_Concurrency(t *testing.T) { 44 | dwh := &types.DestinationTableConfigMap{} 45 | fakeTableID := &mocks.FakeTableIdentifier{} 46 | dwhTableCfg := generateDestinationTableConfig() 47 | dwh.AddTable(fakeTableID, dwhTableCfg) 48 | var wg sync.WaitGroup 49 | // Write 50 | wg.Add(1) 51 | go func() { 52 | defer wg.Done() 53 | for i := 0; i < 1000; i++ { 54 | time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) 55 | dwh.AddTable(fakeTableID, dwhTableCfg) 56 | } 57 | }() 58 | 59 | // Read 60 | wg.Add(1) 61 | go func() { 62 | defer wg.Done() 63 | for i := 0; i < 1000; i++ { 64 | time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) 65 | assert.Equal(t, dwhTableCfg, dwh.GetTableConfig(fakeTableID)) 66 | } 67 | }() 68 | 69 | wg.Wait() 70 | } 71 | -------------------------------------------------------------------------------- /lib/destination/utils/load.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/artie-labs/transfer/clients/bigquery" 8 | "github.com/artie-labs/transfer/clients/databricks" 9 | "github.com/artie-labs/transfer/clients/iceberg" 10 | "github.com/artie-labs/transfer/clients/mssql" 11 | "github.com/artie-labs/transfer/clients/redshift" 12 | "github.com/artie-labs/transfer/clients/s3" 13 | "github.com/artie-labs/transfer/clients/snowflake" 14 | "github.com/artie-labs/transfer/lib/config" 15 | "github.com/artie-labs/transfer/lib/config/constants" 16 | "github.com/artie-labs/transfer/lib/db" 17 | "github.com/artie-labs/transfer/lib/destination" 18 | ) 19 | 20 | func IsOutputBaseline(cfg config.Config) bool { 21 | return cfg.Output == constants.S3 || cfg.Output == constants.Iceberg 22 | } 23 | 24 | func LoadBaseline(ctx context.Context, cfg config.Config) (destination.Baseline, error) { 25 | switch cfg.Output { 26 | case constants.S3: 27 | store, err := s3.LoadStore(ctx, cfg) 28 | if err != nil { 29 | return nil, fmt.Errorf("failed to load S3: %w", err) 30 | } 31 | 32 | return store, nil 33 | case constants.Iceberg: 34 | store, err := iceberg.LoadStore(ctx, cfg) 35 | if err != nil { 36 | return nil, fmt.Errorf("failed to load Iceberg: %w", err) 37 | } 38 | return store, nil 39 | } 40 | 41 | return nil, fmt.Errorf("invalid baseline output source specified: %q", cfg.Output) 42 | } 43 | 44 | func LoadDestination(ctx context.Context, cfg config.Config, store *db.Store) (destination.Destination, error) { 45 | switch cfg.Output { 46 | case constants.Snowflake: 47 | return snowflake.LoadSnowflake(ctx, cfg, store) 48 | case constants.BigQuery: 49 | return bigquery.LoadBigQuery(ctx, cfg, store) 50 | case constants.Databricks: 51 | return databricks.LoadStore(cfg) 52 | case constants.MSSQL: 53 | return mssql.LoadStore(cfg) 54 | case constants.Redshift: 55 | return redshift.LoadRedshift(ctx, cfg, store) 56 | } 57 | 58 | return nil, fmt.Errorf("invalid destination: %q", cfg.Output) 59 | } 60 | -------------------------------------------------------------------------------- /lib/environ/environment.go: -------------------------------------------------------------------------------- 1 | package environ 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | ) 8 | 9 | func MustGetEnv(envVars ...string) error { 10 | var invalidParts []string 11 | for _, envVar := range envVars { 12 | if os.Getenv(envVar) == "" { 13 | invalidParts = append(invalidParts, envVar) 14 | } 15 | } 16 | 17 | if len(invalidParts) > 0 { 18 | return fmt.Errorf("required environment variables %q are not set", strings.Join(invalidParts, ", ")) 19 | } 20 | 21 | return nil 22 | } 23 | -------------------------------------------------------------------------------- /lib/environ/environment_test.go: -------------------------------------------------------------------------------- 1 | package environ 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestMustGetEnv(t *testing.T) { 10 | { 11 | // Single environment variable is set 12 | t.Setenv("TEST_ENV_VAR", "test") 13 | assert.NoError(t, MustGetEnv("TEST_ENV_VAR")) 14 | } 15 | { 16 | // Multiple environment variables are set 17 | t.Setenv("TEST_ENV_VAR_2", "test2") 18 | t.Setenv("TEST_ENV_VAR_3", "test3") 19 | assert.NoError(t, MustGetEnv("TEST_ENV_VAR_2", "TEST_ENV_VAR_3")) 20 | } 21 | { 22 | // Environment variable is not set 23 | assert.ErrorContains(t, MustGetEnv("NONEXISTENT_ENV_VAR"), `required environment variables "NONEXISTENT_ENV_VAR" are not set`) 24 | } 25 | { 26 | // Multiple environment variables, some not set 27 | t.Setenv("TEST_ENV_VAR_4", "test4") 28 | assert.ErrorContains(t, MustGetEnv("TEST_ENV_VAR_4", "NONEXISTENT_ENV_VAR_2", "NONEXISTENT_ENV_VAR_3"), `required environment variables "NONEXISTENT_ENV_VAR_2, NONEXISTENT_ENV_VAR_3" are not set`) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /lib/jitter/sleep.go: -------------------------------------------------------------------------------- 1 | package jitter 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "time" 7 | ) 8 | 9 | const DefaultMaxMs = 3500 10 | 11 | // safePowerOfTwo calculates 2 ** n without panicking for values of n below 0 or above 62. 12 | func safePowerOfTwo(n int64) int64 { 13 | if n < 0 { 14 | return 0 15 | } else if n > 62 { 16 | return math.MaxInt64 // 2 ** n will overflow 17 | } 18 | return 1 << n // equal to 2 ** n 19 | } 20 | 21 | // computeJitterUpperBoundMs calculates min(maxMs, baseMs * 2 ** attempt). 22 | func computeJitterUpperBoundMs(baseMs, maxMs, attempts int64) int64 { 23 | if maxMs <= 0 { 24 | return 0 25 | } 26 | 27 | powerOfTwo := safePowerOfTwo(attempts) 28 | if powerOfTwo > math.MaxInt64/baseMs { // check for overflow 29 | return maxMs 30 | } 31 | return min(maxMs, baseMs*powerOfTwo) 32 | } 33 | 34 | // Jitter implements exponential backoff + jitter. 35 | // See: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 36 | // Algorithm: sleep = random_between(0, min(cap, base * 2 ** attempt)) 37 | func Jitter(baseMs, maxMs, attempts int) time.Duration { 38 | upperBoundMs := computeJitterUpperBoundMs(int64(baseMs), int64(maxMs), int64(attempts)) 39 | if upperBoundMs <= 0 { 40 | return time.Duration(0) 41 | } 42 | return time.Duration(rand.Int63n(upperBoundMs)) * time.Millisecond 43 | } 44 | -------------------------------------------------------------------------------- /lib/jitter/sleep_test.go: -------------------------------------------------------------------------------- 1 | package jitter 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSafePowerOfTwo(t *testing.T) { 12 | assert.Equal(t, int64(0), safePowerOfTwo(-2)) 13 | assert.Equal(t, int64(0), safePowerOfTwo(-1)) 14 | assert.Equal(t, int64(1), safePowerOfTwo(0)) 15 | assert.Equal(t, int64(2), safePowerOfTwo(1)) 16 | assert.Equal(t, int64(4), safePowerOfTwo(2)) 17 | assert.Equal(t, int64(4611686018427387904), safePowerOfTwo(62)) 18 | assert.Equal(t, int64(math.MaxInt64), safePowerOfTwo(63)) 19 | assert.Equal(t, int64(math.MaxInt64), safePowerOfTwo(64)) 20 | assert.Equal(t, int64(math.MaxInt64), safePowerOfTwo(100)) 21 | } 22 | 23 | func TestComputeJitterUpperBoundMs(t *testing.T) { 24 | // A maxMs that is <= 0 returns 0. 25 | assert.Equal(t, int64(0), computeJitterUpperBoundMs(0, 0, 0)) 26 | assert.Equal(t, int64(0), computeJitterUpperBoundMs(10, 0, 0)) 27 | assert.Equal(t, int64(0), computeJitterUpperBoundMs(10, 0, 100)) 28 | assert.Equal(t, int64(0), computeJitterUpperBoundMs(10, -1, 0)) 29 | assert.Equal(t, int64(0), computeJitterUpperBoundMs(10, -1, 100)) 30 | 31 | // Increasing attempts with a baseMs of 10 and essentially no maxMs. 32 | assert.Equal(t, int64(10), computeJitterUpperBoundMs(10, math.MaxInt64, 0)) 33 | assert.Equal(t, int64(20), computeJitterUpperBoundMs(10, math.MaxInt64, 1)) 34 | assert.Equal(t, int64(40), computeJitterUpperBoundMs(10, math.MaxInt64, 2)) 35 | assert.Equal(t, int64(80), computeJitterUpperBoundMs(10, math.MaxInt64, 3)) 36 | assert.Equal(t, int64(160), computeJitterUpperBoundMs(10, math.MaxInt64, 4)) 37 | 38 | // Large inputs do not panic. 39 | assert.Equal(t, int64(100), computeJitterUpperBoundMs(10, 100, 200)) 40 | assert.Equal(t, int64(100), computeJitterUpperBoundMs(10, 100, math.MaxInt64)) 41 | assert.Equal(t, int64(math.MaxInt64), computeJitterUpperBoundMs(math.MaxInt64, math.MaxInt64, math.MaxInt64)) 42 | } 43 | 44 | func TestJitter(t *testing.T) { 45 | // An upper bounds of 0 does not cause a [rand.Intn] panic. 46 | assert.Equal(t, time.Duration(0), Jitter(0, 0, 0)) 47 | assert.Equal(t, time.Duration(0), Jitter(-1, -1, -1)) 48 | 49 | { 50 | // A large number of attempts does not panic. 51 | value := Jitter(10, 100, 200) 52 | assert.LessOrEqual(t, value, time.Duration(100)*time.Millisecond) 53 | } 54 | { 55 | // A very large number of attempts does not panic. 56 | value := Jitter(10, 100, math.MaxInt) 57 | assert.LessOrEqual(t, value, time.Duration(100)*time.Millisecond) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /lib/jsonutil/jsonutil.go: -------------------------------------------------------------------------------- 1 | package jsonutil 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | func UnmarshalPayload(val string) (any, error) { 8 | // There are edge cases for when this may happen 9 | // Example: JSONB column in a table in Postgres where the table replica identity is set to `default` and it was a delete event. 10 | if val == "" { 11 | return "", nil 12 | } 13 | 14 | var obj any 15 | if err := json.Unmarshal([]byte(val), &obj); err != nil { 16 | return nil, err 17 | } 18 | 19 | return obj, nil 20 | } 21 | -------------------------------------------------------------------------------- /lib/jsonutil/jsonutil_test.go: -------------------------------------------------------------------------------- 1 | package jsonutil 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestUnmarshalPayload(t *testing.T) { 10 | { 11 | // Invalid JSON string 12 | _, err := UnmarshalPayload("hello") 13 | assert.ErrorContains(t, err, "invalid character 'h' looking for beginning of value") 14 | } 15 | { 16 | // Empty JSON string edge case 17 | val, err := UnmarshalPayload("") 18 | assert.NoError(t, err) 19 | assert.Equal(t, "", val) 20 | } 21 | { 22 | // Valid JSON string, nothing changed. 23 | val, err := UnmarshalPayload(`{"hello":"world"}`) 24 | assert.NoError(t, err) 25 | assert.Equal(t, map[string]any{"hello": "world"}, val) 26 | } 27 | { 28 | // Fake JSON - appears to be in JSON format, but has duplicate keys 29 | val, err := UnmarshalPayload(`{"hello":"11world","hello":"world"}`) 30 | assert.NoError(t, err) 31 | assert.Equal(t, map[string]any{"hello": "world"}, val) 32 | } 33 | { 34 | // Make sure all the keys are good and only duplicate keys got stripped 35 | val, err := UnmarshalPayload(`{"hello":"world","foo":"bar","hello":"world"}`) 36 | assert.NoError(t, err) 37 | assert.Equal(t, map[string]any{"hello": "world", "foo": "bar"}, val) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /lib/kafkalib/connection_test.go: -------------------------------------------------------------------------------- 1 | package kafkalib 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestConnection_Mechanism(t *testing.T) { 10 | { 11 | c := NewConnection(false, false, "", "", DefaultTimeout) 12 | assert.Equal(t, Plain, c.Mechanism()) 13 | } 14 | { 15 | c := NewConnection(false, false, "username", "password", DefaultTimeout) 16 | assert.Equal(t, ScramSha512, c.Mechanism()) 17 | 18 | // Username and password are set but AWS IAM is enabled 19 | c = NewConnection(true, false, "username", "password", DefaultTimeout) 20 | assert.Equal(t, AwsMskIam, c.Mechanism()) 21 | } 22 | { 23 | c := NewConnection(true, false, "", "", DefaultTimeout) 24 | assert.Equal(t, AwsMskIam, c.Mechanism()) 25 | } 26 | { 27 | // not setting timeout 28 | c := NewConnection(false, false, "", "", 0) 29 | assert.Equal(t, DefaultTimeout, c.timeout) 30 | } 31 | } 32 | 33 | func TestConnection_Dialer(t *testing.T) { 34 | ctx := t.Context() 35 | { 36 | // Plain 37 | c := NewConnection(false, false, "", "", DefaultTimeout) 38 | dialer, err := c.Dialer(ctx) 39 | assert.NoError(t, err) 40 | assert.Nil(t, dialer.TLS) 41 | assert.Nil(t, dialer.SASLMechanism) 42 | } 43 | { 44 | // SCRAM enabled with TLS 45 | c := NewConnection(false, false, "username", "password", DefaultTimeout) 46 | dialer, err := c.Dialer(ctx) 47 | assert.NoError(t, err) 48 | assert.NotNil(t, dialer.TLS) 49 | assert.NotNil(t, dialer.SASLMechanism) 50 | 51 | // w/o TLS 52 | c = NewConnection(false, true, "username", "password", DefaultTimeout) 53 | dialer, err = c.Dialer(ctx) 54 | assert.NoError(t, err) 55 | assert.Nil(t, dialer.TLS) 56 | assert.NotNil(t, dialer.SASLMechanism) 57 | } 58 | { 59 | // AWS IAM w/ TLS 60 | c := NewConnection(true, false, "", "", DefaultTimeout) 61 | dialer, err := c.Dialer(ctx) 62 | assert.NoError(t, err) 63 | assert.NotNil(t, dialer.TLS) 64 | assert.NotNil(t, dialer.SASLMechanism) 65 | 66 | // w/o TLS (still enabled because AWS doesn't support not having TLS) 67 | c = NewConnection(true, true, "", "", DefaultTimeout) 68 | dialer, err = c.Dialer(ctx) 69 | assert.NoError(t, err) 70 | assert.NotNil(t, dialer.TLS) 71 | assert.NotNil(t, dialer.SASLMechanism) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /lib/kafkalib/consumer.go: -------------------------------------------------------------------------------- 1 | package kafkalib 2 | 3 | import ( 4 | "context" 5 | "github.com/segmentio/kafka-go" 6 | ) 7 | 8 | type Consumer interface { 9 | Close() (err error) 10 | ReadMessage(ctx context.Context) (kafka.Message, error) 11 | CommitMessages(ctx context.Context, msgs ...kafka.Message) error 12 | } 13 | -------------------------------------------------------------------------------- /lib/kafkalib/partition/settings.go: -------------------------------------------------------------------------------- 1 | package partition 2 | 3 | import ( 4 | "fmt" 5 | "slices" 6 | ) 7 | 8 | var ValidPartitionTypes = []string{ 9 | "time", 10 | } 11 | 12 | // TODO: We should be able to support different partition by fields in the future. 13 | // https://cloud.google.com/bigquery/docs/partitioned-tables#partition_decorators 14 | var ValidPartitionBy = []string{ 15 | "daily", 16 | } 17 | 18 | // We need the JSON annotations here for our dashboard to import the settings correctly. 19 | 20 | type MergePredicates struct { 21 | PartitionField string `yaml:"partitionField" json:"partitionField"` 22 | } 23 | 24 | type BigQuerySettings struct { 25 | PartitionType string `yaml:"partitionType" json:"partitionType"` 26 | PartitionField string `yaml:"partitionField" json:"partitionField"` 27 | PartitionBy string `yaml:"partitionBy" json:"partitionBy"` 28 | } 29 | 30 | func (b *BigQuerySettings) Valid() error { 31 | if b == nil { 32 | return fmt.Errorf("bigQuerySettings is nil") 33 | } 34 | 35 | if b.PartitionType == "" { 36 | return fmt.Errorf("partitionTypes cannot be empty") 37 | } 38 | 39 | if b.PartitionField == "" { 40 | return fmt.Errorf("partitionField cannot be empty") 41 | } 42 | 43 | if b.PartitionBy == "" { 44 | return fmt.Errorf("partitionBy cannot be empty") 45 | } 46 | 47 | if !slices.Contains(ValidPartitionTypes, b.PartitionType) { 48 | return fmt.Errorf("partitionType must be one of: %v", ValidPartitionTypes) 49 | } 50 | 51 | if !slices.Contains(ValidPartitionBy, b.PartitionBy) { 52 | return fmt.Errorf("partitionBy must be one of: %v", ValidPartitionBy) 53 | } 54 | 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /lib/kafkalib/partition/settings_test.go: -------------------------------------------------------------------------------- 1 | package partition 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestBigQuerySettings_Valid(t *testing.T) { 10 | { 11 | // Nil 12 | var settings *BigQuerySettings 13 | assert.ErrorContains(t, settings.Valid(), "bigQuerySettings is nil") 14 | } 15 | { 16 | // Empty partition type 17 | settings := &BigQuerySettings{} 18 | assert.ErrorContains(t, settings.Valid(), "partitionTypes cannot be empty") 19 | } 20 | { 21 | // Empty partition field 22 | settings := &BigQuerySettings{PartitionType: "time"} 23 | assert.ErrorContains(t, settings.Valid(), "partitionField cannot be empty") 24 | } 25 | { 26 | // Empty partition by 27 | settings := &BigQuerySettings{PartitionType: "time", PartitionField: "created_at"} 28 | assert.ErrorContains(t, settings.Valid(), "partitionBy cannot be empty") 29 | } 30 | { 31 | // Invalid partition type 32 | settings := &BigQuerySettings{PartitionType: "invalid", PartitionField: "created_at", PartitionBy: "daily"} 33 | assert.ErrorContains(t, settings.Valid(), "partitionType must be one of:") 34 | } 35 | { 36 | // Invalid partition by 37 | settings := &BigQuerySettings{PartitionType: "time", PartitionField: "created_at", PartitionBy: "invalid"} 38 | assert.ErrorContains(t, settings.Valid(), "partitionBy must be one of:") 39 | } 40 | { 41 | // Valid 42 | settings := &BigQuerySettings{PartitionType: "time", PartitionField: "created_at", PartitionBy: "daily"} 43 | assert.NoError(t, settings.Valid()) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /lib/logger/log.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "log/slog" 5 | "os" 6 | "time" 7 | 8 | "github.com/getsentry/sentry-go" 9 | "github.com/lmittmann/tint" 10 | "github.com/mattn/go-isatty" 11 | slogmulti "github.com/samber/slog-multi" 12 | slogsentry "github.com/samber/slog-sentry/v2" 13 | 14 | "github.com/artie-labs/transfer/lib/config" 15 | ) 16 | 17 | var handlersToTerminate []func() 18 | 19 | func NewLogger(verbose bool, sentryCfg *config.Sentry, version string) (*slog.Logger, func()) { 20 | tintLogLevel := slog.LevelInfo 21 | if verbose { 22 | tintLogLevel = slog.LevelDebug 23 | } 24 | 25 | handler := tint.NewHandler(os.Stderr, &tint.Options{ 26 | Level: tintLogLevel, 27 | NoColor: !isatty.IsTerminal(os.Stderr.Fd()), 28 | }) 29 | if sentryCfg != nil && sentryCfg.DSN != "" { 30 | if err := sentry.Init(sentry.ClientOptions{ 31 | Dsn: sentryCfg.DSN, 32 | Release: "artie-transfer@" + version, 33 | }); err != nil { 34 | slog.New(handler).Warn("Failed to enable Sentry output", slog.Any("err", err)) 35 | } else { 36 | slog.New(handler).Info("Sentry logger enabled") 37 | handler = slogmulti.Fanout(handler, slogsentry.Option{Level: slog.LevelError}.NewSentryHandler()) 38 | handlersToTerminate = append(handlersToTerminate, func() { 39 | sentry.Flush(2 * time.Second) 40 | }) 41 | } 42 | } 43 | 44 | return slog.New(handler), runHandlers 45 | } 46 | 47 | func runHandlers() { 48 | for _, handlerToTerminate := range handlersToTerminate { 49 | handlerToTerminate() 50 | } 51 | } 52 | 53 | func Fatal(msg string, args ...any) { 54 | slog.Error(msg, args...) 55 | runHandlers() 56 | os.Exit(1) 57 | } 58 | 59 | func Panic(msg string, args ...any) { 60 | slog.Error(msg, args...) 61 | runHandlers() 62 | panic(msg) 63 | } 64 | -------------------------------------------------------------------------------- /lib/maputil/map.go: -------------------------------------------------------------------------------- 1 | package maputil 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | 7 | "github.com/artie-labs/transfer/lib/typing" 8 | ) 9 | 10 | func GetKeyFromMap(obj map[string]any, key string, defaultValue any) any { 11 | if len(obj) == 0 { 12 | return defaultValue 13 | } 14 | 15 | val, isOk := obj[key] 16 | if !isOk { 17 | return defaultValue 18 | } 19 | 20 | return val 21 | } 22 | 23 | func GetInt32FromMap(obj map[string]any, key string) (int32, error) { 24 | if len(obj) == 0 { 25 | return 0, fmt.Errorf("object is empty") 26 | } 27 | 28 | valInterface, isOk := obj[key] 29 | if !isOk { 30 | return 0, fmt.Errorf("key: %s does not exist in object", key) 31 | } 32 | 33 | val, err := strconv.ParseInt(fmt.Sprint(valInterface), 10, 32) 34 | if err != nil { 35 | return 0, fmt.Errorf("key: %s is not type integer: %w", key, err) 36 | } 37 | 38 | return int32(val), nil 39 | } 40 | 41 | func GetTypeFromMap[T any](obj map[string]any, key string) (T, error) { 42 | value, isOk := obj[key] 43 | if !isOk { 44 | var zero T 45 | return zero, fmt.Errorf("key: %q does not exist in object", key) 46 | } 47 | 48 | return typing.AssertType[T](value) 49 | } 50 | -------------------------------------------------------------------------------- /lib/maputil/ordered_map.go: -------------------------------------------------------------------------------- 1 | package maputil 2 | 3 | import ( 4 | "iter" 5 | "slices" 6 | "strings" 7 | ) 8 | 9 | func removeFromSlice[T any](slice []T, s int) []T { 10 | return append(slice[:s], slice[s+1:]...) 11 | } 12 | 13 | type OrderedMap[T any] struct { 14 | keys []string 15 | // data - Important: Do not ever expose `data` out, always use Get, Add, Remove methods as it will cause corruption between `data` and `keys` 16 | data map[string]T 17 | // caseSensitive - if true - will preserve original casing, else it will lowercase everything 18 | caseSensitive bool 19 | } 20 | 21 | func NewOrderedMap[T any](caseSensitive bool) *OrderedMap[T] { 22 | return &OrderedMap[T]{ 23 | keys: []string{}, 24 | data: make(map[string]T), 25 | caseSensitive: caseSensitive, 26 | } 27 | } 28 | 29 | func (o *OrderedMap[T]) Remove(key string) (removed bool) { 30 | if !o.caseSensitive { 31 | key = strings.ToLower(key) 32 | } 33 | 34 | if index := slices.Index(o.keys, key); index >= 0 { 35 | delete(o.data, key) 36 | o.keys = removeFromSlice(o.keys, index) 37 | return true 38 | } 39 | 40 | return false 41 | } 42 | 43 | func (o *OrderedMap[T]) Add(key string, value T) { 44 | if !o.caseSensitive { 45 | key = strings.ToLower(key) 46 | } 47 | 48 | // Does the key already exist? 49 | // Only add it to `keys` if it doesn't exist 50 | if _, isOk := o.Get(key); !isOk { 51 | o.keys = append(o.keys, key) 52 | } 53 | 54 | o.data[key] = value 55 | } 56 | 57 | func (o *OrderedMap[T]) Get(key string) (T, bool) { 58 | if !o.caseSensitive { 59 | key = strings.ToLower(key) 60 | } 61 | 62 | val, ok := o.data[key] 63 | return val, ok 64 | } 65 | 66 | func (o *OrderedMap[T]) NotEmpty() bool { 67 | return len(o.data) > 0 68 | } 69 | 70 | func (o *OrderedMap[T]) Keys() []string { 71 | return slices.Clone(o.keys) 72 | } 73 | 74 | // All returns an in-order iterator over key-value pairs. 75 | func (o *OrderedMap[T]) All() iter.Seq2[string, T] { 76 | return func(yield func(string, T) bool) { 77 | for _, key := range o.keys { 78 | if value, ok := o.Get(key); ok { 79 | if !yield(key, value) { 80 | break 81 | } 82 | } 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /lib/mocks/generate.go: -------------------------------------------------------------------------------- 1 | package mocks 2 | 3 | //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate 4 | //counterfeiter:generate -o=db.store.mock.go ../db Store 5 | //counterfeiter:generate -o=kafkalib.consumer.mock.go ../kafkalib Consumer 6 | 7 | //counterfeiter:generate -o=destination.mock.go ../destination Destination 8 | //counterfeiter:generate -o=baseline.mock.go ../destination Baseline 9 | //counterfeiter:generate -o=tableid.mock.go ../sql TableIdentifier 10 | 11 | //counterfeiter:generate -o=event.mock.go ../cdc Event 12 | -------------------------------------------------------------------------------- /lib/numbers/numbers.go: -------------------------------------------------------------------------------- 1 | package numbers 2 | 3 | import "github.com/cockroachdb/apd/v3" 4 | 5 | // BetweenEq - Looks something like this. start <= number <= end 6 | func BetweenEq[T int | int32 | int64](start, end, number T) bool { 7 | return number >= start && number <= end 8 | } 9 | 10 | // MustParseDecimal parses a string to a [*apd.Decimal] or panics -- used for tests. 11 | func MustParseDecimal(value string) *apd.Decimal { 12 | decimal, _, err := apd.NewFromString(value) 13 | if err != nil { 14 | panic(err) 15 | } 16 | return decimal 17 | } 18 | -------------------------------------------------------------------------------- /lib/numbers/numbers_test.go: -------------------------------------------------------------------------------- 1 | package numbers 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestBetweenEq(t *testing.T) { 10 | { 11 | // Test number within range 12 | assert.True(t, BetweenEq(5, 500, 100), "number within range should return true") 13 | } 14 | { 15 | // Test number at lower bound 16 | assert.True(t, BetweenEq(5, 500, 5), "number at lower bound should return true") 17 | } 18 | { 19 | // Test number at upper bound 20 | assert.True(t, BetweenEq(5, 500, 500), "number at upper bound should return true") 21 | } 22 | { 23 | // Test number above range 24 | assert.False(t, BetweenEq(5, 500, 501), "number above range should return false") 25 | } 26 | { 27 | // Test number below range 28 | assert.False(t, BetweenEq(5, 500, 4), "number below range should return false") 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /lib/parquetutil/generate_schema.go: -------------------------------------------------------------------------------- 1 | package parquetutil 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/artie-labs/transfer/lib/typing/columns" 7 | ) 8 | 9 | func BuildCSVSchema(columns []columns.Column, location *time.Location) ([]string, error) { 10 | var fields []string 11 | for _, column := range columns { 12 | // We don't need to escape the column name here. 13 | field, err := column.KindDetails.ParquetAnnotation(column.Name(), location) 14 | if err != nil { 15 | return nil, err 16 | } 17 | 18 | fields = append(fields, field.Tag) 19 | } 20 | 21 | return fields, nil 22 | } 23 | -------------------------------------------------------------------------------- /lib/size/size.go: -------------------------------------------------------------------------------- 1 | package size 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func GetApproxSize(value any) int { 8 | // We chose not to use unsafe.SizeOf or reflect.Type.Size (both are akin) because they do not do recursive traversal. 9 | // We also chose not to use gob.NewEncoder because it does not work for all data types and had a huge computational overhead. 10 | // Another plus here is that this will not error out. 11 | if value == nil { 12 | return 0 13 | } 14 | 15 | switch v := value.(type) { 16 | case string: 17 | return len(v) 18 | case []byte: 19 | return len(v) 20 | case bool: 21 | return 1 22 | case int8, uint8: 23 | return 1 24 | case int16, uint16: 25 | return 2 26 | case int32, uint32, float32: 27 | return 4 28 | case int, int64, uint, uint64, uintptr, float64, complex64: 29 | // int, uint are platform dependent - but to be safe, let's over approximate and assume 64-bit system 30 | return 8 31 | case complex128: 32 | return 16 33 | case map[string]any: 34 | var size int 35 | for _, val := range v { 36 | size += GetApproxSize(val) 37 | } 38 | return size 39 | case []map[string]any: 40 | var size int 41 | for _, val := range v { 42 | size += GetApproxSize(val) 43 | } 44 | return size 45 | case []string: 46 | var size int 47 | for _, val := range v { 48 | size += GetApproxSize(val) 49 | } 50 | return size 51 | case []any: 52 | var size int 53 | for _, val := range v { 54 | size += GetApproxSize(val) 55 | } 56 | return size 57 | case [][]byte: 58 | var size int 59 | for _, val := range v { 60 | size += GetApproxSize(val) 61 | } 62 | return size 63 | } 64 | 65 | return len([]byte(fmt.Sprint(value))) 66 | } 67 | -------------------------------------------------------------------------------- /lib/size/size_bench_test.go: -------------------------------------------------------------------------------- 1 | package size 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func BenchmarkGetApproxSize_TallTable(b *testing.B) { 10 | rowsData := make(map[string]any) 11 | for i := 0; i < 5000; i++ { 12 | rowsData[fmt.Sprint(i)] = map[string]any{ 13 | "id": i, 14 | "name": "Robin", 15 | "dog": "dusty the mini aussie", 16 | } 17 | } 18 | 19 | for n := 0; n < b.N; n++ { 20 | GetApproxSize(rowsData) 21 | } 22 | } 23 | 24 | func BenchmarkGetApproxSize_WideTable(b *testing.B) { 25 | rowsData := make(map[string]any) 26 | for i := 0; i < 5000; i++ { 27 | rowsData[fmt.Sprint(i)] = map[string]any{ 28 | "id": i, 29 | "name": "Robin", 30 | "dog": "dusty the mini aussie", 31 | "favorite_fruits": []string{"strawberry", "kiwi", "oranges"}, 32 | "random": false, 33 | "team": []string{"charlie", "jacqueline"}, 34 | "email": "robin@example.com", 35 | "favorite_languages": []string{"go", "sql"}, 36 | "favorite_databases": []string{"postgres", "bigtable"}, 37 | "created_at": time.Now(), 38 | "updated_at": time.Now(), 39 | "negative_number": -500, 40 | "nestedObject": map[string]any{ 41 | "foo": "bar", 42 | "abc": "def", 43 | }, 44 | "array_of_objects": []map[string]any{ 45 | { 46 | "foo": "bar", 47 | }, 48 | { 49 | "foo_nested": map[string]any{ 50 | "foo_foo": "bar_bar", 51 | }, 52 | }, 53 | }, 54 | "is_deleted": false, 55 | "lorem_ipsum": "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec elementum aliquet mi at efficitur. Praesent at erat ac elit faucibus convallis. Donec fermentum tellus eu nunc ornare, non convallis justo facilisis. In hac habitasse platea dictumst. Praesent eu ante vitae erat semper finibus eget ac mauris. Duis gravida cursus enim, nec sagittis arcu placerat sed. Integer semper orci justo, nec rhoncus libero convallis sed.", 56 | "lorem_ipsum2": "Fusce vitae elementum tortor. Vestibulum consectetur ante id nibh ullamcorper, quis sodales turpis tempor. Duis pellentesque suscipit nibh porta posuere. In libero massa, efficitur at ultricies sit amet, vulputate ac ante. In euismod erat eget nulla blandit pretium. Ut tempor ante vel congue venenatis. Vestibulum at metus nec nibh iaculis consequat suscipit ac leo. Maecenas vitae rutrum nulla, quis ultrices justo. Aliquam ipsum ex, luctus ac diam eget, tempor tempor risus.", 57 | } 58 | } 59 | 60 | for n := 0; n < b.N; n++ { 61 | GetApproxSize(rowsData) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /lib/size/size_test.go: -------------------------------------------------------------------------------- 1 | package size 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGetApproxSize(t *testing.T) { 11 | rowsData := make(map[string]any) // pk -> { col -> val } 12 | for i := 0; i < 500; i++ { 13 | rowsData[fmt.Sprintf("key-%v", i)] = map[string]any{ 14 | "id": fmt.Sprintf("key-%v", i), 15 | "artie": "transfer", 16 | "dusty": "the mini aussie", 17 | "next_puppy": true, 18 | "foo": []any{"bar", "baz", "qux"}, 19 | "team": []string{"charlie", "robin", "jacqueline"}, 20 | "arrays": []string{"foo", "bar", "baz"}, 21 | "nested": map[string]any{ 22 | "foo": "bar", 23 | "abc": "xyz", 24 | }, 25 | "array_of_maps": []map[string]any{ 26 | { 27 | "foo": "bar", 28 | }, 29 | { 30 | "abc": "xyz", 31 | }, 32 | }, 33 | } 34 | } 35 | 36 | size := GetApproxSize(rowsData) 37 | 38 | // Check if size is non-zero and seems plausible 39 | assert.NotZero(t, size, "Size should not be zero") 40 | assert.Greater(t, size, 1000, "Size should be reasonably large for the given data structure") 41 | } 42 | -------------------------------------------------------------------------------- /lib/sql/rows.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | ) 7 | 8 | func RowsToObjects(rows *sql.Rows) ([]map[string]any, error) { 9 | defer rows.Close() 10 | 11 | columns, err := rows.Columns() 12 | if err != nil { 13 | return nil, err 14 | } 15 | 16 | var objects []map[string]any 17 | for rows.Next() { 18 | row := make([]any, len(columns)) 19 | rowPointers := make([]any, len(columns)) 20 | for i := range row { 21 | rowPointers[i] = &row[i] 22 | } 23 | 24 | if err = rows.Scan(rowPointers...); err != nil { 25 | return nil, err 26 | } 27 | 28 | object := make(map[string]any) 29 | for i, column := range columns { 30 | object[column] = row[i] 31 | } 32 | 33 | objects = append(objects, object) 34 | } 35 | 36 | if err = rows.Err(); err != nil { 37 | return nil, fmt.Errorf("failed to iterate over rows: %w", err) 38 | } 39 | 40 | return objects, nil 41 | } 42 | -------------------------------------------------------------------------------- /lib/sql/tests/util_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "testing" 5 | 6 | redshiftDialect "github.com/artie-labs/transfer/clients/redshift/dialect" 7 | "github.com/artie-labs/transfer/lib/sql" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestQuoteIdentifiers(t *testing.T) { 12 | assert.Equal(t, []string{}, sql.QuoteIdentifiers([]string{}, redshiftDialect.RedshiftDialect{})) 13 | assert.Equal(t, []string{`"a"`, `"b"`, `"c"`}, sql.QuoteIdentifiers([]string{"a", "b", "c"}, redshiftDialect.RedshiftDialect{})) 14 | } 15 | -------------------------------------------------------------------------------- /lib/sql/util.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/artie-labs/transfer/lib/stringutil" 8 | ) 9 | 10 | // QuoteLiteral wraps a string with single quotes so that it can be used in a SQL query. 11 | // If there are backslashes in the string, then they will be escaped to [\\]. 12 | // After escaping backslashes, any remaining single quotes will be replaced with [\']. 13 | func QuoteLiteral(value string) string { 14 | return fmt.Sprintf("'%s'", strings.ReplaceAll(stringutil.EscapeBackslashes(value), "'", `\'`)) 15 | } 16 | 17 | func QuoteLiterals(values []string) []string { 18 | result := make([]string, len(values)) 19 | for i, value := range values { 20 | result[i] = QuoteLiteral(value) 21 | } 22 | return result 23 | } 24 | 25 | func QuoteIdentifiers(identifiers []string, dialect Dialect) []string { 26 | result := make([]string, len(identifiers)) 27 | for i, identifier := range identifiers { 28 | result[i] = dialect.QuoteIdentifier(identifier) 29 | } 30 | return result 31 | } 32 | 33 | // ParseDataTypeDefinition parses a column type definition returning the type and parameters. 34 | // "TEXT" -> "TEXT", {} 35 | // "VARCHAR(1234)" -> "VARCHAR", {"1234"} 36 | // "NUMERIC(5, 1)" -> "NUMERIC", {"5", "1"} 37 | func ParseDataTypeDefinition(value string) (string, []string, error) { 38 | value = strings.TrimSpace(value) 39 | 40 | if idx := strings.Index(value, "("); idx > 0 { 41 | if value[len(value)-1] != ')' { 42 | return "", nil, fmt.Errorf("missing closing parenthesis") 43 | } 44 | 45 | parameters := strings.Split(value[idx+1:len(value)-1], ",") 46 | for i, parameter := range parameters { 47 | parameters[i] = strings.TrimSpace(parameter) 48 | } 49 | return strings.TrimSpace(value[:idx]), parameters, nil 50 | } 51 | return value, nil, nil 52 | } 53 | -------------------------------------------------------------------------------- /lib/stringutil/strings.go: -------------------------------------------------------------------------------- 1 | package stringutil 2 | 3 | import ( 4 | "math/rand" 5 | "strings" 6 | ) 7 | 8 | func CapitalizeFirstLetter(s string) string { 9 | if len(s) == 0 { 10 | return s 11 | } 12 | 13 | return strings.ToUpper(s[:1]) + s[1:] 14 | } 15 | 16 | func EscapeBackslashes(value string) string { 17 | return strings.ReplaceAll(value, `\`, `\\`) 18 | } 19 | 20 | func Empty(vals ...string) bool { 21 | for _, val := range vals { 22 | if val == "" { 23 | return true 24 | } 25 | } 26 | 27 | return false 28 | } 29 | 30 | func EscapeSpaces(col string) (escaped bool, newString string) { 31 | subStr := " " 32 | return strings.Contains(col, subStr), strings.ReplaceAll(col, subStr, "__") 33 | } 34 | 35 | func stringWithCharset(length int, charset string) string { 36 | b := make([]byte, length) 37 | for i := range b { 38 | b[i] = charset[rand.Intn(len(charset))] 39 | } 40 | return string(b) 41 | } 42 | 43 | func Random(length int) string { 44 | return stringWithCharset(length, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") 45 | } 46 | -------------------------------------------------------------------------------- /lib/stringutil/strings_test.go: -------------------------------------------------------------------------------- 1 | package stringutil 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestCapitalizeFirstLetter(t *testing.T) { 10 | { 11 | assert.Equal(t, "Hello", CapitalizeFirstLetter("hello")) 12 | } 13 | { 14 | assert.Equal(t, "", CapitalizeFirstLetter("")) 15 | } 16 | { 17 | assert.Equal(t, "H", CapitalizeFirstLetter("H")) 18 | } 19 | } 20 | 21 | func TestEscapeBackslashes(t *testing.T) { 22 | { 23 | // No escape 24 | { 25 | assert.Equal(t, "hello", EscapeBackslashes("hello")) 26 | } 27 | { 28 | // Special char 29 | assert.Equal(t, `bobby o'reilly`, EscapeBackslashes(`bobby o'reilly`)) 30 | } 31 | { 32 | // Line breaks 33 | assert.Equal(t, "line1 \n line 2", EscapeBackslashes("line1 \n line 2")) 34 | } 35 | } 36 | { 37 | // Escape 38 | { 39 | // Backslash 40 | assert.Equal(t, `hello \\ there \\ hh`, EscapeBackslashes(`hello \ there \ hh`)) 41 | 42 | } 43 | } 44 | } 45 | 46 | func TestEmpty(t *testing.T) { 47 | { 48 | // No empty 49 | assert.False(t, Empty("hi", "there", "artie", "transfer")) 50 | assert.False(t, Empty("dusty")) 51 | } 52 | { 53 | // Empty 54 | assert.True(t, Empty("robin", "jacqueline", "charlie", "")) 55 | assert.True(t, Empty("")) 56 | } 57 | } 58 | 59 | func TestEscapeSpaces(t *testing.T) { 60 | colsToExpectation := map[string]map[string]any{ 61 | "columnA": {"escaped": "columnA", "space": false}, 62 | "column_a": {"escaped": "column_a", "space": false}, 63 | "column a": {"escaped": "column__a", "space": true}, 64 | } 65 | 66 | for col, expected := range colsToExpectation { 67 | containsSpace, escapedString := EscapeSpaces(col) 68 | assert.Equal(t, expected["escaped"], escapedString) 69 | assert.Equal(t, expected["space"], containsSpace) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /lib/telemetry/metrics/base/provider.go: -------------------------------------------------------------------------------- 1 | package base 2 | 3 | import "time" 4 | 5 | type Client interface { 6 | Timing(name string, value time.Duration, tags map[string]string) 7 | Incr(name string, tags map[string]string) 8 | Count(name string, value int64, tags map[string]string) 9 | Gauge(name string, value float64, tags map[string]string) 10 | GaugeWithSample(name string, value float64, tags map[string]string, sample float64) 11 | } 12 | -------------------------------------------------------------------------------- /lib/telemetry/metrics/datadog/datadog_test.go: -------------------------------------------------------------------------------- 1 | package datadog 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGetSampleRate(t *testing.T) { 11 | assert.Equal(t, getSampleRate("foo"), float64(DefaultSampleRate)) 12 | assert.Equal(t, getSampleRate(1.25), float64(DefaultSampleRate)) 13 | assert.Equal(t, getSampleRate(1), float64(1)) 14 | assert.Equal(t, getSampleRate(0.33), 0.33) 15 | assert.Equal(t, getSampleRate(0), float64(DefaultSampleRate)) 16 | assert.Equal(t, getSampleRate(-0.55), float64(DefaultSampleRate)) 17 | } 18 | 19 | func TestGetTags(t *testing.T) { 20 | assert.Equal(t, getTags(nil), []string{}) 21 | assert.Equal(t, getTags([]string{}), []string{}) 22 | assert.Equal(t, getTags([]any{"env:bar", "a:b"}), []string{"env:bar", "a:b"}) 23 | } 24 | 25 | func TestNewDatadogClient(t *testing.T) { 26 | client, err := NewDatadogClient(map[string]any{ 27 | Tags: []string{ 28 | "env:production", 29 | }, 30 | Namespace: "dusty.", 31 | Sampling: 0.255, 32 | }) 33 | 34 | assert.NoError(t, err) 35 | mtr, isOk := client.(*statsClient) 36 | assert.True(t, isOk) 37 | assert.Equal(t, 0.255, mtr.rate) 38 | 39 | clientValue := reflect.ValueOf(mtr.client).Elem() 40 | assert.Equal(t, "dusty.", clientValue.FieldByName("namespace").String()) 41 | tagsField := clientValue.FieldByName("tags") 42 | assert.Equal(t, 1, tagsField.Len()) 43 | assert.Equal(t, "env:production", tagsField.Index(0).String()) 44 | } 45 | -------------------------------------------------------------------------------- /lib/telemetry/metrics/datadog/tags.go: -------------------------------------------------------------------------------- 1 | package datadog 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gopkg.in/yaml.v3" 7 | ) 8 | 9 | func getTags(tags any) []string { 10 | // Yaml parses lists as a sequence, so we'll unpack it again with the same library. 11 | if tags == nil { 12 | return []string{} 13 | } 14 | 15 | yamlBytes, err := yaml.Marshal(tags) 16 | if err != nil { 17 | return []string{} 18 | } 19 | 20 | var retTagStrings []string 21 | err = yaml.Unmarshal(yamlBytes, &retTagStrings) 22 | if err != nil { 23 | return []string{} 24 | } 25 | 26 | return retTagStrings 27 | } 28 | 29 | func toDatadogTags(tags map[string]string) []string { 30 | var retTags []string 31 | for key, val := range tags { 32 | retTags = append(retTags, fmt.Sprintf("%s:%s", key, val)) 33 | } 34 | 35 | return retTags 36 | } 37 | -------------------------------------------------------------------------------- /lib/telemetry/metrics/null_provider.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import "time" 4 | 5 | type NullMetricsProvider struct{} 6 | 7 | func (n NullMetricsProvider) Gauge(name string, value float64, tags map[string]string) { 8 | } 9 | 10 | func (n NullMetricsProvider) GaugeWithSample(name string, value float64, tags map[string]string, sample float64) { 11 | } 12 | 13 | func (n NullMetricsProvider) Count(name string, value int64, tags map[string]string) { 14 | } 15 | 16 | func (n NullMetricsProvider) Timing(name string, value time.Duration, tags map[string]string) { 17 | } 18 | 19 | func (n NullMetricsProvider) Incr(name string, tags map[string]string) { 20 | } 21 | -------------------------------------------------------------------------------- /lib/telemetry/metrics/stats.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/artie-labs/transfer/lib/config" 7 | "github.com/artie-labs/transfer/lib/config/constants" 8 | "github.com/artie-labs/transfer/lib/telemetry/metrics/base" 9 | "github.com/artie-labs/transfer/lib/telemetry/metrics/datadog" 10 | ) 11 | 12 | var supportedExporterKinds = []constants.ExporterKind{constants.Datadog} 13 | 14 | func exporterKindValid(kind constants.ExporterKind) bool { 15 | var valid bool 16 | for _, supportedExporterKind := range supportedExporterKinds { 17 | valid = kind == supportedExporterKind 18 | if valid { 19 | break 20 | } 21 | } 22 | 23 | return valid 24 | } 25 | 26 | func LoadExporter(cfg config.Config) base.Client { 27 | kind := cfg.Telemetry.Metrics.Provider 28 | ddSettings := cfg.Telemetry.Metrics.Settings 29 | if !exporterKindValid(kind) { 30 | slog.Info("Invalid or no exporter kind passed in, skipping...", slog.Any("exporterKind", kind)) 31 | } 32 | 33 | switch kind { 34 | case constants.Datadog: 35 | statsClient, exportErr := datadog.NewDatadogClient(ddSettings) 36 | if exportErr != nil { 37 | slog.Error("Metrics client error", slog.Any("err", exportErr), slog.Any("provider", kind)) 38 | } else { 39 | slog.Info("Metrics client loaded", slog.Any("provider", kind)) 40 | return statsClient 41 | } 42 | } 43 | 44 | return NullMetricsProvider{} 45 | } 46 | -------------------------------------------------------------------------------- /lib/telemetry/metrics/stats_test.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/artie-labs/transfer/lib/config" 8 | "github.com/artie-labs/transfer/lib/config/constants" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestExporterKindValid(t *testing.T) { 13 | exporterKindToResultsMap := map[constants.ExporterKind]bool{ 14 | constants.Datadog: true, 15 | constants.ExporterKind("daaaa"): false, 16 | constants.ExporterKind("daaaa231321"): false, 17 | constants.ExporterKind("honeycomb.io"): false, 18 | } 19 | 20 | for exporterKind, expectedResults := range exporterKindToResultsMap { 21 | assert.Equal(t, expectedResults, exporterKindValid(exporterKind), 22 | fmt.Sprintf("kind: %v should have been %v", exporterKind, expectedResults)) 23 | } 24 | } 25 | 26 | func TestLoadExporter(t *testing.T) { 27 | // Datadog should not be a NullMetricsProvider 28 | exporterKindToResultMap := map[constants.ExporterKind]bool{ 29 | constants.Datadog: false, 30 | constants.ExporterKind("invalid"): true, 31 | } 32 | 33 | for kind, result := range exporterKindToResultMap { 34 | // Wipe and create a new ctx per run 35 | cfg := config.Config{ 36 | Telemetry: struct { 37 | Metrics struct { 38 | Provider constants.ExporterKind `yaml:"provider"` 39 | Settings map[string]any `yaml:"settings,omitempty"` 40 | } 41 | }{ 42 | Metrics: struct { 43 | Provider constants.ExporterKind `yaml:"provider"` 44 | Settings map[string]any `yaml:"settings,omitempty"` 45 | }{ 46 | Provider: kind, 47 | Settings: map[string]any{ 48 | "url": "localhost:8125", 49 | }, 50 | }, 51 | }, 52 | } 53 | 54 | client := LoadExporter(cfg) 55 | _, isOk := client.(NullMetricsProvider) 56 | assert.Equal(t, result, isOk) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /lib/typing/README.md: -------------------------------------------------------------------------------- 1 | # Typing 2 | 3 | Typing is a core utility within Transfer, as such - we have created a lot of utilities and strayed away from using other client libraries as much as possible. 4 | 5 | Once our schema detection detects a change, we will need to take the first not-null value from the CDC stream and infer the type. 6 | This is where our library comes in: 7 | * We will figure out the type (we support a variety of date time formats) 8 | * Based on the type, we will then call DWH and create a column with the inferred type. 9 | * This is necessary as there are transactional DBs that are schemaless (MongoDB, Bigtable, DynamoDB to name a few...) 10 | 11 | ## Performance 12 | 13 | As part of this being a core utility within Artie, we decided to write our own Typing library.
14 | Below, you can see the difference between Artie and Reflect (which is our baseline). 15 | 16 | ``` 17 | > make 18 | 19 | BenchmarkParseValueIntegerArtie-8 1000000000 2.804 ns/op 20 | BenchmarkParseValueIntegerGo-8 1000000000 4.788 ns/op 21 | BenchmarkParseValueBooleanArtie-8 1000000000 2.656 ns/op 22 | BenchmarkParseValueBooleanGo-8 1000000000 5.042 ns/op 23 | BenchmarkParseValueFloatArtie-8 1000000000 2.684 ns/op 24 | BenchmarkParseValueFloatGo-8 1000000000 4.784 ns/op 25 | ``` 26 | -------------------------------------------------------------------------------- /lib/typing/assert.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import "fmt" 4 | 5 | func AssertType[T any](val any) (T, error) { 6 | castedVal, isOk := val.(T) 7 | if !isOk { 8 | var zero T 9 | return zero, fmt.Errorf("expected type %T, got %T", zero, val) 10 | } 11 | return castedVal, nil 12 | } 13 | 14 | // AssertTypeOptional - will return zero if the value is nil, otherwise it will assert the type 15 | func AssertTypeOptional[T any](val any) (T, error) { 16 | var zero T 17 | if val == nil { 18 | return zero, nil 19 | } 20 | 21 | return AssertType[T](val) 22 | } 23 | -------------------------------------------------------------------------------- /lib/typing/assert_test.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestAssertType(t *testing.T) { 10 | { 11 | // String to string 12 | val, err := AssertType[string]("hello") 13 | assert.NoError(t, err) 14 | assert.Equal(t, "hello", val) 15 | } 16 | { 17 | // Int to string 18 | _, err := AssertType[string](1) 19 | assert.ErrorContains(t, err, "expected type string, got int") 20 | } 21 | { 22 | // Boolean to boolean 23 | val, err := AssertType[bool](true) 24 | assert.NoError(t, err) 25 | assert.Equal(t, true, val) 26 | } 27 | { 28 | // String to boolean 29 | _, err := AssertType[bool]("true") 30 | assert.ErrorContains(t, err, "expected type bool, got string") 31 | } 32 | } 33 | 34 | func TestAssertTypeOptional(t *testing.T) { 35 | { 36 | // String to string 37 | val, err := AssertTypeOptional[string]("hello") 38 | assert.NoError(t, err) 39 | assert.Equal(t, "hello", val) 40 | } 41 | { 42 | // Nil to string 43 | val, err := AssertTypeOptional[string](nil) 44 | assert.NoError(t, err) 45 | assert.Equal(t, "", val) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /lib/typing/columns/diff.go: -------------------------------------------------------------------------------- 1 | package columns 2 | 3 | import ( 4 | "slices" 5 | "strings" 6 | 7 | "github.com/artie-labs/transfer/lib/config/constants" 8 | "github.com/artie-labs/transfer/lib/maputil" 9 | ) 10 | 11 | func shouldSkipColumn(colName string, columnsToKeep []string) bool { 12 | if slices.Contains(columnsToKeep, colName) { 13 | return false 14 | } 15 | 16 | if colName == constants.OnlySetDeleteColumnMarker { 17 | // We never want to create this column in the destination table 18 | return true 19 | } 20 | 21 | return strings.Contains(colName, constants.ArtiePrefix) 22 | } 23 | 24 | type DiffResults struct { 25 | SourceColumnsMissing []Column 26 | TargetColumnsMissing []Column 27 | } 28 | 29 | func Diff(sourceColumns []Column, targetColumns []Column) DiffResults { 30 | src := buildColumnsMap(sourceColumns) 31 | targ := buildColumnsMap(targetColumns) 32 | 33 | for _, colName := range src.Keys() { 34 | if _, isOk := targ.Get(colName); isOk { 35 | targ.Remove(colName) 36 | src.Remove(colName) 37 | } 38 | } 39 | 40 | var targetColumnsMissing []Column 41 | for _, col := range src.All() { 42 | targetColumnsMissing = append(targetColumnsMissing, col) 43 | } 44 | 45 | var sourceColumnsMissing []Column 46 | for _, col := range targ.All() { 47 | sourceColumnsMissing = append(sourceColumnsMissing, col) 48 | } 49 | 50 | return DiffResults{ 51 | SourceColumnsMissing: sourceColumnsMissing, 52 | TargetColumnsMissing: targetColumnsMissing, 53 | } 54 | } 55 | 56 | func filterColumns(columns []Column, columnsToKeep []string) []Column { 57 | var filteredColumns []Column 58 | for _, col := range columns { 59 | if shouldSkipColumn(col.Name(), columnsToKeep) { 60 | continue 61 | } 62 | 63 | filteredColumns = append(filteredColumns, col) 64 | } 65 | 66 | return filteredColumns 67 | } 68 | 69 | // DiffAndFilter - will diff the columns and filter out any Artie metadata columns that should not exist in the target table. 70 | func DiffAndFilter(columnsInSource []Column, columnsInDestination []Column, columnsToKeep []string) ([]Column, []Column) { 71 | diffResult := Diff(columnsInSource, columnsInDestination) 72 | return filterColumns(diffResult.SourceColumnsMissing, columnsToKeep), filterColumns(diffResult.TargetColumnsMissing, columnsToKeep) 73 | } 74 | 75 | func buildColumnsMap(cols []Column) *maputil.OrderedMap[Column] { 76 | retMap := maputil.NewOrderedMap[Column](false) 77 | for _, col := range cols { 78 | retMap.Add(col.name, col) 79 | } 80 | 81 | return retMap 82 | } 83 | -------------------------------------------------------------------------------- /lib/typing/converters/primitives/converter.go: -------------------------------------------------------------------------------- 1 | package primitives 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | ) 7 | 8 | type Int64Converter struct{} 9 | 10 | func (Int64Converter) Convert(value any) (int64, error) { 11 | switch castValue := value.(type) { 12 | case string: 13 | parsed, err := strconv.ParseInt(castValue, 10, 64) 14 | if err != nil { 15 | return 0, fmt.Errorf("failed to parse string to int64: %w", err) 16 | } 17 | return parsed, nil 18 | case int16: 19 | return int64(castValue), nil 20 | case int32: 21 | return int64(castValue), nil 22 | case int: 23 | return int64(castValue), nil 24 | case int64: 25 | return castValue, nil 26 | } 27 | return 0, fmt.Errorf("expected string/int/int16/int32/int64 got %T with value: %v", value, value) 28 | } 29 | -------------------------------------------------------------------------------- /lib/typing/converters/primitives/converter_test.go: -------------------------------------------------------------------------------- 1 | package primitives 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestInt64Converter_Convert(t *testing.T) { 10 | converter := Int64Converter{} 11 | { 12 | // Test converting valid string to int64 13 | got, err := converter.Convert("123") 14 | assert.NoError(t, err) 15 | assert.Equal(t, int64(123), got) 16 | } 17 | { 18 | // Test error handling for invalid string 19 | got, err := converter.Convert("not a number") 20 | assert.Error(t, err) 21 | assert.Equal(t, int64(0), got) 22 | } 23 | { 24 | // Test converting int16 to int64 25 | got, err := converter.Convert(int16(123)) 26 | assert.NoError(t, err) 27 | assert.Equal(t, int64(123), got) 28 | } 29 | { 30 | // Test converting int32 to int64 31 | got, err := converter.Convert(int32(456)) 32 | assert.NoError(t, err) 33 | assert.Equal(t, int64(456), got) 34 | } 35 | { 36 | // Test converting int to int64 37 | got, err := converter.Convert(789) 38 | assert.NoError(t, err) 39 | assert.Equal(t, int64(789), got) 40 | } 41 | { 42 | // Test converting int64 to int64 43 | got, err := converter.Convert(int64(101112)) 44 | assert.NoError(t, err) 45 | assert.Equal(t, int64(101112), got) 46 | } 47 | { 48 | // Test error handling for unsupported type 49 | got, err := converter.Convert(float64(123.45)) 50 | assert.Error(t, err) 51 | assert.Equal(t, int64(0), got) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /lib/typing/converters/util.go: -------------------------------------------------------------------------------- 1 | package converters 2 | 3 | import "strconv" 4 | 5 | func Float64ToString(value float64) string { 6 | return strconv.FormatFloat(value, 'f', -1, 64) 7 | } 8 | 9 | func Float32ToString(value float32) string { 10 | return strconv.FormatFloat(float64(value), 'f', -1, 32) 11 | } 12 | 13 | func BooleanToBit(val bool) int { 14 | if val { 15 | return 1 16 | } else { 17 | return 0 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /lib/typing/decimal/base.go: -------------------------------------------------------------------------------- 1 | package decimal 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/numbers" 7 | ) 8 | 9 | func (d Details) isNumeric() bool { 10 | if d.precision == PrecisionNotSpecified { 11 | return false 12 | } 13 | 14 | // 0 <= s <= 9 15 | if !numbers.BetweenEq(0, 9, d.scale) { 16 | return false 17 | } 18 | 19 | // max(1,s) <= p <= s + 29 20 | return numbers.BetweenEq(max(1, d.scale), d.scale+29, d.precision) 21 | } 22 | 23 | func (d Details) isBigNumeric() bool { 24 | if d.precision == PrecisionNotSpecified { 25 | return false 26 | } 27 | 28 | // 0 <= s <= 38 29 | if !numbers.BetweenEq(0, 38, d.scale) { 30 | return false 31 | } 32 | 33 | // max(1,s) <= p <= s + 38 34 | return numbers.BetweenEq(max(1, d.scale), d.scale+38, d.precision) 35 | } 36 | 37 | func (d Details) toKind(maxPrecision int32, exceededKind string) string { 38 | if d.precision > maxPrecision || d.precision == PrecisionNotSpecified { 39 | return exceededKind 40 | } 41 | 42 | return fmt.Sprintf("NUMERIC(%d, %d)", d.precision, d.scale) 43 | } 44 | 45 | func (d Details) toDecimalKind(maxPrecision int32, exceededKind string) string { 46 | if d.precision > maxPrecision || d.precision == PrecisionNotSpecified { 47 | return exceededKind 48 | } 49 | 50 | return fmt.Sprintf("DECIMAL(%d, %d)", d.precision, d.scale) 51 | } 52 | -------------------------------------------------------------------------------- /lib/typing/decimal/base_test.go: -------------------------------------------------------------------------------- 1 | package decimal 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDecimal_IsNumeric(t *testing.T) { 10 | { 11 | // Valid numeric with small scale 12 | assert.True(t, NewDetails(10, 2).isNumeric(), "should be valid numeric with small scale") 13 | } 14 | { 15 | // Valid numeric with max scale 16 | assert.True(t, NewDetails(38, 9).isNumeric(), "should be valid numeric with max scale") 17 | } 18 | { 19 | // Invalid - precision not specified 20 | assert.False(t, NewDetails(PrecisionNotSpecified, 2).isNumeric(), "should be invalid when precision is not specified") 21 | } 22 | { 23 | // Invalid - scale too large 24 | assert.False(t, NewDetails(10, 10).isNumeric(), "should be invalid when scale is too large") 25 | } 26 | { 27 | // Valid - precision equals scale 28 | assert.True(t, NewDetails(2, 2).isNumeric(), "should be valid when precision equals scale") 29 | } 30 | { 31 | // Invalid - precision too large 32 | assert.False(t, NewDetails(40, 2).isNumeric(), "should be invalid when precision is too large") 33 | } 34 | { 35 | // Valid - minimum valid case 36 | assert.True(t, NewDetails(1, 0).isNumeric(), "should be valid with minimum precision and scale") 37 | } 38 | { 39 | // Valid - scale equals precision 40 | assert.True(t, NewDetails(5, 5).isNumeric(), "should be valid when scale equals precision") 41 | } 42 | } 43 | 44 | func TestDecimal_IsBigNumeric(t *testing.T) { 45 | { 46 | // Valid bignumeric with small scale 47 | assert.True(t, NewDetails(40, 2).isBigNumeric(), "should be valid bignumeric with small scale") 48 | } 49 | { 50 | // Valid bignumeric with max scale 51 | assert.True(t, NewDetails(76, 38).isBigNumeric(), "should be valid bignumeric with max scale") 52 | } 53 | { 54 | // Invalid - precision not specified 55 | assert.False(t, NewDetails(PrecisionNotSpecified, 2).isBigNumeric(), "should be invalid when precision is not specified") 56 | } 57 | { 58 | // Invalid - scale too large 59 | assert.False(t, NewDetails(77, 39).isBigNumeric(), "should be invalid when scale is too large") 60 | } 61 | { 62 | // Valid - numeric precision 63 | assert.True(t, NewDetails(38, 2).isBigNumeric(), "should be valid with numeric precision") 64 | } 65 | { 66 | // Valid - precision equals scale 67 | assert.True(t, NewDetails(40, 2).isBigNumeric(), "should be valid when precision equals scale") 68 | } 69 | { 70 | // Valid - scale equals max 71 | assert.True(t, NewDetails(40, 38).isBigNumeric(), "should be valid when scale equals max allowed") 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /lib/typing/decimal/decimal.go: -------------------------------------------------------------------------------- 1 | package decimal 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/cockroachdb/apd/v3" 7 | ) 8 | 9 | const PrecisionNotSpecified int32 = -1 10 | 11 | // Decimal is Artie's wrapper around [*apd.Decimal] which can store large numbers w/ no precision loss. 12 | type Decimal struct { 13 | precision int32 14 | value *apd.Decimal 15 | } 16 | 17 | func (d Decimal) MarshalJSON() ([]byte, error) { 18 | return json.Marshal(d.String()) 19 | } 20 | 21 | func NewDecimalWithPrecision(value *apd.Decimal, precision int32) *Decimal { 22 | return &Decimal{ 23 | precision: precision, 24 | value: value, 25 | } 26 | } 27 | 28 | func NewDecimal(value *apd.Decimal) *Decimal { 29 | return NewDecimalWithPrecision(value, PrecisionNotSpecified) 30 | } 31 | 32 | func (d *Decimal) Value() *apd.Decimal { 33 | return d.value 34 | } 35 | 36 | // String() is used to override fmt.Sprint(val), where val type is *decimal.Decimal 37 | // This is particularly useful for Snowflake because we're writing all the values as STRINGS into TSV format. 38 | // This function guarantees backwards compatibility. 39 | func (d *Decimal) String() string { 40 | return d.value.Text('f') 41 | } 42 | 43 | func (d *Decimal) Details() Details { 44 | return NewDetails(d.precision, -d.value.Exponent) 45 | } 46 | -------------------------------------------------------------------------------- /lib/typing/errors.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import "errors" 4 | 5 | type UnsupportedDataTypeError struct { 6 | message string 7 | } 8 | 9 | func NewUnsupportedDataTypeError(message string) UnsupportedDataTypeError { 10 | return UnsupportedDataTypeError{message: message} 11 | } 12 | 13 | func (u UnsupportedDataTypeError) Error() string { 14 | return u.message 15 | } 16 | 17 | func IsUnsupportedDataTypeError(err error) bool { 18 | return errors.As(err, &UnsupportedDataTypeError{}) 19 | } 20 | -------------------------------------------------------------------------------- /lib/typing/errors_test.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestUnsupportedDataTypeError(t *testing.T) { 11 | assert.True(t, IsUnsupportedDataTypeError(NewUnsupportedDataTypeError("foo"))) 12 | 13 | // Not relevant 14 | assert.False(t, IsUnsupportedDataTypeError(fmt.Errorf("foo"))) 15 | 16 | // Nil 17 | assert.False(t, IsUnsupportedDataTypeError(nil)) 18 | } 19 | -------------------------------------------------------------------------------- /lib/typing/ext/variables.go: -------------------------------------------------------------------------------- 1 | package ext 2 | 3 | import "time" 4 | 5 | const ( 6 | PostgresTimeFormat = PostgresTimeFormatNoTZ + TimezoneOffsetFormat 7 | PostgresTimeFormatNoTZ = "15:04:05.999999" // microsecond precision, used because certain destinations do not like `Time` types to specify tz locale 8 | ) 9 | 10 | var supportedDateTimeLayouts = []string{ 11 | // RFC 3339 12 | time.RFC3339, 13 | time.RFC3339Nano, 14 | RFC3339Millisecond, 15 | RFC3339Microsecond, 16 | RFC3339Nanosecond, 17 | // Others 18 | "2006-01-02T15:04:05.999999999-07:00", 19 | "2006-01-02T15:04:05.000-07:00", 20 | time.Layout, 21 | time.ANSIC, 22 | time.UnixDate, 23 | time.RubyDate, 24 | time.RFC822, 25 | time.RFC822Z, 26 | time.RFC850, 27 | time.RFC1123, 28 | time.RFC1123Z, 29 | } 30 | 31 | var supportedDateFormats = []string{ 32 | time.DateOnly, 33 | } 34 | 35 | var SupportedTimeFormats = []string{ 36 | PostgresTimeFormat, 37 | PostgresTimeFormatNoTZ, 38 | } 39 | 40 | const TimezoneOffsetFormat = "Z07:00" 41 | 42 | // RFC3339 variants 43 | const ( 44 | // Max precision up to microseconds (will trim away the trailing zeros) 45 | RFC3339MicroTZNoTZ = "2006-01-02T15:04:05.999999" 46 | RFC3339MicroTZ = RFC3339MicroTZNoTZ + TimezoneOffsetFormat 47 | 48 | RFC3339NoTZ = "2006-01-02T15:04:05.999999999" 49 | 50 | RFC3339MillisecondNoTZ = "2006-01-02T15:04:05.000" 51 | RFC3339Millisecond = RFC3339MillisecondNoTZ + TimezoneOffsetFormat 52 | 53 | RFC3339MicrosecondNoTZ = "2006-01-02T15:04:05.000000" 54 | RFC3339Microsecond = RFC3339MicrosecondNoTZ + TimezoneOffsetFormat 55 | 56 | RFC3339NanosecondNoTZ = "2006-01-02T15:04:05.000000000" 57 | RFC3339Nanosecond = RFC3339NanosecondNoTZ + TimezoneOffsetFormat 58 | ) 59 | -------------------------------------------------------------------------------- /lib/typing/ext/variables_test.go: -------------------------------------------------------------------------------- 1 | package ext 2 | 3 | import "testing" 4 | 5 | func TestSupportedDateTimeLayoutsUniqueness(t *testing.T) { 6 | layouts := make(map[string]bool) 7 | for _, layout := range supportedDateTimeLayouts { 8 | if _, ok := layouts[layout]; ok { 9 | t.Errorf("layout %q is duplicated", layout) 10 | } 11 | layouts[layout] = true 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /lib/typing/mongo/bson_bench_test.go: -------------------------------------------------------------------------------- 1 | package mongo 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func BenchmarkJSONEToMap(b *testing.B) { 10 | bsonData := []byte(` 11 | { 12 | "_id": { 13 | "$numberLong": "10004" 14 | }, 15 | "order_date": { 16 | "$date": 1456012800000 17 | }, 18 | "purchaser_id": { 19 | "$numberLong": "1003" 20 | }, 21 | "number_int": 30, 22 | "quantity": 1, 23 | "product_id": { 24 | "$numberLong": "107" 25 | }, 26 | "profilePic": { 27 | "$binary": "123456ABCDEF", 28 | "$type": "00" 29 | }, 30 | "compiledFunction": { 31 | "$binary": "cHJpbnQoJ0hlbGxvIFdvcmxkJyk=", 32 | "$type": "01" 33 | }, 34 | "unique_id": { 35 | "$binary": "hW5W/8uwQR6FWpiwi4dRQA==", 36 | "$type": "04" 37 | }, 38 | "fileChecksum": { 39 | "$binary": "1B2M2Y8AsgTpgAmY7PhCfg==", 40 | "$type": "05" 41 | }, 42 | "secureData": { 43 | "$binary": "YWJjZGVmZ2hpamtsbW5vcA==", 44 | "$type": "06" 45 | }, 46 | "full_name": "Robin Tang", 47 | "test_bool_false": false, 48 | "test_bool_true": true, 49 | "object_id": {"$oid": "63793b4014f7f28f570c524e"}, 50 | "test_decimal": {"$numberDecimal": "13.37"}, 51 | "test_decimal_2": 13.37, 52 | "test_int": 1337, 53 | "test_foo": "bar", 54 | "test_null": null, 55 | "test_list": [1.0,2.0,3.0,4.0,"hello"], 56 | "test_nested_object": { 57 | "a": { 58 | "b": { 59 | "c": "hello" 60 | } 61 | } 62 | }, 63 | "test_timestamp": { 64 | "$timestamp": { "t": 1678929517, "i": 1 } 65 | }, 66 | "test_nan": NaN, 67 | "test_nan_string": "NaN", 68 | "test_nan_string33": "NaNaNaNa", 69 | "test_infinity": Infinity, 70 | "test_infinity_string": "Infinity", 71 | "test_infinity_string1": "Infinity123", 72 | "test_negative_infinity": -Infinity, 73 | "test_negative_infinity_string": "-Infinity", 74 | "test_negative_infinity_string1": "-Infinity123", 75 | "maxValue": {"$maxKey": 1}, 76 | "minValue": {"$minKey": 1}, 77 | "calcDiscount": {"$code": "function() {return 0.10;}"}, 78 | "emailPattern": {"$regex": "@example\\.com$","$options": ""} 79 | }`) 80 | 81 | for i := 0; i < b.N; i++ { 82 | _, err := JSONEToMap(bsonData) 83 | assert.NoError(b, err) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /lib/typing/numeric.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/artie-labs/transfer/lib/typing/decimal" 9 | ) 10 | 11 | func ParseNumeric(parts []string) (KindDetails, error) { 12 | if len(parts) == 0 || len(parts) > 2 { 13 | return Invalid, fmt.Errorf("invalid number of parts: %d", len(parts)) 14 | } 15 | 16 | var parsedNumbers []int32 17 | for _, part := range parts { 18 | parsedNumber, err := strconv.ParseInt(strings.TrimSpace(part), 10, 32) 19 | if err != nil { 20 | return Invalid, fmt.Errorf("failed to parse number: %w", err) 21 | } 22 | 23 | parsedNumbers = append(parsedNumbers, int32(parsedNumber)) 24 | } 25 | 26 | // If scale is 0 or not specified, then number is an int. 27 | if len(parsedNumbers) == 1 || parsedNumbers[1] == 0 { 28 | return NewDecimalDetailsFromTemplate( 29 | EDecimal, 30 | decimal.NewDetails(parsedNumbers[0], 0), 31 | ), nil 32 | } 33 | 34 | return NewDecimalDetailsFromTemplate( 35 | EDecimal, 36 | decimal.NewDetails(parsedNumbers[0], parsedNumbers[1]), 37 | ), nil 38 | } 39 | -------------------------------------------------------------------------------- /lib/typing/parse.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "time" 7 | 8 | "github.com/artie-labs/transfer/lib/typing/decimal" 9 | ) 10 | 11 | // MustParseValue - panics if the value cannot be parsed. This is used only for tests. 12 | func MustParseValue(key string, optionalSchema map[string]KindDetails, val any) KindDetails { 13 | kindDetail, err := ParseValue(key, optionalSchema, val) 14 | if err != nil { 15 | panic(err) 16 | } 17 | 18 | return kindDetail 19 | } 20 | 21 | func ParseValue(key string, optionalSchema map[string]KindDetails, val any) (KindDetails, error) { 22 | if kindDetail, isOk := optionalSchema[key]; isOk { 23 | return kindDetail, nil 24 | } 25 | 26 | switch convertedVal := val.(type) { 27 | case nil: 28 | return Invalid, nil 29 | case uint, int, uint8, uint16, uint32, uint64, int8, int16, int32, int64: 30 | return Integer, nil 31 | case float32, float64: 32 | // Integers will be parsed as Floats if they come from JSON 33 | // This is a limitation with JSON - https://github.com/golang/go/issues/56719 34 | // UNLESS Transfer is provided with a schema object, and we deliberately typecast the value to an integer 35 | // before calling ParseValue(). 36 | return Float, nil 37 | case bool: 38 | return Boolean, nil 39 | case string: 40 | if IsJSON(convertedVal) { 41 | return Struct, nil 42 | } 43 | 44 | return String, nil 45 | case *decimal.Decimal: 46 | extendedDetails := convertedVal.Details() 47 | return KindDetails{ 48 | Kind: EDecimal.Kind, 49 | ExtendedDecimalDetails: &extendedDetails, 50 | }, nil 51 | case time.Time: 52 | return TimestampTZ, nil 53 | default: 54 | // Check if the val is one of our custom-types 55 | if reflect.TypeOf(val).Kind() == reflect.Slice { 56 | return Array, nil 57 | } else if reflect.TypeOf(val).Kind() == reflect.Map { 58 | return Struct, nil 59 | } 60 | } 61 | 62 | return Invalid, fmt.Errorf("unknown type: %T, value: %v", val, val) 63 | } 64 | -------------------------------------------------------------------------------- /lib/typing/ptr.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | func ToPtr[T any](v T) *T { 4 | return &v 5 | } 6 | 7 | func DefaultValueFromPtr[T any](value *T, defaultValue T) T { 8 | if value == nil { 9 | return defaultValue 10 | } 11 | 12 | return *value 13 | } 14 | -------------------------------------------------------------------------------- /lib/typing/ptr_test.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDefaultValueFromPtr(t *testing.T) { 10 | { 11 | // ptr is not set 12 | assert.Equal(t, int32(5), DefaultValueFromPtr[int32](nil, int32(5))) 13 | } 14 | { 15 | // ptr is set 16 | assert.Equal(t, int32(10), DefaultValueFromPtr[int32](ToPtr(int32(10)), int32(5))) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /lib/typing/typing_bench_test.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func BenchmarkLargeMapLengthQuery(b *testing.B) { 10 | retMap := make(map[string]any) 11 | for i := 0; i < 15000; i++ { 12 | retMap[fmt.Sprintf("key-%v", i)] = true 13 | } 14 | 15 | for n := 0; n < b.N; n++ { 16 | _ = uint(len(retMap)) 17 | } 18 | } 19 | 20 | func BenchmarkLargeMapLengthQuery_WithMassiveValues(b *testing.B) { 21 | retMap := make(map[string]any) 22 | for i := 0; i < 15000; i++ { 23 | retMap[fmt.Sprintf("key-%v", i)] = map[string]any{ 24 | "foo": "bar", 25 | "hello": "world", 26 | "true": true, 27 | "false": false, 28 | "array": []string{"abc", "def"}, 29 | } 30 | } 31 | 32 | for n := 0; n < b.N; n++ { 33 | _ = uint(len(retMap)) 34 | } 35 | } 36 | 37 | func BenchmarkParseValueIntegerArtie(b *testing.B) { 38 | for n := 0; n < b.N; n++ { 39 | MustParseValue("", nil, 45456312) 40 | } 41 | } 42 | 43 | func BenchmarkParseValueIntegerGo(b *testing.B) { 44 | for n := 0; n < b.N; n++ { 45 | reflect.TypeOf(45456312).Kind() 46 | } 47 | } 48 | 49 | func BenchmarkParseValueBooleanArtie(b *testing.B) { 50 | for n := 0; n < b.N; n++ { 51 | MustParseValue("", nil, true) 52 | } 53 | } 54 | 55 | func BenchmarkParseValueBooleanGo(b *testing.B) { 56 | for n := 0; n < b.N; n++ { 57 | reflect.TypeOf(true).Kind() 58 | } 59 | } 60 | 61 | func BenchmarkParseValueFloatArtie(b *testing.B) { 62 | for n := 0; n < b.N; n++ { 63 | MustParseValue("", nil, 7.44) 64 | } 65 | } 66 | 67 | func BenchmarkParseValueFloatGo(b *testing.B) { 68 | for n := 0; n < b.N; n++ { 69 | reflect.TypeOf(7.44).Kind() 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /lib/typing/typing_test.go: -------------------------------------------------------------------------------- 1 | package typing 2 | 3 | import ( 4 | "encoding/json" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func Test_IsJSON(t *testing.T) { 12 | { 13 | invalidValues := []string{ 14 | `{"hello": "world"`, 15 | `{"hello": "world"}}`, 16 | `{null}`, 17 | "", 18 | "foo", 19 | " ", 20 | "{", 21 | "[", 22 | "12345", 23 | } 24 | 25 | for _, invalidValue := range invalidValues { 26 | assert.False(t, IsJSON(invalidValue), invalidValue) 27 | } 28 | } 29 | { 30 | validValues := []string{ 31 | "{}", 32 | `{"hello": "world"}`, 33 | `{ 34 | "hello": { 35 | "world": { 36 | "nested_value": true 37 | } 38 | }, 39 | "add_a_list_here": [1, 2, 3, 4], 40 | "number": 7.5, 41 | "integerNum": 7 42 | }`, 43 | "[]", 44 | "[1, 2, 3, 4]", 45 | } 46 | 47 | for _, validValue := range validValues { 48 | assert.True(t, IsJSON(validValue), validValue) 49 | } 50 | } 51 | } 52 | 53 | func BenchmarkIsJSON(b *testing.B) { 54 | values := []string{"hello world", `{"hello": "world"}`, `{"hello": "world"}}`, `{null}`, "", "foo", " ", "12345"} 55 | b.Run("OldMethod", func(b *testing.B) { 56 | for i := 0; i < b.N; i++ { 57 | for _, value := range values { 58 | oldIsJSON(value) 59 | } 60 | } 61 | }) 62 | 63 | b.Run("NewMethod", func(b *testing.B) { 64 | for i := 0; i < b.N; i++ { 65 | for _, value := range values { 66 | IsJSON(value) 67 | } 68 | } 69 | }) 70 | } 71 | 72 | func oldIsJSON(str string) bool { 73 | str = strings.TrimSpace(str) 74 | if len(str) < 2 { 75 | return false 76 | } 77 | 78 | valStringChars := []rune(str) 79 | firstChar := string(valStringChars[0]) 80 | lastChar := string(valStringChars[len(valStringChars)-1]) 81 | 82 | if (firstChar == "{" && lastChar == "}") || (firstChar == "[" && lastChar == "]") { 83 | var js json.RawMessage 84 | return json.Unmarshal([]byte(str), &js) == nil 85 | } 86 | 87 | return false 88 | } 89 | -------------------------------------------------------------------------------- /lib/typing/values/string.go: -------------------------------------------------------------------------------- 1 | package values 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/artie-labs/transfer/lib/typing" 7 | "github.com/artie-labs/transfer/lib/typing/converters" 8 | ) 9 | 10 | func ToStringOpts(colVal any, colKind typing.KindDetails, opts converters.GetStringConverterOpts) (string, error) { 11 | if colVal == nil { 12 | return "", fmt.Errorf("colVal is nil") 13 | } 14 | 15 | sv, err := converters.GetStringConverter(colKind, opts) 16 | if err != nil { 17 | return "", fmt.Errorf("failed to get string converter: %w", err) 18 | } 19 | 20 | value, err := sv.Convert(colVal) 21 | if err != nil { 22 | return "", fmt.Errorf("converter %T failed to convert value: %w", sv, err) 23 | } 24 | 25 | return value, nil 26 | } 27 | 28 | func ToString(colVal any, colKind typing.KindDetails) (string, error) { 29 | return ToStringOpts(colVal, colKind, converters.GetStringConverterOpts{}) 30 | } 31 | -------------------------------------------------------------------------------- /models/event/events_suite_test.go: -------------------------------------------------------------------------------- 1 | package event 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/artie-labs/transfer/lib/typing/columns" 7 | 8 | "github.com/artie-labs/transfer/lib/config/constants" 9 | 10 | "github.com/artie-labs/transfer/lib/mocks" 11 | 12 | "github.com/artie-labs/transfer/models" 13 | 14 | "github.com/artie-labs/transfer/lib/config" 15 | 16 | "github.com/stretchr/testify/suite" 17 | ) 18 | 19 | type EventsTestSuite struct { 20 | suite.Suite 21 | cfg config.Config 22 | db *models.DatabaseData 23 | fakeEvent *mocks.FakeEvent 24 | } 25 | 26 | func (e *EventsTestSuite) SetupTest() { 27 | e.cfg = config.Config{ 28 | FlushIntervalSeconds: 10, 29 | FlushSizeKb: 1024, 30 | BufferRows: 1000, 31 | } 32 | e.db = models.NewMemoryDB() 33 | 34 | fakeEvent := &mocks.FakeEvent{} 35 | fakeEvent.GetDataReturns(map[string]any{constants.DeleteColumnMarker: false, constants.OnlySetDeleteColumnMarker: false}, nil) 36 | fakeEvent.GetColumnsReturns(&columns.Columns{}, nil) 37 | fakeEvent.GetTableNameReturns("foo") 38 | e.fakeEvent = fakeEvent 39 | } 40 | 41 | func TestEventsTestSuite(t *testing.T) { 42 | suite.Run(t, new(EventsTestSuite)) 43 | } 44 | -------------------------------------------------------------------------------- /models/memory.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/artie-labs/transfer/lib/optimization" 8 | ) 9 | 10 | // TableData is a wrapper around *optimization.TableData which stores the actual underlying tableData. 11 | // The wrapper here is just to have a mutex. Any of the ptr methods on *TableData will require callers to use their own locks. 12 | // We did this because certain operations require different locking patterns 13 | type TableData struct { 14 | *optimization.TableData 15 | lastFlushTime time.Time 16 | sync.Mutex 17 | } 18 | 19 | func (t *TableData) Wipe() { 20 | t.TableData = nil 21 | t.lastFlushTime = time.Now() 22 | } 23 | 24 | // ShouldSkipFlush - this function is only used when the flush reason was time-based. 25 | // We want to add this in so that it can strike a balance between the Flush and Consumer go-routines on when to merge. 26 | // Say our flush interval is 5 min, and it flushed 4 min ago based on size or rows - we don't want to flush right after since the buffer would be mostly empty. 27 | func (t *TableData) ShouldSkipFlush(cooldown time.Duration) bool { 28 | if cooldown > 1*time.Minute { 29 | confidenceInterval := 0.25 30 | confidenceDuration := time.Duration(confidenceInterval * float64(cooldown)) 31 | 32 | // Subtract the confidenceDuration from the cooldown to get the adjusted cooldown 33 | cooldown = cooldown - confidenceDuration 34 | } 35 | 36 | return time.Since(t.lastFlushTime) < cooldown 37 | } 38 | 39 | func (t *TableData) Empty() bool { 40 | return t.TableData == nil 41 | } 42 | 43 | func (t *TableData) SetTableData(td *optimization.TableData) { 44 | t.TableData = td 45 | } 46 | 47 | type DatabaseData struct { 48 | tableData map[string]*TableData 49 | sync.RWMutex 50 | } 51 | 52 | func NewMemoryDB() *DatabaseData { 53 | tableData := make(map[string]*TableData) 54 | return &DatabaseData{ 55 | tableData: tableData, 56 | } 57 | } 58 | 59 | func (d *DatabaseData) GetOrCreateTableData(tableName string) *TableData { 60 | d.Lock() 61 | defer d.Unlock() 62 | 63 | table, exists := d.tableData[tableName] 64 | if !exists { 65 | table = &TableData{ 66 | Mutex: sync.Mutex{}, 67 | } 68 | d.tableData[tableName] = table 69 | } 70 | 71 | return table 72 | } 73 | 74 | func (d *DatabaseData) ClearTableConfig(tableName string) { 75 | d.Lock() 76 | defer d.Unlock() 77 | d.tableData[tableName].Wipe() 78 | } 79 | 80 | func (d *DatabaseData) TableData() map[string]*TableData { 81 | return d.tableData 82 | } 83 | -------------------------------------------------------------------------------- /models/memory_flush_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/artie-labs/transfer/lib/optimization" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestShouldSkipMerge(t *testing.T) { 12 | // 5 seconds 13 | coolDown := 5 * time.Second 14 | checkInterval := 200 * time.Millisecond 15 | 16 | td := TableData{ 17 | TableData: &optimization.TableData{}, 18 | } 19 | 20 | // Before wiping, we should not skip the flush since ts did not get set yet. 21 | assert.False(t, td.ShouldSkipFlush(coolDown)) 22 | td.Wipe() 23 | for i := 0; i < 10; i++ { 24 | assert.True(t, td.ShouldSkipFlush(coolDown)) 25 | time.Sleep(checkInterval) 26 | } 27 | 28 | time.Sleep(3 * time.Second) 29 | assert.False(t, td.ShouldSkipFlush(coolDown)) 30 | 31 | // 5 minutes now 32 | coolDown = 5 * time.Minute 33 | now := time.Now() 34 | 35 | // We flushed 4 min ago, so let's test the confidence interval. 36 | td.lastFlushTime = now.Add(-4 * time.Minute) 37 | assert.False(t, td.ShouldSkipFlush(coolDown)) 38 | 39 | // Let's try if we flushed 2 min ago, we should skip. 40 | td.lastFlushTime = now.Add(-2 * time.Minute) 41 | assert.True(t, td.ShouldSkipFlush(coolDown)) 42 | } 43 | -------------------------------------------------------------------------------- /models/memory_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/artie-labs/transfer/lib/optimization" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTableData_Complete(t *testing.T) { 12 | db := NewMemoryDB() 13 | 14 | tableName := "table" 15 | 16 | // TableData does not exist 17 | _, isOk := db.TableData()[tableName] 18 | assert.False(t, isOk) 19 | 20 | td := db.GetOrCreateTableData(tableName) 21 | assert.True(t, td.Empty()) 22 | _, isOk = db.TableData()[tableName] 23 | assert.True(t, isOk) 24 | 25 | // Add the td struct 26 | td.SetTableData(&optimization.TableData{}) 27 | assert.False(t, td.Empty()) 28 | 29 | // Wipe via tableData.Wipe() 30 | td.Wipe() 31 | assert.True(t, td.Empty()) 32 | 33 | // Wipe via ClearTableConfig(...) 34 | td.SetTableData(&optimization.TableData{}) 35 | assert.False(t, td.Empty()) 36 | 37 | db.ClearTableConfig(tableName) 38 | assert.True(t, td.Empty()) 39 | } 40 | -------------------------------------------------------------------------------- /processes/consumer/configs.go: -------------------------------------------------------------------------------- 1 | package consumer 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "sync" 7 | 8 | "github.com/artie-labs/transfer/lib/artie" 9 | "github.com/artie-labs/transfer/lib/cdc" 10 | "github.com/artie-labs/transfer/lib/kafkalib" 11 | ) 12 | 13 | type TcFmtMap struct { 14 | tc map[string]TopicConfigFormatter 15 | sync.Mutex 16 | } 17 | 18 | func NewTcFmtMap() *TcFmtMap { 19 | return &TcFmtMap{ 20 | tc: make(map[string]TopicConfigFormatter), 21 | } 22 | } 23 | 24 | func (t *TcFmtMap) Add(topic string, fmt TopicConfigFormatter) { 25 | t.Lock() 26 | defer t.Unlock() 27 | t.tc[topic] = fmt 28 | } 29 | 30 | func (t *TcFmtMap) GetTopicFmt(topic string) (TopicConfigFormatter, bool) { 31 | t.Lock() 32 | defer t.Unlock() 33 | tcFmt, isOk := t.tc[topic] 34 | return tcFmt, isOk 35 | } 36 | 37 | type TopicConfigFormatter struct { 38 | tc kafkalib.TopicConfig 39 | cdc.Format 40 | } 41 | 42 | func commitOffset(ctx context.Context, topic string, partitionsToOffset map[string]artie.Message) error { 43 | for _, msg := range partitionsToOffset { 44 | if msg.KafkaMsg != nil { 45 | if err := topicToConsumer.Get(topic).CommitMessages(ctx, *msg.KafkaMsg); err != nil { 46 | return err 47 | } 48 | 49 | slog.Info("Successfully committed Kafka offset", slog.String("topic", topic), slog.Int("partition", msg.KafkaMsg.Partition), slog.Int64("offset", msg.KafkaMsg.Offset)) 50 | } 51 | } 52 | 53 | return nil 54 | } 55 | -------------------------------------------------------------------------------- /processes/consumer/flush_suite_test.go: -------------------------------------------------------------------------------- 1 | package consumer 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/artie-labs/transfer/lib/config" 7 | "github.com/artie-labs/transfer/lib/config/constants" 8 | "github.com/artie-labs/transfer/lib/destination" 9 | "github.com/artie-labs/transfer/lib/kafkalib" 10 | "github.com/artie-labs/transfer/lib/mocks" 11 | "github.com/artie-labs/transfer/models" 12 | "github.com/stretchr/testify/suite" 13 | ) 14 | 15 | func SetKafkaConsumer(_topicToConsumer map[string]kafkalib.Consumer) { 16 | topicToConsumer = &TopicToConsumer{ 17 | topicToConsumer: _topicToConsumer, 18 | } 19 | } 20 | 21 | type FlushTestSuite struct { 22 | suite.Suite 23 | fakeConsumer *mocks.FakeConsumer 24 | cfg config.Config 25 | db *models.DatabaseData 26 | fakeBaseline *mocks.FakeBaseline 27 | baseline destination.Baseline 28 | } 29 | 30 | func (f *FlushTestSuite) SetupTest() { 31 | tc := &kafkalib.TopicConfig{ 32 | Database: "db", 33 | Schema: "schema", 34 | Topic: "topic", 35 | CDCFormat: constants.DBZPostgresFormat, 36 | CDCKeyFormat: kafkalib.JSONKeyFmt, 37 | } 38 | 39 | tc.Load() 40 | 41 | f.cfg = config.Config{ 42 | Mode: config.Replication, 43 | Kafka: &config.Kafka{ 44 | BootstrapServer: "foo", 45 | GroupID: "bar", 46 | Username: "user", 47 | Password: "abc", 48 | TopicConfigs: []*kafkalib.TopicConfig{tc}, 49 | }, 50 | Queue: constants.Kafka, 51 | Output: "snowflake", 52 | BufferRows: 500, 53 | FlushIntervalSeconds: 60, 54 | FlushSizeKb: 500, 55 | } 56 | 57 | f.fakeBaseline = &mocks.FakeBaseline{} 58 | f.baseline = f.fakeBaseline 59 | f.db = models.NewMemoryDB() 60 | f.fakeConsumer = &mocks.FakeConsumer{} 61 | SetKafkaConsumer(map[string]kafkalib.Consumer{"foo": f.fakeConsumer}) 62 | } 63 | 64 | func TestFlushTestSuite(t *testing.T) { 65 | suite.Run(t, new(FlushTestSuite)) 66 | } 67 | -------------------------------------------------------------------------------- /processes/pool/writes.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "time" 7 | 8 | "github.com/artie-labs/transfer/lib/destination" 9 | "github.com/artie-labs/transfer/lib/telemetry/metrics/base" 10 | "github.com/artie-labs/transfer/lib/typing" 11 | "github.com/artie-labs/transfer/models" 12 | "github.com/artie-labs/transfer/processes/consumer" 13 | ) 14 | 15 | func StartPool(ctx context.Context, inMemDB *models.DatabaseData, dest destination.Baseline, metricsClient base.Client, td time.Duration) { 16 | slog.Info("Starting pool timer...") 17 | ticker := time.NewTicker(td) 18 | for range ticker.C { 19 | slog.Info("Flushing via pool...") 20 | if err := consumer.Flush(ctx, inMemDB, dest, metricsClient, consumer.Args{Reason: "time", CoolDown: typing.ToPtr(td)}); err != nil { 21 | slog.Error("Failed to flush via pool", slog.Any("err", err)) 22 | } 23 | } 24 | } 25 | --------------------------------------------------------------------------------